Skip to content

[mlir][spirv] Fix serialization of TensorARM with rank higher than one #152391

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Aug 8, 2025

Conversation

mahabadm
Copy link
Contributor

@mahabadm mahabadm commented Aug 6, 2025

This PR fixes #152012 where serialization of TensorARM values into OpConstantComposite resulted in invalid binary.

This addresses issue llvm#152012 where serialization of TensorARM values into OpConstantComposite resulted in invalid binary.

Signed-off-by: Mohammadreza Ameri Mahabadian <[email protected]>
@llvmbot
Copy link
Member

llvmbot commented Aug 6, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-spirv

Author: Mohammadreza Ameri Mahabadian (mahabadm)

Changes

This addresses issue #152012 where serialization of TensorARM values into OpConstantComposite resulted in invalid binary.


Full diff: https://github.com/llvm/llvm-project/pull/152391.diff

3 Files Affected:

  • (modified) mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp (+13-1)
  • (modified) mlir/lib/Target/SPIRV/Serialization/Serializer.cpp (+9-24)
  • (modified) mlir/test/Target/SPIRV/arm-tensor-constant.mlir (+48-8)
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index c967e863554fc..d8c54ec5f88c3 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -1560,7 +1560,19 @@ spirv::Deserializer::processConstantComposite(ArrayRef<uint32_t> operands) {
   }
 
   auto resultID = operands[1];
-  if (auto shapedType = dyn_cast<ShapedType>(resultType)) {
+  if (auto tensorType = dyn_cast<TensorArmType>(resultType)) {
+    SmallVector<Attribute> flattenedElems;
+    for (Attribute element : elements) {
+      if (auto denseElemAttr = dyn_cast<DenseElementsAttr>(element)) {
+        for (auto value : denseElemAttr.getValues<Attribute>())
+          flattenedElems.push_back(value);
+      } else {
+        flattenedElems.push_back(element);
+      }
+    }
+    auto attr = DenseElementsAttr::get(tensorType, flattenedElems);
+    constantMap.try_emplace(resultID, attr, tensorType);
+  } else if (auto shapedType = dyn_cast<ShapedType>(resultType)) {
     auto attr = DenseElementsAttr::get(shapedType, elements);
     // For normal constants, we just record the attribute (and its type) for
     // later materialization at use sites.
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index c049574fbc9e3..04277be1a192d 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -956,6 +956,11 @@ Serializer::prepareDenseElementsConstant(Location loc, Type constType,
   uint32_t resultID = getNextID();
   SmallVector<uint32_t, 4> operands = {typeID, resultID};
   auto elementType = cast<spirv::CompositeType>(constType).getElementType(0);
+  if (auto tensorArmType = dyn_cast<spirv::TensorArmType>(constType)) {
+    ArrayRef<int64_t> innerShape = tensorArmType.getShape().drop_front();
+    if (innerShape.size() > 0)
+      elementType = spirv::TensorArmType::get(innerShape, elementType);
+  }
 
   // "If the Result Type is a cooperative matrix type, then there must be only
   // one Constituent, with scalar type matching the cooperative matrix Component
@@ -979,30 +984,10 @@ Serializer::prepareDenseElementsConstant(Location loc, Type constType,
     } else {
       return 0;
     }
-  } else if (isa<spirv::TensorArmType>(constType)) {
-    if (isZeroValue(valueAttr)) {
-      encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpConstantNull,
-                            {typeID, resultID});
-      return resultID;
-    }
-    numberOfConstituents = shapedType.getNumElements();
-    operands.reserve(numberOfConstituents + 2);
-    for (int i = 0; i < numberOfConstituents; ++i) {
-      uint32_t elementID = 0;
-      if (auto attr = dyn_cast<DenseIntElementsAttr>(valueAttr)) {
-        elementID =
-            elementType.isInteger(1)
-                ? prepareConstantBool(loc, attr.getValues<BoolAttr>()[i])
-                : prepareConstantInt(loc, attr.getValues<IntegerAttr>()[i]);
-      }
-      if (auto attr = dyn_cast<DenseFPElementsAttr>(valueAttr)) {
-        elementID = prepareConstantFp(loc, attr.getValues<FloatAttr>()[i]);
-      }
-      if (!elementID) {
-        return 0;
-      }
-      operands.push_back(elementID);
-    }
+  } else if (isa<spirv::TensorArmType>(constType) && isZeroValue(valueAttr)) {
+    encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpConstantNull,
+                          {typeID, resultID});
+    return resultID;
   } else {
     operands.reserve(numberOfConstituents + 2);
     for (int i = 0; i < numberOfConstituents; ++i) {
diff --git a/mlir/test/Target/SPIRV/arm-tensor-constant.mlir b/mlir/test/Target/SPIRV/arm-tensor-constant.mlir
index 275e586f70634..7fb8af1904388 100644
--- a/mlir/test/Target/SPIRV/arm-tensor-constant.mlir
+++ b/mlir/test/Target/SPIRV/arm-tensor-constant.mlir
@@ -1,17 +1,36 @@
 // RUN: mlir-translate --no-implicit-module --test-spirv-roundtrip %s | FileCheck %s
-// DISABLED: %if spirv-tools %{ mlir-translate --no-implicit-module --serialize-spirv %s | spirv-val %}
-
-// FIXME(#152012): Fix arm tensor constant validation errors and reenable spirv-val tests.
+// RUN: %if spirv-tools %{ mlir-translate --no-implicit-module --serialize-spirv %s | spirv-val %}
 
 spirv.module Logical Vulkan requires #spirv.vce<v1.3,
              [VulkanMemoryModel, Shader, TensorsARM, Linkage], [SPV_KHR_vulkan_memory_model, SPV_ARM_tensors]> {
-  // CHECK-LABEL: @arm_tensor_of_i32
-  spirv.func @arm_tensor_of_i32() -> (!spirv.arm.tensor<2x3xi32>) "None" {
+  // CHECK-LABEL: @rank_1_arm_tensor_of_i32
+  spirv.func @rank_1_arm_tensor_of_i32() -> (!spirv.arm.tensor<3xi32>) "None" {
+    // CHECK: {{%.*}} = spirv.Constant dense<[1, 2, 3]> : !spirv.arm.tensor<3xi32>
+    %0 = spirv.Constant dense<[1, 2, 3]> : !spirv.arm.tensor<3xi32>
+    spirv.ReturnValue %0 : !spirv.arm.tensor<3xi32>
+  }
+
+  // CHECK-LABEL: @rank_2_arm_tensor_of_i32
+  spirv.func @rank_2_arm_tensor_of_i32() -> (!spirv.arm.tensor<2x3xi32>) "None" {
     // CHECK: {{%.*}} = spirv.Constant dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : !spirv.arm.tensor<2x3xi32>
     %0 = spirv.Constant dense<[[1, 2, 3], [4, 5, 6]]> : !spirv.arm.tensor<2x3xi32>
     spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xi32>
   }
 
+  // CHECK-LABEL: @rank_3_arm_tensor_of_i32
+  spirv.func @rank_3_arm_tensor_of_i32() -> (!spirv.arm.tensor<2x2x3xi32>) "None" {
+    // CHECK: {{%.*}} = spirv.Constant dense<{{\[}}{{\[}}[1, 2, 3], [4, 5, 6]], {{\[}}[7, 8, 9], [10, 11, 12]]]> : !spirv.arm.tensor<2x2x3xi32>
+    %0 = spirv.Constant dense<[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]> : !spirv.arm.tensor<2x2x3xi32>
+    spirv.ReturnValue %0 : !spirv.arm.tensor<2x2x3xi32>
+  }
+
+  // CHECK-LABEL: @rank_4_arm_tensor_of_i32
+  spirv.func @rank_4_arm_tensor_of_i32() -> (!spirv.arm.tensor<2x3x4x5xi32>) "None" {
+    // CHECK: {{%.*}} = spirv.Constant dense<5> : !spirv.arm.tensor<2x3x4x5xi32>
+    %0 = spirv.Constant dense<5> : !spirv.arm.tensor<2x3x4x5xi32>
+    spirv.ReturnValue %0 : !spirv.arm.tensor<2x3x4x5xi32>
+  }
+
   // CHECK-LABEL: @splat_arm_tensor_of_i32
   spirv.func @splat_arm_tensor_of_i32() -> (!spirv.arm.tensor<2x3xi32>) "None" {
     // CHECK: {{%.*}} = spirv.Constant dense<2> : !spirv.arm.tensor<2x3xi32>
@@ -19,13 +38,34 @@ spirv.module Logical Vulkan requires #spirv.vce<v1.3,
     spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xi32>
   }
 
-  // CHECK-LABEL: @arm_tensor_of_f32
-  spirv.func @arm_tensor_of_f32() -> (!spirv.arm.tensor<2x3xf32>) "None" {
+  // CHECK-LABEL: @rank_1_arm_tensor_of_f32
+  spirv.func @rank_1_arm_tensor_of_f32() -> (!spirv.arm.tensor<3xf32>) "None" {
+    // CHECK: {{%.*}} = spirv.Constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00]> : !spirv.arm.tensor<3xf32>
+    %0 = spirv.Constant dense<[1.0, 2.0, 3.0]> : !spirv.arm.tensor<3xf32>
+    spirv.ReturnValue %0 : !spirv.arm.tensor<3xf32>
+  }
+
+  // CHECK-LABEL: @rank_2_arm_tensor_of_f32
+  spirv.func @rank_2_arm_tensor_of_f32() -> (!spirv.arm.tensor<2x3xf32>) "None" {
     // CHECK: {{%.*}} = spirv.Constant dense<{{\[}}[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : !spirv.arm.tensor<2x3xf32>
-    %0 = spirv.Constant dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]>: !spirv.arm.tensor<2x3xf32>
+    %0 = spirv.Constant dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : !spirv.arm.tensor<2x3xf32>
     spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xf32>
   }
 
+  // CHECK-LABEL: @rank_3_arm_tensor_of_f32
+  spirv.func @rank_3_arm_tensor_of_f32() -> (!spirv.arm.tensor<2x2x3xf32>) "None" {
+    // CHECK: {{%.*}} = spirv.Constant dense<{{\[}}{{\[}}[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]], {{\[}}[7.000000e+00, 8.000000e+00, 9.000000e+00], [1.000000e+01, 1.100000e+01, 1.200000e+01]]]> : !spirv.arm.tensor<2x2x3xf32>
+    %0 = spirv.Constant dense<[[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], [[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]]]> : !spirv.arm.tensor<2x2x3xf32>
+    spirv.ReturnValue %0 : !spirv.arm.tensor<2x2x3xf32>
+  }
+
+  // CHECK-LABEL: @rank_4_arm_tensor_of_f32
+  spirv.func @rank_4_arm_tensor_of_f32() -> (!spirv.arm.tensor<2x3x4x5xf32>) "None" {
+    // CHECK: {{%.*}} = spirv.Constant dense<5.000000e+00> : !spirv.arm.tensor<2x3x4x5xf32>
+    %0 = spirv.Constant dense<5.0> : !spirv.arm.tensor<2x3x4x5xf32>
+    spirv.ReturnValue %0 : !spirv.arm.tensor<2x3x4x5xf32>
+  }
+
   // CHECK-LABEL: @splat_arm_tensor_of_f32
   spirv.func @splat_arm_tensor_of_f32() -> (!spirv.arm.tensor<2x3xf32>) "None" {
     // CHECK: {{%.*}} = spirv.Constant dense<2.000000e+00> : !spirv.arm.tensor<2x3xf32>

@kuhar kuhar changed the title [mlir][spirv]Fix serialization of TensorARM with rank higher than one [mlir][spirv] Fix serialization of TensorARM with rank higher than one Aug 7, 2025
Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing this. Could you rebase your PR? With #152124, we should be able to run spirv-val in the CI and make sure it doesn't complain.

@mahabadm
Copy link
Contributor Author

mahabadm commented Aug 7, 2025

@kuhar Thanks for your note. I have rebased and seems like that the test have passed.

@kuhar kuhar requested review from Hardcode84 and IgWod-IMG August 7, 2025 16:31
Copy link
Contributor

@IgWod-IMG IgWod-IMG left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Signed-off-by: Mohammadreza Ameri Mahabadian <[email protected]>
@mahabadm
Copy link
Contributor Author

mahabadm commented Aug 8, 2025

@kuhar Would you please kindly merge this patch, if there are no further comments? Many thanks.

@kuhar kuhar merged commit 688551f into llvm:main Aug 8, 2025
10 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[mlir][spirv] ARM Tensor constants fail to validate
4 participants