Skip to content

Commit c1968fe

Browse files
authored
[mlir][spirv] Fix serialization of multi-dimensional TensorArm constant (#151158)
This fixes an issue where multi-dimensional TensorArm dense elements could not be serialized. Signed-off-by: Mohammadreza Ameri Mahabadian <[email protected]>
1 parent c2c8644 commit c1968fe

File tree

2 files changed

+47
-0
lines changed

2 files changed

+47
-0
lines changed

mlir/lib/Target/SPIRV/Serialization/Serializer.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -958,6 +958,25 @@ Serializer::prepareDenseElementsConstant(Location loc, Type constType,
958958
} else {
959959
return 0;
960960
}
961+
} else if (isa<spirv::TensorArmType>(constType)) {
962+
numberOfConstituents = shapedType.getNumElements();
963+
operands.reserve(numberOfConstituents + 2);
964+
for (int i = 0; i < numberOfConstituents; ++i) {
965+
uint32_t elementID = 0;
966+
if (auto attr = dyn_cast<DenseIntElementsAttr>(valueAttr)) {
967+
elementID =
968+
elementType.isInteger(1)
969+
? prepareConstantBool(loc, attr.getValues<BoolAttr>()[i])
970+
: prepareConstantInt(loc, attr.getValues<IntegerAttr>()[i]);
971+
}
972+
if (auto attr = dyn_cast<DenseFPElementsAttr>(valueAttr)) {
973+
elementID = prepareConstantFp(loc, attr.getValues<FloatAttr>()[i]);
974+
}
975+
if (!elementID) {
976+
return 0;
977+
}
978+
operands.push_back(elementID);
979+
}
961980
} else {
962981
operands.reserve(numberOfConstituents + 2);
963982
for (int i = 0; i < numberOfConstituents; ++i) {

mlir/test/Target/SPIRV/constant.mlir

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,34 @@ spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader
307307
spirv.ReturnValue %coop : !spirv.coopmatrix<16x16xi8, Subgroup, MatrixAcc>
308308
}
309309

310+
// CHECK-LABEL: @arm_tensor_of_i32
311+
spirv.func @arm_tensor_of_i32() -> (!spirv.arm.tensor<2x3xi32>) "None" {
312+
// CHECK: {{%.*}} = spirv.Constant dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : !spirv.arm.tensor<2x3xi32>
313+
%0 = spirv.Constant dense<[[1, 2, 3], [4, 5, 6]]> : !spirv.arm.tensor<2x3xi32>
314+
spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xi32>
315+
}
316+
317+
// CHECK-LABEL: @splat_arm_tensor_of_i32
318+
spirv.func @splat_arm_tensor_of_i32() -> (!spirv.arm.tensor<2x3xi32>) "None" {
319+
// CHECK: {{%.*}} = spirv.Constant dense<2> : !spirv.arm.tensor<2x3xi32>
320+
%0 = spirv.Constant dense<2> : !spirv.arm.tensor<2x3xi32>
321+
spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xi32>
322+
}
323+
324+
// CHECK-LABEL: @arm_tensor_of_f32
325+
spirv.func @arm_tensor_of_f32() -> (!spirv.arm.tensor<2x3xf32>) "None" {
326+
// CHECK: {{%.*}} = spirv.Constant dense<{{\[}}[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : !spirv.arm.tensor<2x3xf32>
327+
%0 = spirv.Constant dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]>: !spirv.arm.tensor<2x3xf32>
328+
spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xf32>
329+
}
330+
331+
// CHECK-LABEL: @splat_arm_tensor_of_f32
332+
spirv.func @splat_arm_tensor_of_f32() -> (!spirv.arm.tensor<2x3xf32>) "None" {
333+
// CHECK: {{%.*}} = spirv.Constant dense<2.000000e+00> : !spirv.arm.tensor<2x3xf32>
334+
%0 = spirv.Constant dense<2.0> : !spirv.arm.tensor<2x3xf32>
335+
spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xf32>
336+
}
337+
310338
spirv.EntryPoint "GLCompute" @bool_const
311339
}
312340

0 commit comments

Comments
 (0)