Skip to content

Commit 51ce7ac

Browse files
committed
Addressing code review comments
Signed-off-by: Mohammadreza Ameri Mahabadian <[email protected]>
1 parent bee1e45 commit 51ce7ac

File tree

3 files changed

+135
-22
lines changed

3 files changed

+135
-22
lines changed

mlir/include/mlir/Dialect/SPIRV/Transforms/Passes.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def SPIRVWebGPUPreparePass : Pass<"spirv-webgpu-prepare", "spirv::ModuleOp"> {
7878
}
7979

8080
def SPIRVReplicatedConstantCompositePass
81-
: Pass<"spirv-replicated-const-composite", "spirv::ModuleOp"> {
81+
: Pass<"spirv-convert-to-replicated-const-composite", "spirv::ModuleOp"> {
8282
let summary = "Convert splat composite constants and spec constants to"
8383
"corresponding replicated constant composite ops defined by"
8484
"SPV_EXT_replicated_composites";

mlir/lib/Dialect/SPIRV/Transforms/ConversionToReplicatedConstantCompositePass.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
//===- ConversionToReplicatedConstantCompositePass.cpp
2-
//---------------------------===//
1+
//===- ConversionToReplicatedConstantCompositePass.cpp --------------------===//
32
//
43
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
54
// See https://llvm.org/LICENSE.txt for license information.
@@ -83,7 +82,7 @@ class ConstantOpConversion : public OpRewritePattern<spirv::ConstantOp> {
8382

8483
if (splatCount == 1)
8584
return rewriter.notifyMatchFailure(op,
86-
"composite has only one consituent");
85+
"composite has only one constituent");
8786

8887
rewriter.replaceOpWithNewOp<spirv::EXTConstantCompositeReplicateOp>(
8988
op, op.getType(), splatAttr);
@@ -102,7 +101,7 @@ class SpecConstantCompositeOpConversion
102101
if (!compositeType)
103102
return rewriter.notifyMatchFailure(op, "not a composite constant");
104103

105-
auto constituents = op.getConstituents();
104+
ArrayAttr constituents = op.getConstituents();
106105
if (constituents.size() == 1)
107106
return rewriter.notifyMatchFailure(op,
108107
"composite has only one consituent");
@@ -113,6 +112,9 @@ class SpecConstantCompositeOpConversion
113112

114113
auto splatConstituent =
115114
dyn_cast<FlatSymbolRefAttr>(op.getConstituents()[0]);
115+
if (!splatConstituent)
116+
return rewriter.notifyMatchFailure(
117+
op, "expected flat symbol reference for splat constituent");
116118

117119
rewriter.replaceOpWithNewOp<spirv::EXTSpecConstantCompositeReplicateOp>(
118120
op, TypeAttr::get(op.getType()), op.getSymNameAttr(), splatConstituent);

mlir/test/Dialect/SPIRV/Transforms/replicated-const-composites.mlir

Lines changed: 128 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
// RUN: mlir-opt -spirv-replicated-const-composite -split-input-file -verify-diagnostics %s -o - | FileCheck %s
1+
// RUN: mlir-opt --spirv-convert-to-replicated-const-composite --split-input-file --verify-diagnostics %s -o - | FileCheck %s
22

33
spirv.module Logical GLSL450 {
44
spirv.func @splat_vector_of_i32() -> (vector<3xi32>) "None" {
5-
%0 = spirv.Constant dense<2> : vector<3xi32>
65
// CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [2 : i32] : vector<3xi32>
6+
%0 = spirv.Constant dense<2> : vector<3xi32>
77
spirv.ReturnValue %0 : vector<3xi32>
88
}
99
}
@@ -12,8 +12,8 @@ spirv.module Logical GLSL450 {
1212

1313
spirv.module Logical GLSL450 {
1414
spirv.func @splat_array_of_i32() -> (!spirv.array<3 x i32>) "None" {
15-
%0 = spirv.Constant [1 : i32, 1 : i32, 1 : i32] : !spirv.array<3 x i32>
1615
// CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [1 : i32] : !spirv.array<3 x i32>
16+
%0 = spirv.Constant [1 : i32, 1 : i32, 1 : i32] : !spirv.array<3 x i32>
1717
spirv.ReturnValue %0 : !spirv.array<3 x i32>
1818
}
1919
}
@@ -22,7 +22,7 @@ spirv.module Logical GLSL450 {
2222

2323
spirv.module Logical GLSL450 {
2424
spirv.func @splat_array_of_splat_array_of_i32() -> (!spirv.array<2 x !spirv.array<3 x i32>>) "None" {
25-
// CHECK: spirv.EXT.ConstantCompositeReplicate [3 : i32] : !spirv.array<2 x !spirv.array<3 x i32>>
25+
// CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [3 : i32] : !spirv.array<2 x !spirv.array<3 x i32>>
2626
%0 = spirv.Constant [[3 : i32, 3 : i32, 3 : i32], [3 : i32, 3 : i32, 3 : i32]] : !spirv.array<2 x !spirv.array<3 x i32>>
2727
spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<3 x i32>>
2828
}
@@ -32,7 +32,7 @@ spirv.module Logical GLSL450 {
3232

3333
spirv.module Logical GLSL450 {
3434
spirv.func @splat_array_of_non_splat_array_of_i32() -> (!spirv.array<2 x !spirv.array<3 x i32>>) "None" {
35-
// CHECK: spirv.EXT.ConstantCompositeReplicate {{\[}}[1 : i32, 2 : i32, 3 : i32]] : !spirv.array<2 x !spirv.array<3 x i32>>
35+
// CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate {{\[}}[1 : i32, 2 : i32, 3 : i32]] : !spirv.array<2 x !spirv.array<3 x i32>>
3636
%0 = spirv.Constant [[1 : i32, 2 : i32, 3 : i32], [1 : i32, 2 : i32, 3 : i32]] : !spirv.array<2 x !spirv.array<3 x i32>>
3737
spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<3 x i32>>
3838
}
@@ -42,8 +42,8 @@ spirv.module Logical GLSL450 {
4242

4343
spirv.module Logical GLSL450 {
4444
spirv.func @splat_array_of_vectors_of_i32() -> (!spirv.array<2xvector<2xi32>>) "None" {
45-
%0 = spirv.Constant [dense<[1, 2]> : vector<2xi32>, dense<[1, 2]> : vector<2xi32>] : !spirv.array<2 x vector<2xi32>>
4645
// CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [dense<[1, 2]> : vector<2xi32>] : !spirv.array<2 x vector<2xi32>>
46+
%0 = spirv.Constant [dense<[1, 2]> : vector<2xi32>, dense<[1, 2]> : vector<2xi32>] : !spirv.array<2 x vector<2xi32>>
4747
spirv.ReturnValue %0 : !spirv.array<2 x vector<2xi32>>
4848
}
4949
}
@@ -52,8 +52,8 @@ spirv.module Logical GLSL450 {
5252

5353
spirv.module Logical GLSL450 {
5454
spirv.func @splat_array_of_splat_vectors_of_i32() -> (!spirv.array<2 x vector<2xi32>>) "None" {
55-
%0 = spirv.Constant [dense<2> : vector<2xi32>, dense<2> : vector<2xi32>] : !spirv.array<2 x vector<2xi32>>
5655
// CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [2 : i32] : !spirv.array<2 x vector<2xi32>>
56+
%0 = spirv.Constant [dense<2> : vector<2xi32>, dense<2> : vector<2xi32>] : !spirv.array<2 x vector<2xi32>>
5757
spirv.ReturnValue %0 : !spirv.array<2 x vector<2xi32>>
5858
}
5959
}
@@ -62,7 +62,7 @@ spirv.module Logical GLSL450 {
6262

6363
spirv.module Logical GLSL450 {
6464
spirv.func @splat_tensor_of_i32() -> (!spirv.array<2 x !spirv.array<3 x i32>>) "None" {
65-
// CHECK: spirv.EXT.ConstantCompositeReplicate [3 : i32] : !spirv.array<2 x !spirv.array<3 x i32>>
65+
// CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [3 : i32] : !spirv.array<2 x !spirv.array<3 x i32>>
6666
%0 = spirv.Constant dense<3> : tensor<2x3xi32> : !spirv.array<2 x !spirv.array<3 x i32>>
6767
spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<3 x i32>>
6868
}
@@ -72,7 +72,7 @@ spirv.module Logical GLSL450 {
7272

7373
spirv.module Logical GLSL450 {
7474
spirv.func @splat_arm_tensor_of_i32() -> (!spirv.arm.tensor<2x3xi32>) "None" {
75-
// CHECK: spirv.EXT.ConstantCompositeReplicate [2 : i32] : !spirv.arm.tensor<2x3xi32>
75+
// CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [2 : i32] : !spirv.arm.tensor<2x3xi32>
7676
%0 = spirv.Constant dense<2> : !spirv.arm.tensor<2x3xi32>
7777
spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xi32>
7878
}
@@ -82,8 +82,8 @@ spirv.module Logical GLSL450 {
8282

8383
spirv.module Logical GLSL450 {
8484
spirv.func @splat_vector_of_f32() -> (vector<3xf32>) "None" {
85-
%0 = spirv.Constant dense<2.0> : vector<3xf32>
8685
// CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [2.000000e+00 : f32] : vector<3xf32>
86+
%0 = spirv.Constant dense<2.0> : vector<3xf32>
8787
spirv.ReturnValue %0 : vector<3xf32>
8888
}
8989
}
@@ -92,8 +92,8 @@ spirv.module Logical GLSL450 {
9292

9393
spirv.module Logical GLSL450 {
9494
spirv.func @splat_array_of_f32() -> (!spirv.array<3 x f32>) "None" {
95-
%0 = spirv.Constant [1.0 : f32, 1.0 : f32, 1.0 : f32] : !spirv.array<3 x f32>
9695
// CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [1.000000e+00 : f32] : !spirv.array<3 x f32>
96+
%0 = spirv.Constant [1.0 : f32, 1.0 : f32, 1.0 : f32] : !spirv.array<3 x f32>
9797
spirv.ReturnValue %0 : !spirv.array<3 x f32>
9898
}
9999
}
@@ -102,7 +102,7 @@ spirv.module Logical GLSL450 {
102102

103103
spirv.module Logical GLSL450 {
104104
spirv.func @splat_array_of_splat_array_of_f32() -> (!spirv.array<2 x !spirv.array<3 x f32>>) "None" {
105-
// CHECK: spirv.EXT.ConstantCompositeReplicate [3.000000e+00 : f32] : !spirv.array<2 x !spirv.array<3 x f32>>
105+
// CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [3.000000e+00 : f32] : !spirv.array<2 x !spirv.array<3 x f32>>
106106
%0 = spirv.Constant [[3.0 : f32, 3.0 : f32, 3.0 : f32], [3.0 : f32, 3.0 : f32, 3.0 : f32]] : !spirv.array<2 x !spirv.array<3 x f32>>
107107
spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<3 x f32>>
108108
}
@@ -112,7 +112,7 @@ spirv.module Logical GLSL450 {
112112

113113
spirv.module Logical GLSL450 {
114114
spirv.func @splat_array_of_non_splat_array_of_f32() -> (!spirv.array<2 x !spirv.array<3 x f32>>) "None" {
115-
// CHECK: spirv.EXT.ConstantCompositeReplicate {{\[}}[1.000000e+00 : f32, 2.000000e+00 : f32, 3.000000e+00 : f32]] : !spirv.array<2 x !spirv.array<3 x f32>>
115+
// CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate {{\[}}[1.000000e+00 : f32, 2.000000e+00 : f32, 3.000000e+00 : f32]] : !spirv.array<2 x !spirv.array<3 x f32>>
116116
%0 = spirv.Constant [[1.0 : f32, 2.0 : f32, 3.0 : f32], [1.0 : f32, 2.0 : f32, 3.0 : f32]] : !spirv.array<2 x !spirv.array<3 x f32>>
117117
spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<3 x f32>>
118118
}
@@ -122,8 +122,8 @@ spirv.module Logical GLSL450 {
122122

123123
spirv.module Logical GLSL450 {
124124
spirv.func @splat_array_of_vectors_of_f32() -> (!spirv.array<2xvector<2xf32>>) "None" {
125-
%0 = spirv.Constant [dense<[1.0, 2.0]> : vector<2xf32>, dense<[1.0, 2.0]> : vector<2xf32>] : !spirv.array<2 x vector<2xf32>>
126125
// CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [dense<[1.000000e+00, 2.000000e+00]> : vector<2xf32>] : !spirv.array<2 x vector<2xf32>>
126+
%0 = spirv.Constant [dense<[1.0, 2.0]> : vector<2xf32>, dense<[1.0, 2.0]> : vector<2xf32>] : !spirv.array<2 x vector<2xf32>>
127127
spirv.ReturnValue %0 : !spirv.array<2 x vector<2xf32>>
128128
}
129129
}
@@ -132,8 +132,8 @@ spirv.module Logical GLSL450 {
132132

133133
spirv.module Logical GLSL450 {
134134
spirv.func @splat_array_of_splat_vectors_of_f32() -> (!spirv.array<2 x vector<2xf32>>) "None" {
135-
%0 = spirv.Constant [dense<2.0> : vector<2xf32>, dense<2.0> : vector<2xf32>] : !spirv.array<2 x vector<2xf32>>
136135
// CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [2.000000e+00 : f32] : !spirv.array<2 x vector<2xf32>>
136+
%0 = spirv.Constant [dense<2.0> : vector<2xf32>, dense<2.0> : vector<2xf32>] : !spirv.array<2 x vector<2xf32>>
137137
spirv.ReturnValue %0 : !spirv.array<2 x vector<2xf32>>
138138
}
139139
}
@@ -142,7 +142,7 @@ spirv.module Logical GLSL450 {
142142

143143
spirv.module Logical GLSL450 {
144144
spirv.func @splat_tensor_of_f32() -> (!spirv.array<2 x !spirv.array<3 x f32>>) "None" {
145-
// CHECK: spirv.EXT.ConstantCompositeReplicate [3.000000e+00 : f32] : !spirv.array<2 x !spirv.array<3 x f32>>
145+
// CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [3.000000e+00 : f32] : !spirv.array<2 x !spirv.array<3 x f32>>
146146
%0 = spirv.Constant dense<3.0> : tensor<2x3xf32> : !spirv.array<2 x !spirv.array<3 x f32>>
147147
spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<3 x f32>>
148148
}
@@ -152,14 +152,94 @@ spirv.module Logical GLSL450 {
152152

153153
spirv.module Logical GLSL450 {
154154
spirv.func @splat_arm_tensor_of_f32() -> (!spirv.arm.tensor<2x3xf32>) "None" {
155-
// CHECK: spirv.EXT.ConstantCompositeReplicate [2.000000e+00 : f32] : !spirv.arm.tensor<2x3xf32>
155+
// CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [2.000000e+00 : f32] : !spirv.arm.tensor<2x3xf32>
156156
%0 = spirv.Constant dense<2.0> : !spirv.arm.tensor<2x3xf32>
157157
spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xf32>
158158
}
159159
}
160160

161161
// -----
162162

163+
spirv.module Logical GLSL450 {
164+
spirv.func @array_of_one_i32() -> (!spirv.array<1 x i32>) "None" {
165+
// CHECK-NOT: spirv.EXT.ConstantCompositeReplicate
166+
%0 = spirv.Constant [1 : i32] : !spirv.array<1 x i32>
167+
spirv.ReturnValue %0 : !spirv.array<1 x i32>
168+
}
169+
}
170+
171+
// -----
172+
173+
spirv.module Logical GLSL450 {
174+
spirv.func @arm_tensor_of_one_i32() -> (!spirv.arm.tensor<1xi32>) "None" {
175+
// CHECK-NOT: spirv.EXT.ConstantCompositeReplicate
176+
%0 = spirv.Constant dense<1> : !spirv.arm.tensor<1xi32>
177+
spirv.ReturnValue %0 : !spirv.arm.tensor<1xi32>
178+
}
179+
}
180+
181+
// -----
182+
183+
spirv.module Logical GLSL450 {
184+
spirv.func @non_splat_vector_of_i32() -> (vector<3xi32>) "None" {
185+
// CHECK-NOT: spirv.EXT.ConstantCompositeReplicate
186+
%0 = spirv.Constant dense<[0, 1, 2]> : vector<3xi32>
187+
spirv.ReturnValue %0 : vector<3xi32>
188+
}
189+
}
190+
191+
// -----
192+
193+
spirv.module Logical GLSL450 {
194+
spirv.func @non_splat_array_of_vectors_of_i32() -> (!spirv.array<2xvector<2xi32>>) "None" {
195+
// CHECK-NOT: spirv.EXT.ConstantCompositeReplicate
196+
%0 = spirv.Constant [dense<[1, 2]> : vector<2xi32>, dense<[1, 3]> : vector<2xi32>] : !spirv.array<2 x vector<2xi32>>
197+
spirv.ReturnValue %0 : !spirv.array<2 x vector<2xi32>>
198+
}
199+
}
200+
201+
// -----
202+
203+
spirv.module Logical GLSL450 {
204+
spirv.func @array_of_one_f32() -> (!spirv.array<1 x f32>) "None" {
205+
// CHECK-NOT: spirv.EXT.ConstantCompositeReplicate
206+
%0 = spirv.Constant [1.0 : f32] : !spirv.array<1 x f32>
207+
spirv.ReturnValue %0 : !spirv.array<1 x f32>
208+
}
209+
}
210+
211+
// -----
212+
213+
spirv.module Logical GLSL450 {
214+
spirv.func @arm_tensor_of_one_f32() -> (!spirv.arm.tensor<1xf32>) "None" {
215+
// CHECK-NOT: spirv.EXT.ConstantCompositeReplicate
216+
%0 = spirv.Constant dense<1.0> : !spirv.arm.tensor<1xf32>
217+
spirv.ReturnValue %0 : !spirv.arm.tensor<1xf32>
218+
}
219+
}
220+
221+
// -----
222+
223+
spirv.module Logical GLSL450 {
224+
spirv.func @non_splat_vector_of_f32() -> (vector<3xf32>) "None" {
225+
// CHECK-NOT: spirv.EXT.ConstantCompositeReplicate
226+
%0 = spirv.Constant dense<[0.0, 1.0, 2.0]> : vector<3xf32>
227+
spirv.ReturnValue %0 : vector<3xf32>
228+
}
229+
}
230+
231+
// -----
232+
233+
spirv.module Logical GLSL450 {
234+
spirv.func @non_splat_array_of_vectors_of_f32() -> (!spirv.array<2xvector<2xf32>>) "None" {
235+
// CHECK-NOT: spirv.EXT.ConstantCompositeReplicate
236+
%0 = spirv.Constant [dense<[1.0, 2.0]> : vector<2xf32>, dense<[1.0, 3.0]> : vector<2xf32>] : !spirv.array<2 x vector<2xf32>>
237+
spirv.ReturnValue %0 : !spirv.array<2 x vector<2xf32>>
238+
}
239+
}
240+
241+
// -----
242+
163243
spirv.module Logical GLSL450 {
164244

165245
spirv.SpecConstant @sc_i32_1 = 1 : i32
@@ -189,4 +269,35 @@ spirv.module Logical GLSL450 {
189269

190270
// CHECK: spirv.EXT.SpecConstantCompositeReplicate @scc_splat_arm_tensor_of_f32 (@sc_f32_1) : !spirv.arm.tensor<3xf32>
191271
spirv.SpecConstantComposite @scc_splat_arm_tensor_of_f32 (@sc_f32_1, @sc_f32_1, @sc_f32_1) : !spirv.arm.tensor<3xf32>
272+
273+
spirv.SpecConstant @sc_i32_2 = 2 : i32
274+
275+
// CHECK-NOT: spirv.EXT.SpecConstantCompositeReplicate
276+
spirv.SpecConstantComposite @scc_array_of_one_i32 (@sc_i32_1) : !spirv.array<1 x i32>
277+
278+
// CHECK-NOT: spirv.EXT.SpecConstantCompositeReplicate
279+
spirv.SpecConstantComposite @scc_arm_tensor_of_one_i32 (@sc_i32_1) : !spirv.arm.tensor<1xi32>
280+
281+
// CHECK-NOT: spirv.EXT.SpecConstantCompositeReplicate
282+
spirv.SpecConstantComposite @scc_non_splat_vector_of_i32 (@sc_i32_1, @sc_i32_1, @sc_i32_2) : vector<3 x i32>
283+
284+
// CHECK-NOT: spirv.EXT.SpecConstantCompositeReplicate
285+
spirv.SpecConstantComposite @scc_non_splat_arm_tensor_of_i32 (@sc_i32_2, @sc_i32_1, @sc_i32_1) : !spirv.arm.tensor<3xi32>
286+
287+
spirv.SpecConstant @sc_f32_2 = 2.0 : f32
288+
289+
// CHECK-NOT: spirv.EXT.SpecConstantCompositeReplicate
290+
spirv.SpecConstantComposite @scc_array_of_one_f32 (@sc_f32_1) : !spirv.array<1 x f32>
291+
292+
// CHECK-NOT: spirv.EXT.SpecConstantCompositeReplicate
293+
spirv.SpecConstantComposite @scc_arm_tensor_of_one_f32 (@sc_f32_1) : !spirv.arm.tensor<1xf32>
294+
295+
// CHECK-NOT: spirv.EXT.SpecConstantCompositeReplicate
296+
spirv.SpecConstantComposite @scc_non_splat_vector_of_f32 (@sc_f32_1, @sc_f32_1, @sc_f32_2) : vector<3 x f32>
297+
298+
// CHECK-NOT: spirv.EXT.SpecConstantCompositeReplicate
299+
spirv.SpecConstantComposite @scc_non_splat_arm_tensor_of_f32 (@sc_f32_2, @sc_f32_1, @sc_f32_1) : !spirv.arm.tensor<3xf32>
300+
301+
// CHECK-NOT: spirv.EXT.SpecConstantCompositeReplicate
302+
spirv.SpecConstantComposite @scc_struct_of_i32_and_f32 (@sc_i32_1, @sc_i32_1, @sc_f32_1) : !spirv.struct<(i32, i32, f32)>
192303
}

0 commit comments

Comments
 (0)