Skip to content

Commit 688551f

Browse files
authored
[mlir][spirv] Fix serialization of TensorARM with rank higher than one (#152391)
This PR fixes #152012 where serialization of TensorARM values into OpConstantComposite resulted in invalid binary. --------- Signed-off-by: Mohammadreza Ameri Mahabadian <[email protected]>
1 parent 4f2ed92 commit 688551f

File tree

3 files changed

+70
-33
lines changed

3 files changed

+70
-33
lines changed

mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1560,7 +1560,19 @@ spirv::Deserializer::processConstantComposite(ArrayRef<uint32_t> operands) {
15601560
}
15611561

15621562
auto resultID = operands[1];
1563-
if (auto shapedType = dyn_cast<ShapedType>(resultType)) {
1563+
if (auto tensorType = dyn_cast<TensorArmType>(resultType)) {
1564+
SmallVector<Attribute> flattenedElems;
1565+
for (Attribute element : elements) {
1566+
if (auto denseElemAttr = dyn_cast<DenseElementsAttr>(element)) {
1567+
for (auto value : denseElemAttr.getValues<Attribute>())
1568+
flattenedElems.push_back(value);
1569+
} else {
1570+
flattenedElems.push_back(element);
1571+
}
1572+
}
1573+
auto attr = DenseElementsAttr::get(tensorType, flattenedElems);
1574+
constantMap.try_emplace(resultID, attr, tensorType);
1575+
} else if (auto shapedType = dyn_cast<ShapedType>(resultType)) {
15641576
auto attr = DenseElementsAttr::get(shapedType, elements);
15651577
// For normal constants, we just record the attribute (and its type) for
15661578
// later materialization at use sites.

mlir/lib/Target/SPIRV/Serialization/Serializer.cpp

Lines changed: 9 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -956,6 +956,11 @@ Serializer::prepareDenseElementsConstant(Location loc, Type constType,
956956
uint32_t resultID = getNextID();
957957
SmallVector<uint32_t, 4> operands = {typeID, resultID};
958958
auto elementType = cast<spirv::CompositeType>(constType).getElementType(0);
959+
if (auto tensorArmType = dyn_cast<spirv::TensorArmType>(constType)) {
960+
ArrayRef<int64_t> innerShape = tensorArmType.getShape().drop_front();
961+
if (!innerShape.empty())
962+
elementType = spirv::TensorArmType::get(innerShape, elementType);
963+
}
959964

960965
// "If the Result Type is a cooperative matrix type, then there must be only
961966
// one Constituent, with scalar type matching the cooperative matrix Component
@@ -979,30 +984,10 @@ Serializer::prepareDenseElementsConstant(Location loc, Type constType,
979984
} else {
980985
return 0;
981986
}
982-
} else if (isa<spirv::TensorArmType>(constType)) {
983-
if (isZeroValue(valueAttr)) {
984-
encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpConstantNull,
985-
{typeID, resultID});
986-
return resultID;
987-
}
988-
numberOfConstituents = shapedType.getNumElements();
989-
operands.reserve(numberOfConstituents + 2);
990-
for (int i = 0; i < numberOfConstituents; ++i) {
991-
uint32_t elementID = 0;
992-
if (auto attr = dyn_cast<DenseIntElementsAttr>(valueAttr)) {
993-
elementID =
994-
elementType.isInteger(1)
995-
? prepareConstantBool(loc, attr.getValues<BoolAttr>()[i])
996-
: prepareConstantInt(loc, attr.getValues<IntegerAttr>()[i]);
997-
}
998-
if (auto attr = dyn_cast<DenseFPElementsAttr>(valueAttr)) {
999-
elementID = prepareConstantFp(loc, attr.getValues<FloatAttr>()[i]);
1000-
}
1001-
if (!elementID) {
1002-
return 0;
1003-
}
1004-
operands.push_back(elementID);
1005-
}
987+
} else if (isa<spirv::TensorArmType>(constType) && isZeroValue(valueAttr)) {
988+
encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpConstantNull,
989+
{typeID, resultID});
990+
return resultID;
1006991
} else {
1007992
operands.reserve(numberOfConstituents + 2);
1008993
for (int i = 0; i < numberOfConstituents; ++i) {

mlir/test/Target/SPIRV/arm-tensor-constant.mlir

Lines changed: 48 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,71 @@
11
// RUN: mlir-translate --no-implicit-module --test-spirv-roundtrip %s | FileCheck %s
2-
// DISABLED: %if spirv-tools %{ mlir-translate --no-implicit-module --serialize-spirv %s | spirv-val %}
3-
4-
// FIXME(#152012): Fix arm tensor constant validation errors and reenable spirv-val tests.
2+
// RUN: %if spirv-tools %{ mlir-translate --no-implicit-module --serialize-spirv %s | spirv-val %}
53

64
spirv.module Logical Vulkan requires #spirv.vce<v1.3,
75
[VulkanMemoryModel, Shader, TensorsARM, Linkage], [SPV_KHR_vulkan_memory_model, SPV_ARM_tensors]> {
8-
// CHECK-LABEL: @arm_tensor_of_i32
9-
spirv.func @arm_tensor_of_i32() -> (!spirv.arm.tensor<2x3xi32>) "None" {
6+
// CHECK-LABEL: @rank_1_arm_tensor_of_i32
7+
spirv.func @rank_1_arm_tensor_of_i32() -> (!spirv.arm.tensor<3xi32>) "None" {
8+
// CHECK: {{%.*}} = spirv.Constant dense<[1, 2, 3]> : !spirv.arm.tensor<3xi32>
9+
%0 = spirv.Constant dense<[1, 2, 3]> : !spirv.arm.tensor<3xi32>
10+
spirv.ReturnValue %0 : !spirv.arm.tensor<3xi32>
11+
}
12+
13+
// CHECK-LABEL: @rank_2_arm_tensor_of_i32
14+
spirv.func @rank_2_arm_tensor_of_i32() -> (!spirv.arm.tensor<2x3xi32>) "None" {
1015
// CHECK: {{%.*}} = spirv.Constant dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : !spirv.arm.tensor<2x3xi32>
1116
%0 = spirv.Constant dense<[[1, 2, 3], [4, 5, 6]]> : !spirv.arm.tensor<2x3xi32>
1217
spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xi32>
1318
}
1419

20+
// CHECK-LABEL: @rank_3_arm_tensor_of_i32
21+
spirv.func @rank_3_arm_tensor_of_i32() -> (!spirv.arm.tensor<2x2x3xi32>) "None" {
22+
// CHECK: {{%.*}} = spirv.Constant dense<{{\[}}{{\[}}[1, 2, 3], [4, 5, 6]], {{\[}}[7, 8, 9], [10, 11, 12]]]> : !spirv.arm.tensor<2x2x3xi32>
23+
%0 = spirv.Constant dense<[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]> : !spirv.arm.tensor<2x2x3xi32>
24+
spirv.ReturnValue %0 : !spirv.arm.tensor<2x2x3xi32>
25+
}
26+
27+
// CHECK-LABEL: @rank_4_arm_tensor_of_i32
28+
spirv.func @rank_4_arm_tensor_of_i32() -> (!spirv.arm.tensor<2x3x4x5xi32>) "None" {
29+
// CHECK: {{%.*}} = spirv.Constant dense<5> : !spirv.arm.tensor<2x3x4x5xi32>
30+
%0 = spirv.Constant dense<5> : !spirv.arm.tensor<2x3x4x5xi32>
31+
spirv.ReturnValue %0 : !spirv.arm.tensor<2x3x4x5xi32>
32+
}
33+
1534
// CHECK-LABEL: @splat_arm_tensor_of_i32
1635
spirv.func @splat_arm_tensor_of_i32() -> (!spirv.arm.tensor<2x3xi32>) "None" {
1736
// CHECK: {{%.*}} = spirv.Constant dense<2> : !spirv.arm.tensor<2x3xi32>
1837
%0 = spirv.Constant dense<2> : !spirv.arm.tensor<2x3xi32>
1938
spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xi32>
2039
}
2140

22-
// CHECK-LABEL: @arm_tensor_of_f32
23-
spirv.func @arm_tensor_of_f32() -> (!spirv.arm.tensor<2x3xf32>) "None" {
41+
// CHECK-LABEL: @rank_1_arm_tensor_of_f32
42+
spirv.func @rank_1_arm_tensor_of_f32() -> (!spirv.arm.tensor<3xf32>) "None" {
43+
// CHECK: {{%.*}} = spirv.Constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00]> : !spirv.arm.tensor<3xf32>
44+
%0 = spirv.Constant dense<[1.0, 2.0, 3.0]> : !spirv.arm.tensor<3xf32>
45+
spirv.ReturnValue %0 : !spirv.arm.tensor<3xf32>
46+
}
47+
48+
// CHECK-LABEL: @rank_2_arm_tensor_of_f32
49+
spirv.func @rank_2_arm_tensor_of_f32() -> (!spirv.arm.tensor<2x3xf32>) "None" {
2450
// CHECK: {{%.*}} = spirv.Constant dense<{{\[}}[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : !spirv.arm.tensor<2x3xf32>
25-
%0 = spirv.Constant dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]>: !spirv.arm.tensor<2x3xf32>
51+
%0 = spirv.Constant dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : !spirv.arm.tensor<2x3xf32>
2652
spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xf32>
2753
}
2854

55+
// CHECK-LABEL: @rank_3_arm_tensor_of_f32
56+
spirv.func @rank_3_arm_tensor_of_f32() -> (!spirv.arm.tensor<2x2x3xf32>) "None" {
57+
// CHECK: {{%.*}} = spirv.Constant dense<{{\[}}{{\[}}[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]], {{\[}}[7.000000e+00, 8.000000e+00, 9.000000e+00], [1.000000e+01, 1.100000e+01, 1.200000e+01]]]> : !spirv.arm.tensor<2x2x3xf32>
58+
%0 = spirv.Constant dense<[[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], [[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]]]> : !spirv.arm.tensor<2x2x3xf32>
59+
spirv.ReturnValue %0 : !spirv.arm.tensor<2x2x3xf32>
60+
}
61+
62+
// CHECK-LABEL: @rank_4_arm_tensor_of_f32
63+
spirv.func @rank_4_arm_tensor_of_f32() -> (!spirv.arm.tensor<2x3x4x5xf32>) "None" {
64+
// CHECK: {{%.*}} = spirv.Constant dense<5.000000e+00> : !spirv.arm.tensor<2x3x4x5xf32>
65+
%0 = spirv.Constant dense<5.0> : !spirv.arm.tensor<2x3x4x5xf32>
66+
spirv.ReturnValue %0 : !spirv.arm.tensor<2x3x4x5xf32>
67+
}
68+
2969
// CHECK-LABEL: @splat_arm_tensor_of_f32
3070
spirv.func @splat_arm_tensor_of_f32() -> (!spirv.arm.tensor<2x3xf32>) "None" {
3171
// CHECK: {{%.*}} = spirv.Constant dense<2.000000e+00> : !spirv.arm.tensor<2x3xf32>

0 commit comments

Comments
 (0)