-
Notifications
You must be signed in to change notification settings - Fork 15.4k
[mlir][spirv] Fix verification and serialization replicated constant … #151168
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][spirv] Fix verification and serialization replicated constant … #151168
Conversation
…composites of multi-dimensional array This fixes a bug in verification and serialization of replicated constant composite ops where the splat value can potentially be a multi-dimensional array. Signed-off-by: Mohammadreza Ameri Mahabadian <[email protected]>
|
@llvm/pr-subscribers-mlir-spirv @llvm/pr-subscribers-mlir Author: Mohammadreza Ameri Mahabadian (mahabadm) Changes…composites of multi-dimensional array This fixes a bug in verification and serialization of replicated constant composite ops where the splat value can potentially be a multi-dimensional array. Full diff: https://github.com/llvm/llvm-project/pull/151168.diff 3 Files Affected:
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 52c672a05fa43..c8b87fad8ccad 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -767,19 +767,22 @@ void mlir::spirv::AddressOfOp::getAsmResultNames(
// spirv.EXTConstantCompositeReplicate
//===----------------------------------------------------------------------===//
+static Type getValueType(Attribute attr) {
+ if (auto typedAttr = dyn_cast<TypedAttr>(attr)) {
+ return typedAttr.getType();
+ }
+
+ if (auto arrayAttr = dyn_cast<ArrayAttr>(attr)) {
+ return spirv::ArrayType::get(getValueType(arrayAttr[0]), arrayAttr.size());
+ }
+
+ return nullptr;
+}
+
LogicalResult spirv::EXTConstantCompositeReplicateOp::verify() {
- Type valueType;
- if (auto typedAttr = dyn_cast<TypedAttr>(getValue())) {
- valueType = typedAttr.getType();
- } else if (auto arrayAttr = dyn_cast<ArrayAttr>(getValue())) {
- auto typedElemAttr = dyn_cast<TypedAttr>(arrayAttr[0]);
- if (!typedElemAttr)
- return emitError("value attribute is not typed");
- valueType =
- spirv::ArrayType::get(typedElemAttr.getType(), arrayAttr.size());
- } else {
+ Type valueType = getValueType(getValue());
+ if (!valueType)
return emitError("unknown value attribute type");
- }
auto compositeType = dyn_cast<spirv::CompositeType>(getType());
if (!compositeType)
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index a8a2b2e7cf38c..9e81b6ca505a2 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -1124,6 +1124,18 @@ uint32_t Serializer::prepareConstantFp(Location loc, FloatAttr floatAttr,
return resultID;
}
+static Type getValueType(Attribute attr) {
+ if (auto typedAttr = dyn_cast<TypedAttr>(attr)) {
+ return typedAttr.getType();
+ }
+
+ if (auto arrayAttr = dyn_cast<ArrayAttr>(attr)) {
+ return spirv::ArrayType::get(getValueType(arrayAttr[0]), arrayAttr.size());
+ }
+
+ return nullptr;
+}
+
uint32_t Serializer::prepareConstantCompositeReplicate(Location loc,
Type resultType,
Attribute valueAttr) {
@@ -1137,18 +1149,9 @@ uint32_t Serializer::prepareConstantCompositeReplicate(Location loc,
return 0;
}
- Type valueType;
- if (auto typedAttr = dyn_cast<TypedAttr>(valueAttr)) {
- valueType = typedAttr.getType();
- } else if (auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
- auto typedElemAttr = dyn_cast<TypedAttr>(arrayAttr[0]);
- if (!typedElemAttr)
- return 0;
- valueType =
- spirv::ArrayType::get(typedElemAttr.getType(), arrayAttr.size());
- } else {
+ Type valueType = getValueType(valueAttr);
+ if (!valueAttr)
return 0;
- }
auto compositeType = dyn_cast<CompositeType>(resultType);
if (!compositeType)
diff --git a/mlir/test/Target/SPIRV/constant.mlir b/mlir/test/Target/SPIRV/constant.mlir
index 6aca11ec5e6e6..e06c0146d4ad2 100644
--- a/mlir/test/Target/SPIRV/constant.mlir
+++ b/mlir/test/Target/SPIRV/constant.mlir
@@ -363,6 +363,13 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, ReplicatedCompos
spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xi32>
}
+ // CHECK-LABEL: @array_of_splat_array_of_non_splat_arrays_of_i32
+ spirv.func @array_of_splat_array_of_non_splat_arrays_of_i32() -> !spirv.array<2 x !spirv.array<2 x !spirv.array<3 x i32>>> "None" {
+ // CHECK: spirv.EXT.ConstantCompositeReplicate {{\[}}{{\[}}[1 : i32, 2 : i32, 3 : i32], [4 : i32, 5 : i32, 6 : i32]]] : !spirv.array<2 x !spirv.array<2 x !spirv.array<3 x i32>>>
+ %0 = spirv.EXT.ConstantCompositeReplicate [[[1 : i32, 2 : i32, 3 : i32], [4 : i32, 5 : i32, 6 : i32]]] : !spirv.array<2 x !spirv.array<2 x !spirv.array<3 x i32>>>
+ spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<2 x !spirv.array<3 x i32>>>
+ }
+
// CHECK-LABEL: @splat_vector_f32
spirv.func @splat_vector_f32() -> (vector<3xf32>) "None" {
// CHECK: spirv.EXT.ConstantCompositeReplicate [1.000000e+00 : f32] : vector<3xf32>
@@ -411,4 +418,11 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, ReplicatedCompos
%0 = spirv.EXT.ConstantCompositeReplicate [2.0 : f32] : !spirv.arm.tensor<2x3xf32>
spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xf32>
}
+
+ // CHECK-LABEL: @array_of_splat_array_of_non_splat_arrays_of_f32
+ spirv.func @array_of_splat_array_of_non_splat_arrays_of_f32() -> !spirv.array<2 x !spirv.array<2 x !spirv.array<3 x f32>>> "None" {
+ // CHECK: spirv.EXT.ConstantCompositeReplicate {{\[}}{{\[}}[1.000000e+00 : f32, 2.000000e+00 : f32, 3.000000e+00 : f32], [4.000000e+00 : f32, 5.000000e+00 : f32, 6.000000e+00 : f32]]] : !spirv.array<2 x !spirv.array<2 x !spirv.array<3 x f32>>>
+ %0 = spirv.EXT.ConstantCompositeReplicate [[[1.0 : f32, 2.0 : f32, 3.0 : f32], [4.0 : f32, 5.0 : f32, 6.0 : f32]]] : !spirv.array<2 x !spirv.array<2 x !spirv.array<3 x f32>>>
+ spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<2 x !spirv.array<3 x f32>>>
+ }
}
|
Signed-off-by: Mohammadreza Ameri Mahabadian <[email protected]>
|
@kuhar Can you please advise if there are further changes needed for this? Appreciate that. |
Signed-off-by: Mohammadreza Ameri Mahabadian <[email protected]>
…composites of multi-dimensional array
This fixes a bug in verification and serialization of replicated constant composite ops where the splat value can potentially be a multi-dimensional array.