-
Notifications
You must be signed in to change notification settings - Fork 14.7k
[mlir][spirv] Fix serialization of TensorARM with rank higher than one #152391
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][spirv] Fix serialization of TensorARM with rank higher than one #152391
Conversation
This addresses issue llvm#152012 where serialization of TensorARM values into OpConstantComposite resulted in invalid binary. Signed-off-by: Mohammadreza Ameri Mahabadian <[email protected]>
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-spirv Author: Mohammadreza Ameri Mahabadian (mahabadm) ChangesThis addresses issue #152012 where serialization of TensorARM values into OpConstantComposite resulted in invalid binary. Full diff: https://github.com/llvm/llvm-project/pull/152391.diff 3 Files Affected:
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index c967e863554fc..d8c54ec5f88c3 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -1560,7 +1560,19 @@ spirv::Deserializer::processConstantComposite(ArrayRef<uint32_t> operands) {
}
auto resultID = operands[1];
- if (auto shapedType = dyn_cast<ShapedType>(resultType)) {
+ if (auto tensorType = dyn_cast<TensorArmType>(resultType)) {
+ SmallVector<Attribute> flattenedElems;
+ for (Attribute element : elements) {
+ if (auto denseElemAttr = dyn_cast<DenseElementsAttr>(element)) {
+ for (auto value : denseElemAttr.getValues<Attribute>())
+ flattenedElems.push_back(value);
+ } else {
+ flattenedElems.push_back(element);
+ }
+ }
+ auto attr = DenseElementsAttr::get(tensorType, flattenedElems);
+ constantMap.try_emplace(resultID, attr, tensorType);
+ } else if (auto shapedType = dyn_cast<ShapedType>(resultType)) {
auto attr = DenseElementsAttr::get(shapedType, elements);
// For normal constants, we just record the attribute (and its type) for
// later materialization at use sites.
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index c049574fbc9e3..04277be1a192d 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -956,6 +956,11 @@ Serializer::prepareDenseElementsConstant(Location loc, Type constType,
uint32_t resultID = getNextID();
SmallVector<uint32_t, 4> operands = {typeID, resultID};
auto elementType = cast<spirv::CompositeType>(constType).getElementType(0);
+ if (auto tensorArmType = dyn_cast<spirv::TensorArmType>(constType)) {
+ ArrayRef<int64_t> innerShape = tensorArmType.getShape().drop_front();
+ if (innerShape.size() > 0)
+ elementType = spirv::TensorArmType::get(innerShape, elementType);
+ }
// "If the Result Type is a cooperative matrix type, then there must be only
// one Constituent, with scalar type matching the cooperative matrix Component
@@ -979,30 +984,10 @@ Serializer::prepareDenseElementsConstant(Location loc, Type constType,
} else {
return 0;
}
- } else if (isa<spirv::TensorArmType>(constType)) {
- if (isZeroValue(valueAttr)) {
- encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpConstantNull,
- {typeID, resultID});
- return resultID;
- }
- numberOfConstituents = shapedType.getNumElements();
- operands.reserve(numberOfConstituents + 2);
- for (int i = 0; i < numberOfConstituents; ++i) {
- uint32_t elementID = 0;
- if (auto attr = dyn_cast<DenseIntElementsAttr>(valueAttr)) {
- elementID =
- elementType.isInteger(1)
- ? prepareConstantBool(loc, attr.getValues<BoolAttr>()[i])
- : prepareConstantInt(loc, attr.getValues<IntegerAttr>()[i]);
- }
- if (auto attr = dyn_cast<DenseFPElementsAttr>(valueAttr)) {
- elementID = prepareConstantFp(loc, attr.getValues<FloatAttr>()[i]);
- }
- if (!elementID) {
- return 0;
- }
- operands.push_back(elementID);
- }
+ } else if (isa<spirv::TensorArmType>(constType) && isZeroValue(valueAttr)) {
+ encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpConstantNull,
+ {typeID, resultID});
+ return resultID;
} else {
operands.reserve(numberOfConstituents + 2);
for (int i = 0; i < numberOfConstituents; ++i) {
diff --git a/mlir/test/Target/SPIRV/arm-tensor-constant.mlir b/mlir/test/Target/SPIRV/arm-tensor-constant.mlir
index 275e586f70634..7fb8af1904388 100644
--- a/mlir/test/Target/SPIRV/arm-tensor-constant.mlir
+++ b/mlir/test/Target/SPIRV/arm-tensor-constant.mlir
@@ -1,17 +1,36 @@
// RUN: mlir-translate --no-implicit-module --test-spirv-roundtrip %s | FileCheck %s
-// DISABLED: %if spirv-tools %{ mlir-translate --no-implicit-module --serialize-spirv %s | spirv-val %}
-
-// FIXME(#152012): Fix arm tensor constant validation errors and reenable spirv-val tests.
+// RUN: %if spirv-tools %{ mlir-translate --no-implicit-module --serialize-spirv %s | spirv-val %}
spirv.module Logical Vulkan requires #spirv.vce<v1.3,
[VulkanMemoryModel, Shader, TensorsARM, Linkage], [SPV_KHR_vulkan_memory_model, SPV_ARM_tensors]> {
- // CHECK-LABEL: @arm_tensor_of_i32
- spirv.func @arm_tensor_of_i32() -> (!spirv.arm.tensor<2x3xi32>) "None" {
+ // CHECK-LABEL: @rank_1_arm_tensor_of_i32
+ spirv.func @rank_1_arm_tensor_of_i32() -> (!spirv.arm.tensor<3xi32>) "None" {
+ // CHECK: {{%.*}} = spirv.Constant dense<[1, 2, 3]> : !spirv.arm.tensor<3xi32>
+ %0 = spirv.Constant dense<[1, 2, 3]> : !spirv.arm.tensor<3xi32>
+ spirv.ReturnValue %0 : !spirv.arm.tensor<3xi32>
+ }
+
+ // CHECK-LABEL: @rank_2_arm_tensor_of_i32
+ spirv.func @rank_2_arm_tensor_of_i32() -> (!spirv.arm.tensor<2x3xi32>) "None" {
// CHECK: {{%.*}} = spirv.Constant dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : !spirv.arm.tensor<2x3xi32>
%0 = spirv.Constant dense<[[1, 2, 3], [4, 5, 6]]> : !spirv.arm.tensor<2x3xi32>
spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xi32>
}
+ // CHECK-LABEL: @rank_3_arm_tensor_of_i32
+ spirv.func @rank_3_arm_tensor_of_i32() -> (!spirv.arm.tensor<2x2x3xi32>) "None" {
+ // CHECK: {{%.*}} = spirv.Constant dense<{{\[}}{{\[}}[1, 2, 3], [4, 5, 6]], {{\[}}[7, 8, 9], [10, 11, 12]]]> : !spirv.arm.tensor<2x2x3xi32>
+ %0 = spirv.Constant dense<[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]> : !spirv.arm.tensor<2x2x3xi32>
+ spirv.ReturnValue %0 : !spirv.arm.tensor<2x2x3xi32>
+ }
+
+ // CHECK-LABEL: @rank_4_arm_tensor_of_i32
+ spirv.func @rank_4_arm_tensor_of_i32() -> (!spirv.arm.tensor<2x3x4x5xi32>) "None" {
+ // CHECK: {{%.*}} = spirv.Constant dense<5> : !spirv.arm.tensor<2x3x4x5xi32>
+ %0 = spirv.Constant dense<5> : !spirv.arm.tensor<2x3x4x5xi32>
+ spirv.ReturnValue %0 : !spirv.arm.tensor<2x3x4x5xi32>
+ }
+
// CHECK-LABEL: @splat_arm_tensor_of_i32
spirv.func @splat_arm_tensor_of_i32() -> (!spirv.arm.tensor<2x3xi32>) "None" {
// CHECK: {{%.*}} = spirv.Constant dense<2> : !spirv.arm.tensor<2x3xi32>
@@ -19,13 +38,34 @@ spirv.module Logical Vulkan requires #spirv.vce<v1.3,
spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xi32>
}
- // CHECK-LABEL: @arm_tensor_of_f32
- spirv.func @arm_tensor_of_f32() -> (!spirv.arm.tensor<2x3xf32>) "None" {
+ // CHECK-LABEL: @rank_1_arm_tensor_of_f32
+ spirv.func @rank_1_arm_tensor_of_f32() -> (!spirv.arm.tensor<3xf32>) "None" {
+ // CHECK: {{%.*}} = spirv.Constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00]> : !spirv.arm.tensor<3xf32>
+ %0 = spirv.Constant dense<[1.0, 2.0, 3.0]> : !spirv.arm.tensor<3xf32>
+ spirv.ReturnValue %0 : !spirv.arm.tensor<3xf32>
+ }
+
+ // CHECK-LABEL: @rank_2_arm_tensor_of_f32
+ spirv.func @rank_2_arm_tensor_of_f32() -> (!spirv.arm.tensor<2x3xf32>) "None" {
// CHECK: {{%.*}} = spirv.Constant dense<{{\[}}[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : !spirv.arm.tensor<2x3xf32>
- %0 = spirv.Constant dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]>: !spirv.arm.tensor<2x3xf32>
+ %0 = spirv.Constant dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : !spirv.arm.tensor<2x3xf32>
spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xf32>
}
+ // CHECK-LABEL: @rank_3_arm_tensor_of_f32
+ spirv.func @rank_3_arm_tensor_of_f32() -> (!spirv.arm.tensor<2x2x3xf32>) "None" {
+ // CHECK: {{%.*}} = spirv.Constant dense<{{\[}}{{\[}}[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]], {{\[}}[7.000000e+00, 8.000000e+00, 9.000000e+00], [1.000000e+01, 1.100000e+01, 1.200000e+01]]]> : !spirv.arm.tensor<2x2x3xf32>
+ %0 = spirv.Constant dense<[[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], [[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]]]> : !spirv.arm.tensor<2x2x3xf32>
+ spirv.ReturnValue %0 : !spirv.arm.tensor<2x2x3xf32>
+ }
+
+ // CHECK-LABEL: @rank_4_arm_tensor_of_f32
+ spirv.func @rank_4_arm_tensor_of_f32() -> (!spirv.arm.tensor<2x3x4x5xf32>) "None" {
+ // CHECK: {{%.*}} = spirv.Constant dense<5.000000e+00> : !spirv.arm.tensor<2x3x4x5xf32>
+ %0 = spirv.Constant dense<5.0> : !spirv.arm.tensor<2x3x4x5xf32>
+ spirv.ReturnValue %0 : !spirv.arm.tensor<2x3x4x5xf32>
+ }
+
// CHECK-LABEL: @splat_arm_tensor_of_f32
spirv.func @splat_arm_tensor_of_f32() -> (!spirv.arm.tensor<2x3xf32>) "None" {
// CHECK: {{%.*}} = spirv.Constant dense<2.000000e+00> : !spirv.arm.tensor<2x3xf32>
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for fixing this. Could you rebase your PR? With #152124, we should be able to run spirv-val in the CI and make sure it doesn't complain.
@kuhar Thanks for your note. I have rebased and seems like that the test have passed. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Signed-off-by: Mohammadreza Ameri Mahabadian <[email protected]>
@kuhar Would you please kindly merge this patch, if there are no further comments? Many thanks. |
This PR fixes #152012 where serialization of TensorARM values into OpConstantComposite resulted in invalid binary.