From c246a1a4ab86aec86585356aede89b31ac4309ea Mon Sep 17 00:00:00 2001 From: Mohammadreza Ameri Mahabadian Date: Fri, 18 Jul 2025 14:04:15 +0100 Subject: [PATCH 1/3] [mlir][spirv] Fix verification and serialization replicated constant composites of multi-dimensional array This fixes a bug in verification and serialization of replicated constant composite ops where the splat value can potentially be a multi-dimensional array. Signed-off-by: Mohammadreza Ameri Mahabadian --- mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp | 25 +++++++++++-------- .../Target/SPIRV/Serialization/Serializer.cpp | 25 +++++++++++-------- mlir/test/Target/SPIRV/constant.mlir | 14 +++++++++++ 3 files changed, 42 insertions(+), 22 deletions(-) diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp index 52c672a05fa43..c8b87fad8ccad 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -767,19 +767,22 @@ void mlir::spirv::AddressOfOp::getAsmResultNames( // spirv.EXTConstantCompositeReplicate //===----------------------------------------------------------------------===// +static Type getValueType(Attribute attr) { + if (auto typedAttr = dyn_cast(attr)) { + return typedAttr.getType(); + } + + if (auto arrayAttr = dyn_cast(attr)) { + return spirv::ArrayType::get(getValueType(arrayAttr[0]), arrayAttr.size()); + } + + return nullptr; +} + LogicalResult spirv::EXTConstantCompositeReplicateOp::verify() { - Type valueType; - if (auto typedAttr = dyn_cast(getValue())) { - valueType = typedAttr.getType(); - } else if (auto arrayAttr = dyn_cast(getValue())) { - auto typedElemAttr = dyn_cast(arrayAttr[0]); - if (!typedElemAttr) - return emitError("value attribute is not typed"); - valueType = - spirv::ArrayType::get(typedElemAttr.getType(), arrayAttr.size()); - } else { + Type valueType = getValueType(getValue()); + if (!valueType) return emitError("unknown value attribute type"); - } auto compositeType = dyn_cast(getType()); if (!compositeType) diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp index a8a2b2e7cf38c..9e81b6ca505a2 100644 --- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp @@ -1124,6 +1124,18 @@ uint32_t Serializer::prepareConstantFp(Location loc, FloatAttr floatAttr, return resultID; } +static Type getValueType(Attribute attr) { + if (auto typedAttr = dyn_cast(attr)) { + return typedAttr.getType(); + } + + if (auto arrayAttr = dyn_cast(attr)) { + return spirv::ArrayType::get(getValueType(arrayAttr[0]), arrayAttr.size()); + } + + return nullptr; +} + uint32_t Serializer::prepareConstantCompositeReplicate(Location loc, Type resultType, Attribute valueAttr) { @@ -1137,18 +1149,9 @@ uint32_t Serializer::prepareConstantCompositeReplicate(Location loc, return 0; } - Type valueType; - if (auto typedAttr = dyn_cast(valueAttr)) { - valueType = typedAttr.getType(); - } else if (auto arrayAttr = dyn_cast(valueAttr)) { - auto typedElemAttr = dyn_cast(arrayAttr[0]); - if (!typedElemAttr) - return 0; - valueType = - spirv::ArrayType::get(typedElemAttr.getType(), arrayAttr.size()); - } else { + Type valueType = getValueType(valueAttr); + if (!valueAttr) return 0; - } auto compositeType = dyn_cast(resultType); if (!compositeType) diff --git a/mlir/test/Target/SPIRV/constant.mlir b/mlir/test/Target/SPIRV/constant.mlir index 6aca11ec5e6e6..e06c0146d4ad2 100644 --- a/mlir/test/Target/SPIRV/constant.mlir +++ b/mlir/test/Target/SPIRV/constant.mlir @@ -363,6 +363,13 @@ spirv.module Logical GLSL450 requires #spirv.vce } + // CHECK-LABEL: @array_of_splat_array_of_non_splat_arrays_of_i32 + spirv.func @array_of_splat_array_of_non_splat_arrays_of_i32() -> !spirv.array<2 x !spirv.array<2 x !spirv.array<3 x i32>>> "None" { + // CHECK: spirv.EXT.ConstantCompositeReplicate {{\[}}{{\[}}[1 : i32, 2 : i32, 3 : i32], [4 : i32, 5 : i32, 6 : i32]]] : !spirv.array<2 x !spirv.array<2 x !spirv.array<3 x i32>>> + %0 = spirv.EXT.ConstantCompositeReplicate [[[1 : i32, 2 : i32, 3 : i32], [4 : i32, 5 : i32, 6 : i32]]] : !spirv.array<2 x !spirv.array<2 x !spirv.array<3 x i32>>> + spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<2 x !spirv.array<3 x i32>>> + } + // CHECK-LABEL: @splat_vector_f32 spirv.func @splat_vector_f32() -> (vector<3xf32>) "None" { // CHECK: spirv.EXT.ConstantCompositeReplicate [1.000000e+00 : f32] : vector<3xf32> @@ -411,4 +418,11 @@ spirv.module Logical GLSL450 requires #spirv.vce spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xf32> } + + // CHECK-LABEL: @array_of_splat_array_of_non_splat_arrays_of_f32 + spirv.func @array_of_splat_array_of_non_splat_arrays_of_f32() -> !spirv.array<2 x !spirv.array<2 x !spirv.array<3 x f32>>> "None" { + // CHECK: spirv.EXT.ConstantCompositeReplicate {{\[}}{{\[}}[1.000000e+00 : f32, 2.000000e+00 : f32, 3.000000e+00 : f32], [4.000000e+00 : f32, 5.000000e+00 : f32, 6.000000e+00 : f32]]] : !spirv.array<2 x !spirv.array<2 x !spirv.array<3 x f32>>> + %0 = spirv.EXT.ConstantCompositeReplicate [[[1.0 : f32, 2.0 : f32, 3.0 : f32], [4.0 : f32, 5.0 : f32, 6.0 : f32]]] : !spirv.array<2 x !spirv.array<2 x !spirv.array<3 x f32>>> + spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<2 x !spirv.array<3 x f32>>> + } } From f0ef5237a1ccb5718211c6c78223458262e4c84b Mon Sep 17 00:00:00 2001 From: Mohammadreza Ameri Mahabadian Date: Wed, 30 Jul 2025 14:14:05 +0100 Subject: [PATCH 2/3] Addressing code review comments Signed-off-by: Mohammadreza Ameri Mahabadian --- mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp | 3 +++ mlir/test/Target/SPIRV/constant.mlir | 8 ++++---- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp index c8b87fad8ccad..1306b693de498 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -767,6 +767,9 @@ void mlir::spirv::AddressOfOp::getAsmResultNames( // spirv.EXTConstantCompositeReplicate //===----------------------------------------------------------------------===// +// Returns type of attribute. In case of a TypedAttr this will simply return +// the type. But for an ArrayAttr which is untyped and can be multidimensional +// it creates the ArrayType recursively. static Type getValueType(Attribute attr) { if (auto typedAttr = dyn_cast(attr)) { return typedAttr.getType(); diff --git a/mlir/test/Target/SPIRV/constant.mlir b/mlir/test/Target/SPIRV/constant.mlir index e06c0146d4ad2..e68267373d0af 100644 --- a/mlir/test/Target/SPIRV/constant.mlir +++ b/mlir/test/Target/SPIRV/constant.mlir @@ -363,8 +363,8 @@ spirv.module Logical GLSL450 requires #spirv.vce } - // CHECK-LABEL: @array_of_splat_array_of_non_splat_arrays_of_i32 - spirv.func @array_of_splat_array_of_non_splat_arrays_of_i32() -> !spirv.array<2 x !spirv.array<2 x !spirv.array<3 x i32>>> "None" { + // CHECK-LABEL: @splat_array_of_non_splat_array_of_arrays_of_i32 + spirv.func @splat_array_of_non_splat_array_of_arrays_of_i32() -> !spirv.array<2 x !spirv.array<2 x !spirv.array<3 x i32>>> "None" { // CHECK: spirv.EXT.ConstantCompositeReplicate {{\[}}{{\[}}[1 : i32, 2 : i32, 3 : i32], [4 : i32, 5 : i32, 6 : i32]]] : !spirv.array<2 x !spirv.array<2 x !spirv.array<3 x i32>>> %0 = spirv.EXT.ConstantCompositeReplicate [[[1 : i32, 2 : i32, 3 : i32], [4 : i32, 5 : i32, 6 : i32]]] : !spirv.array<2 x !spirv.array<2 x !spirv.array<3 x i32>>> spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<2 x !spirv.array<3 x i32>>> @@ -419,8 +419,8 @@ spirv.module Logical GLSL450 requires #spirv.vce } - // CHECK-LABEL: @array_of_splat_array_of_non_splat_arrays_of_f32 - spirv.func @array_of_splat_array_of_non_splat_arrays_of_f32() -> !spirv.array<2 x !spirv.array<2 x !spirv.array<3 x f32>>> "None" { + // CHECK-LABEL: @splat_array_of_non_splat_array_of_arrays_of_f32 + spirv.func @splat_array_of_non_splat_array_of_arrays_of_f32() -> !spirv.array<2 x !spirv.array<2 x !spirv.array<3 x f32>>> "None" { // CHECK: spirv.EXT.ConstantCompositeReplicate {{\[}}{{\[}}[1.000000e+00 : f32, 2.000000e+00 : f32, 3.000000e+00 : f32], [4.000000e+00 : f32, 5.000000e+00 : f32, 6.000000e+00 : f32]]] : !spirv.array<2 x !spirv.array<2 x !spirv.array<3 x f32>>> %0 = spirv.EXT.ConstantCompositeReplicate [[[1.0 : f32, 2.0 : f32, 3.0 : f32], [4.0 : f32, 5.0 : f32, 6.0 : f32]]] : !spirv.array<2 x !spirv.array<2 x !spirv.array<3 x f32>>> spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<2 x !spirv.array<3 x f32>>> From 0e80019fb171b367257c9ee1f3241e161d422329 Mon Sep 17 00:00:00 2001 From: Mohammadreza Ameri Mahabadian Date: Fri, 1 Aug 2025 17:43:46 +0100 Subject: [PATCH 3/3] Addressing further code review comments Signed-off-by: Mohammadreza Ameri Mahabadian --- mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp | 2 +- mlir/lib/Target/SPIRV/Serialization/Serializer.cpp | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp index 1306b693de498..f99339852824c 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -767,7 +767,7 @@ void mlir::spirv::AddressOfOp::getAsmResultNames( // spirv.EXTConstantCompositeReplicate //===----------------------------------------------------------------------===// -// Returns type of attribute. In case of a TypedAttr this will simply return +// Returns type of attribute. In case of a TypedAttr this will simply return // the type. But for an ArrayAttr which is untyped and can be multidimensional // it creates the ArrayType recursively. static Type getValueType(Attribute attr) { diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp index 1cb3d71d0d4df..30536638b56f7 100644 --- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp @@ -1187,6 +1187,9 @@ uint32_t Serializer::prepareConstantFp(Location loc, FloatAttr floatAttr, return resultID; } +// Returns type of attribute. In case of a TypedAttr this will simply return +// the type. But for an ArrayAttr which is untyped and can be multidimensional +// it creates the ArrayType recursively. static Type getValueType(Attribute attr) { if (auto typedAttr = dyn_cast(attr)) { return typedAttr.getType();