Skip to content

Commit c246a1a

Browse files
committed
[mlir][spirv] Fix verification and serialization replicated constant composites of multi-dimensional array
This fixes a bug in verification and serialization of replicated constant composite ops where the splat value can potentially be a multi-dimensional array. Signed-off-by: Mohammadreza Ameri Mahabadian <[email protected]>
1 parent 74001be commit c246a1a

File tree

3 files changed

+42
-22
lines changed

3 files changed

+42
-22
lines changed

mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -767,19 +767,22 @@ void mlir::spirv::AddressOfOp::getAsmResultNames(
767767
// spirv.EXTConstantCompositeReplicate
768768
//===----------------------------------------------------------------------===//
769769

770+
static Type getValueType(Attribute attr) {
771+
if (auto typedAttr = dyn_cast<TypedAttr>(attr)) {
772+
return typedAttr.getType();
773+
}
774+
775+
if (auto arrayAttr = dyn_cast<ArrayAttr>(attr)) {
776+
return spirv::ArrayType::get(getValueType(arrayAttr[0]), arrayAttr.size());
777+
}
778+
779+
return nullptr;
780+
}
781+
770782
LogicalResult spirv::EXTConstantCompositeReplicateOp::verify() {
771-
Type valueType;
772-
if (auto typedAttr = dyn_cast<TypedAttr>(getValue())) {
773-
valueType = typedAttr.getType();
774-
} else if (auto arrayAttr = dyn_cast<ArrayAttr>(getValue())) {
775-
auto typedElemAttr = dyn_cast<TypedAttr>(arrayAttr[0]);
776-
if (!typedElemAttr)
777-
return emitError("value attribute is not typed");
778-
valueType =
779-
spirv::ArrayType::get(typedElemAttr.getType(), arrayAttr.size());
780-
} else {
783+
Type valueType = getValueType(getValue());
784+
if (!valueType)
781785
return emitError("unknown value attribute type");
782-
}
783786

784787
auto compositeType = dyn_cast<spirv::CompositeType>(getType());
785788
if (!compositeType)

mlir/lib/Target/SPIRV/Serialization/Serializer.cpp

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1124,6 +1124,18 @@ uint32_t Serializer::prepareConstantFp(Location loc, FloatAttr floatAttr,
11241124
return resultID;
11251125
}
11261126

1127+
static Type getValueType(Attribute attr) {
1128+
if (auto typedAttr = dyn_cast<TypedAttr>(attr)) {
1129+
return typedAttr.getType();
1130+
}
1131+
1132+
if (auto arrayAttr = dyn_cast<ArrayAttr>(attr)) {
1133+
return spirv::ArrayType::get(getValueType(arrayAttr[0]), arrayAttr.size());
1134+
}
1135+
1136+
return nullptr;
1137+
}
1138+
11271139
uint32_t Serializer::prepareConstantCompositeReplicate(Location loc,
11281140
Type resultType,
11291141
Attribute valueAttr) {
@@ -1137,18 +1149,9 @@ uint32_t Serializer::prepareConstantCompositeReplicate(Location loc,
11371149
return 0;
11381150
}
11391151

1140-
Type valueType;
1141-
if (auto typedAttr = dyn_cast<TypedAttr>(valueAttr)) {
1142-
valueType = typedAttr.getType();
1143-
} else if (auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
1144-
auto typedElemAttr = dyn_cast<TypedAttr>(arrayAttr[0]);
1145-
if (!typedElemAttr)
1146-
return 0;
1147-
valueType =
1148-
spirv::ArrayType::get(typedElemAttr.getType(), arrayAttr.size());
1149-
} else {
1152+
Type valueType = getValueType(valueAttr);
1153+
if (!valueAttr)
11501154
return 0;
1151-
}
11521155

11531156
auto compositeType = dyn_cast<CompositeType>(resultType);
11541157
if (!compositeType)

mlir/test/Target/SPIRV/constant.mlir

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,13 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, ReplicatedCompos
363363
spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xi32>
364364
}
365365

366+
// CHECK-LABEL: @array_of_splat_array_of_non_splat_arrays_of_i32
367+
spirv.func @array_of_splat_array_of_non_splat_arrays_of_i32() -> !spirv.array<2 x !spirv.array<2 x !spirv.array<3 x i32>>> "None" {
368+
// CHECK: spirv.EXT.ConstantCompositeReplicate {{\[}}{{\[}}[1 : i32, 2 : i32, 3 : i32], [4 : i32, 5 : i32, 6 : i32]]] : !spirv.array<2 x !spirv.array<2 x !spirv.array<3 x i32>>>
369+
%0 = spirv.EXT.ConstantCompositeReplicate [[[1 : i32, 2 : i32, 3 : i32], [4 : i32, 5 : i32, 6 : i32]]] : !spirv.array<2 x !spirv.array<2 x !spirv.array<3 x i32>>>
370+
spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<2 x !spirv.array<3 x i32>>>
371+
}
372+
366373
// CHECK-LABEL: @splat_vector_f32
367374
spirv.func @splat_vector_f32() -> (vector<3xf32>) "None" {
368375
// CHECK: spirv.EXT.ConstantCompositeReplicate [1.000000e+00 : f32] : vector<3xf32>
@@ -411,4 +418,11 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, ReplicatedCompos
411418
%0 = spirv.EXT.ConstantCompositeReplicate [2.0 : f32] : !spirv.arm.tensor<2x3xf32>
412419
spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xf32>
413420
}
421+
422+
// CHECK-LABEL: @array_of_splat_array_of_non_splat_arrays_of_f32
423+
spirv.func @array_of_splat_array_of_non_splat_arrays_of_f32() -> !spirv.array<2 x !spirv.array<2 x !spirv.array<3 x f32>>> "None" {
424+
// CHECK: spirv.EXT.ConstantCompositeReplicate {{\[}}{{\[}}[1.000000e+00 : f32, 2.000000e+00 : f32, 3.000000e+00 : f32], [4.000000e+00 : f32, 5.000000e+00 : f32, 6.000000e+00 : f32]]] : !spirv.array<2 x !spirv.array<2 x !spirv.array<3 x f32>>>
425+
%0 = spirv.EXT.ConstantCompositeReplicate [[[1.0 : f32, 2.0 : f32, 3.0 : f32], [4.0 : f32, 5.0 : f32, 6.0 : f32]]] : !spirv.array<2 x !spirv.array<2 x !spirv.array<3 x f32>>>
426+
spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<2 x !spirv.array<3 x f32>>>
427+
}
414428
}

0 commit comments

Comments
 (0)