Skip to content

Commit 6ea73e3

Browse files
committed
Slight change of logic for value type detection
Signed-off-by: Mohammadreza Ameri Mahabadian <[email protected]>
1 parent ae1eb18 commit 6ea73e3

File tree

2 files changed

+96
-8
lines changed

2 files changed

+96
-8
lines changed

mlir/lib/Dialect/SPIRV/Transforms/ConvertToReplicatedConstantCompositePass.cpp

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,20 +22,39 @@ namespace mlir::spirv {
2222

2323
namespace {
2424

25+
static Type getArrayElemType(Attribute attr) {
26+
if (auto typedAttr = dyn_cast<TypedAttr>(attr)) {
27+
return typedAttr.getType();
28+
}
29+
30+
if (auto arrayAttr = dyn_cast<ArrayAttr>(attr)) {
31+
return ArrayType::get(getArrayElemType(arrayAttr[0]), arrayAttr.size());
32+
}
33+
34+
return nullptr;
35+
}
36+
2537
static std::pair<Attribute, uint32_t>
26-
getSplatAttrAndNumElements(Attribute valueAttr) {
38+
getSplatAttrAndNumElements(Attribute valueAttr, Type valueType) {
2739
Attribute attr;
28-
uint32_t numElements = 0;
40+
uint32_t numElements = 1;
41+
42+
auto compositeType = dyn_cast_or_null<spirv::CompositeType>(valueType);
43+
if (!compositeType)
44+
return {nullptr, 1};
45+
2946
if (auto splatAttr = dyn_cast<SplatElementsAttr>(valueAttr)) {
3047
return {splatAttr.getSplatValue<Attribute>(), splatAttr.size()};
3148
}
49+
3250
if (auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
3351
if (llvm::all_equal(arrayAttr)) {
3452
attr = arrayAttr[0];
3553
numElements = arrayAttr.size();
3654

3755
// Find the inner-most splat value for array of composites
38-
auto [newAttr, newNumElements] = getSplatAttrAndNumElements(attr);
56+
auto [newAttr, newNumElements] =
57+
getSplatAttrAndNumElements(attr, getArrayElemType(attr));
3958
if (newAttr) {
4059
return {newAttr, numElements * newNumElements};
4160
}
@@ -50,11 +69,8 @@ struct ConstantOpConversion final : OpRewritePattern<spirv::ConstantOp> {
5069

5170
LogicalResult matchAndRewrite(spirv::ConstantOp op,
5271
PatternRewriter &rewriter) const override {
53-
auto compositeType = dyn_cast_or_null<spirv::CompositeType>(op.getType());
54-
if (!compositeType)
55-
return rewriter.notifyMatchFailure(op, "not a composite constant");
56-
57-
auto [attr, numElements] = getSplatAttrAndNumElements(op.getValue());
72+
auto [attr, numElements] =
73+
getSplatAttrAndNumElements(op.getValue(), op.getType());
5874
if (!attr)
5975
return rewriter.notifyMatchFailure(op, "composite is not splat");
6076

mlir/test/Dialect/SPIRV/Transforms/replicated-const-composites.mlir

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,36 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, ReplicatedCompos
4949
spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xi32>
5050
}
5151

52+
spirv.func @array_of_splat_array_of_non_splat_vectors_of_i32() -> (!spirv.array<1 x !spirv.array<2 x vector<2xi32>>>) "None" {
53+
// CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [dense<[1, 2]> : vector<2xi32>] : !spirv.array<1 x !spirv.array<2 x vector<2xi32>>
54+
%0 = spirv.Constant [[dense<[1, 2]> : vector<2xi32>, dense<[1, 2]> : vector<2xi32>]] : !spirv.array<1 x !spirv.array<2 x vector<2xi32>>>
55+
spirv.ReturnValue %0 : !spirv.array<1 x !spirv.array<2 x vector<2xi32>>>
56+
}
57+
58+
spirv.func @array_of_one_splat_array_of_vector_of_one_i32() -> !spirv.array<1 x !spirv.array<2 x vector<1xi32>>> "None" {
59+
// CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [dense<1> : vector<1xi32>] : !spirv.array<1 x !spirv.array<2 x vector<1xi32>
60+
%cst = spirv.Constant [[dense<1> : vector<1xi32>], [dense<1> : vector<1xi32>]] : !spirv.array<1 x !spirv.array<2 x vector<1xi32>>>
61+
spirv.ReturnValue %cst : !spirv.array<1 x !spirv.array<2 x vector<1xi32>>>
62+
}
63+
64+
spirv.func @splat_array_of_array_of_one_vector_of_one_i32() -> (!spirv.array<2 x !spirv.array<1 x vector<1xi32>>>) "None" {
65+
// CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [dense<1> : vector<1xi32>] : !spirv.array<2 x !spirv.array<1 x vector<1xi32>>>
66+
%0 = spirv.Constant [[dense<1> : vector<1xi32>], [dense<1> : vector<1xi32>]] : !spirv.array<2 x !spirv.array<1 x vector<1xi32>>>
67+
spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<1 x vector<1xi32>>>
68+
}
69+
70+
spirv.func @array_of_one_array_of_one_splat_vector_of_i32() -> (!spirv.array<1 x !spirv.array<1 x vector<2xi32>>>) "None" {
71+
// CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [1 : i32] : !spirv.array<1 x !spirv.array<1 x vector<2xi32>>>
72+
%0 = spirv.Constant [[dense<1> : vector<2xi32>]] : !spirv.array<1 x !spirv.array<1 x vector<2xi32>>>
73+
spirv.ReturnValue %0 : !spirv.array<1 x !spirv.array<1 x vector<2xi32>>>
74+
}
75+
76+
spirv.func @splat_array_of_splat_array_of_non_splat_array_of_i32() -> (!spirv.array<2 x !spirv.array<2 x !spirv.array<3 x i32>>>) "None" {
77+
// CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate {{\[}}[1 : i32, 2 : i32, 3 : i32]] : !spirv.array<2 x !spirv.array<2 x !spirv.array<3 x i32>>>
78+
%0 = spirv.Constant [[[1 : i32, 2 : i32, 3 : i32], [1 : i32, 2 : i32, 3 : i32]], [[1 : i32, 2 : i32, 3 : i32], [1 : i32, 2 : i32, 3 : i32]]] : !spirv.array<2 x !spirv.array<2 x !spirv.array<3 x i32>>>
79+
spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<2 x !spirv.array<3 x i32>>>
80+
}
81+
5282
spirv.func @splat_vector_of_f32() -> (vector<3xf32>) "None" {
5383
// CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [2.000000e+00 : f32] : vector<3xf32>
5484
%0 = spirv.Constant dense<2.0> : vector<3xf32>
@@ -97,6 +127,36 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, ReplicatedCompos
97127
spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xf32>
98128
}
99129

130+
spirv.func @array_of_splat_array_of_non_splat_vectors_of_f32() -> (!spirv.array<1 x !spirv.array<2 x vector<2xf32>>>) "None" {
131+
// CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [dense<[1.000000e+00, 2.000000e+00]> : vector<2xf32>] : !spirv.array<1 x !spirv.array<2 x vector<2xf32>>
132+
%0 = spirv.Constant [[dense<[1.0, 2.0]> : vector<2xf32>, dense<[1.0, 2.0]> : vector<2xf32>]] : !spirv.array<1 x !spirv.array<2 x vector<2xf32>>>
133+
spirv.ReturnValue %0 : !spirv.array<1 x !spirv.array<2 x vector<2xf32>>>
134+
}
135+
136+
spirv.func @array_of_one_splat_array_of_vector_of_one_f32() -> !spirv.array<1 x !spirv.array<2 x vector<1xf32>>> "None" {
137+
// CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [dense<1.000000e+00> : vector<1xf32>] : !spirv.array<1 x !spirv.array<2 x vector<1xf32>
138+
%cst = spirv.Constant [[dense<1.0> : vector<1xf32>], [dense<1.0> : vector<1xf32>]] : !spirv.array<1 x !spirv.array<2 x vector<1xf32>>>
139+
spirv.ReturnValue %cst : !spirv.array<1 x !spirv.array<2 x vector<1xf32>>>
140+
}
141+
142+
spirv.func @splat_array_of_array_of_one_vector_of_one_f32() -> (!spirv.array<2 x !spirv.array<1 x vector<1xf32>>>) "None" {
143+
// CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [dense<1.000000e+00> : vector<1xf32>] : !spirv.array<2 x !spirv.array<1 x vector<1xf32>>>
144+
%0 = spirv.Constant [[dense<1.0> : vector<1xf32>], [dense<1.0> : vector<1xf32>]] : !spirv.array<2 x !spirv.array<1 x vector<1xf32>>>
145+
spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<1 x vector<1xf32>>>
146+
}
147+
148+
spirv.func @array_of_one_array_of_one_splat_vector_of_f32() -> (!spirv.array<1 x !spirv.array<1 x vector<2xf32>>>) "None" {
149+
// CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [1.000000e+00 : f32] : !spirv.array<1 x !spirv.array<1 x vector<2xf32>>>
150+
%0 = spirv.Constant [[dense<1.0> : vector<2xf32>]] : !spirv.array<1 x !spirv.array<1 x vector<2xf32>>>
151+
spirv.ReturnValue %0 : !spirv.array<1 x !spirv.array<1 x vector<2xf32>>>
152+
}
153+
154+
spirv.func @splat_array_of_splat_array_of_non_splat_array_of_f32() -> (!spirv.array<2 x !spirv.array<2 x !spirv.array<3 x f32>>>) "None" {
155+
// CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate {{\[}}[1.000000e+00 : f32, 2.000000e+00 : f32, 3.000000e+00 : f32]] : !spirv.array<2 x !spirv.array<2 x !spirv.array<3 x f32>>>
156+
%0 = spirv.Constant [[[1.0 : f32, 2.0 : f32, 3.0 : f32], [1.0 : f32, 2.0 : f32, 3.0 : f32]], [[1.0 : f32, 2.0 : f32, 3.0 : f32], [1.0 : f32, 2.0 : f32, 3.0 : f32]]] : !spirv.array<2 x !spirv.array<2 x !spirv.array<3 x f32>>>
157+
spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<2 x !spirv.array<3 x f32>>>
158+
}
159+
100160
spirv.func @array_of_one_i32() -> (!spirv.array<1 x i32>) "None" {
101161
// CHECK-NOT: spirv.EXT.ConstantCompositeReplicate
102162
%0 = spirv.Constant [1 : i32] : !spirv.array<1 x i32>
@@ -144,6 +204,18 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, ReplicatedCompos
144204
%0 = spirv.Constant [dense<[1.0, 2.0]> : vector<2xf32>, dense<[1.0, 3.0]> : vector<2xf32>] : !spirv.array<2 x vector<2xf32>>
145205
spirv.ReturnValue %0 : !spirv.array<2 x vector<2xf32>>
146206
}
207+
208+
spirv.func @array_of_one_array_of_one_non_splat_vector_of_i32() -> (!spirv.array<1 x !spirv.array<1 x vector<2xi32>>>) "None" {
209+
// CHECK-NOT spirv.EXT.ConstantCompositeReplicate
210+
%0 = spirv.Constant [[dense<[1, 2]> : vector<2xi32>]] : !spirv.array<1 x !spirv.array<1 x vector<2xi32>>>
211+
spirv.ReturnValue %0 : !spirv.array<1 x !spirv.array<1 x vector<2xi32>>>
212+
}
213+
214+
spirv.func @array_of_one_array_of_one_vector_of_one_i32() -> (!spirv.array<1 x !spirv.array<1 x vector<1xi32>>>) "None" {
215+
// CHECK-NOT spirv.EXT.ConstantCompositeReplicate
216+
%0 = spirv.Constant [[dense<1> : vector<1xi32>]] : !spirv.array<1 x !spirv.array<1 x vector<1xi32>>>
217+
spirv.ReturnValue %0 : !spirv.array<1 x !spirv.array<1 x vector<1xi32>>>
218+
}
147219
}
148220

149221
// -----

0 commit comments

Comments
 (0)