Skip to content

Conversation

@mahabadm
Copy link
Contributor

…composites of multi-dimensional array

This fixes a bug in verification and serialization of replicated constant composite ops where the splat value can potentially be a multi-dimensional array.

…composites of multi-dimensional array

This fixes a bug in verification and serialization of replicated constant composite ops where the splat value can potentially be a multi-dimensional array.

Signed-off-by: Mohammadreza Ameri Mahabadian <[email protected]>
@llvmbot
Copy link
Member

llvmbot commented Jul 29, 2025

@llvm/pr-subscribers-mlir-spirv

@llvm/pr-subscribers-mlir

Author: Mohammadreza Ameri Mahabadian (mahabadm)

Changes

…composites of multi-dimensional array

This fixes a bug in verification and serialization of replicated constant composite ops where the splat value can potentially be a multi-dimensional array.


Full diff: https://github.com/llvm/llvm-project/pull/151168.diff

3 Files Affected:

  • (modified) mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp (+14-11)
  • (modified) mlir/lib/Target/SPIRV/Serialization/Serializer.cpp (+14-11)
  • (modified) mlir/test/Target/SPIRV/constant.mlir (+14)
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 52c672a05fa43..c8b87fad8ccad 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -767,19 +767,22 @@ void mlir::spirv::AddressOfOp::getAsmResultNames(
 // spirv.EXTConstantCompositeReplicate
 //===----------------------------------------------------------------------===//
 
+static Type getValueType(Attribute attr) {
+  if (auto typedAttr = dyn_cast<TypedAttr>(attr)) {
+    return typedAttr.getType();
+  }
+
+  if (auto arrayAttr = dyn_cast<ArrayAttr>(attr)) {
+    return spirv::ArrayType::get(getValueType(arrayAttr[0]), arrayAttr.size());
+  }
+
+  return nullptr;
+}
+
 LogicalResult spirv::EXTConstantCompositeReplicateOp::verify() {
-  Type valueType;
-  if (auto typedAttr = dyn_cast<TypedAttr>(getValue())) {
-    valueType = typedAttr.getType();
-  } else if (auto arrayAttr = dyn_cast<ArrayAttr>(getValue())) {
-    auto typedElemAttr = dyn_cast<TypedAttr>(arrayAttr[0]);
-    if (!typedElemAttr)
-      return emitError("value attribute is not typed");
-    valueType =
-        spirv::ArrayType::get(typedElemAttr.getType(), arrayAttr.size());
-  } else {
+  Type valueType = getValueType(getValue());
+  if (!valueType)
     return emitError("unknown value attribute type");
-  }
 
   auto compositeType = dyn_cast<spirv::CompositeType>(getType());
   if (!compositeType)
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index a8a2b2e7cf38c..9e81b6ca505a2 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -1124,6 +1124,18 @@ uint32_t Serializer::prepareConstantFp(Location loc, FloatAttr floatAttr,
   return resultID;
 }
 
+static Type getValueType(Attribute attr) {
+  if (auto typedAttr = dyn_cast<TypedAttr>(attr)) {
+    return typedAttr.getType();
+  }
+
+  if (auto arrayAttr = dyn_cast<ArrayAttr>(attr)) {
+    return spirv::ArrayType::get(getValueType(arrayAttr[0]), arrayAttr.size());
+  }
+
+  return nullptr;
+}
+
 uint32_t Serializer::prepareConstantCompositeReplicate(Location loc,
                                                        Type resultType,
                                                        Attribute valueAttr) {
@@ -1137,18 +1149,9 @@ uint32_t Serializer::prepareConstantCompositeReplicate(Location loc,
     return 0;
   }
 
-  Type valueType;
-  if (auto typedAttr = dyn_cast<TypedAttr>(valueAttr)) {
-    valueType = typedAttr.getType();
-  } else if (auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
-    auto typedElemAttr = dyn_cast<TypedAttr>(arrayAttr[0]);
-    if (!typedElemAttr)
-      return 0;
-    valueType =
-        spirv::ArrayType::get(typedElemAttr.getType(), arrayAttr.size());
-  } else {
+  Type valueType = getValueType(valueAttr);
+  if (!valueAttr)
     return 0;
-  }
 
   auto compositeType = dyn_cast<CompositeType>(resultType);
   if (!compositeType)
diff --git a/mlir/test/Target/SPIRV/constant.mlir b/mlir/test/Target/SPIRV/constant.mlir
index 6aca11ec5e6e6..e06c0146d4ad2 100644
--- a/mlir/test/Target/SPIRV/constant.mlir
+++ b/mlir/test/Target/SPIRV/constant.mlir
@@ -363,6 +363,13 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, ReplicatedCompos
     spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xi32>
   }
 
+  // CHECK-LABEL: @array_of_splat_array_of_non_splat_arrays_of_i32
+  spirv.func @array_of_splat_array_of_non_splat_arrays_of_i32() -> !spirv.array<2 x !spirv.array<2 x !spirv.array<3 x i32>>> "None" {
+    // CHECK: spirv.EXT.ConstantCompositeReplicate {{\[}}{{\[}}[1 : i32, 2 : i32, 3 : i32], [4 : i32, 5 : i32, 6 : i32]]] : !spirv.array<2 x !spirv.array<2 x !spirv.array<3 x i32>>>
+    %0 = spirv.EXT.ConstantCompositeReplicate [[[1 : i32, 2 : i32, 3 : i32], [4 : i32, 5 : i32, 6 : i32]]] : !spirv.array<2 x !spirv.array<2 x !spirv.array<3 x i32>>>
+    spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<2 x !spirv.array<3 x i32>>>
+  }
+
   // CHECK-LABEL: @splat_vector_f32
   spirv.func @splat_vector_f32() -> (vector<3xf32>) "None" {
     // CHECK: spirv.EXT.ConstantCompositeReplicate [1.000000e+00 : f32] : vector<3xf32>
@@ -411,4 +418,11 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, ReplicatedCompos
     %0 = spirv.EXT.ConstantCompositeReplicate [2.0 : f32] : !spirv.arm.tensor<2x3xf32>
     spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xf32>
   }
+
+  // CHECK-LABEL: @array_of_splat_array_of_non_splat_arrays_of_f32
+  spirv.func @array_of_splat_array_of_non_splat_arrays_of_f32() -> !spirv.array<2 x !spirv.array<2 x !spirv.array<3 x f32>>> "None" {
+    // CHECK: spirv.EXT.ConstantCompositeReplicate {{\[}}{{\[}}[1.000000e+00 : f32, 2.000000e+00 : f32, 3.000000e+00 : f32], [4.000000e+00 : f32, 5.000000e+00 : f32, 6.000000e+00 : f32]]] : !spirv.array<2 x !spirv.array<2 x !spirv.array<3 x f32>>>
+    %0 = spirv.EXT.ConstantCompositeReplicate [[[1.0 : f32, 2.0 : f32, 3.0 : f32], [4.0 : f32, 5.0 : f32, 6.0 : f32]]] : !spirv.array<2 x !spirv.array<2 x !spirv.array<3 x f32>>>
+    spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<2 x !spirv.array<3 x f32>>>
+  }
 }

@mahabadm
Copy link
Contributor Author

mahabadm commented Aug 1, 2025

@kuhar Can you please advise if there are further changes needed for this? Appreciate that.

Signed-off-by: Mohammadreza Ameri Mahabadian <[email protected]>
@kuhar kuhar merged commit 6c072c0 into llvm:main Aug 1, 2025
9 checks passed
@mahabadm mahabadm deleted the multidim_replicated_const_composites_fix branch August 3, 2025 14:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants