Skip to content

Commit d065a5a

Browse files
committed
Addressing further code review comments
Signed-off-by: Mohammadreza Ameri Mahabadian <[email protected]>
1 parent 51ce7ac commit d065a5a

File tree

4 files changed

+46
-134
lines changed

4 files changed

+46
-134
lines changed

mlir/include/mlir/Dialect/SPIRV/Transforms/Passes.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,8 @@ def SPIRVWebGPUPreparePass : Pass<"spirv-webgpu-prepare", "spirv::ModuleOp"> {
7979

8080
def SPIRVReplicatedConstantCompositePass
8181
: Pass<"spirv-convert-to-replicated-const-composite", "spirv::ModuleOp"> {
82-
let summary = "Convert splat composite constants and spec constants to"
83-
"corresponding replicated constant composite ops defined by"
82+
let summary = "Convert splat composite constants and spec constants to "
83+
"corresponding replicated constant composite ops defined by "
8484
"SPV_EXT_replicated_composites";
8585
}
8686

mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
set(LLVM_OPTIONAL_SOURCES
22
CanonicalizeGLPass.cpp
3-
ConversionToReplicatedConstantCompositePass.cpp
3+
ConvertToReplicatedConstantCompositePass.cpp
44
DecorateCompositeTypeLayoutPass.cpp
55
LowerABIAttributesPass.cpp
66
RewriteInsertsPass.cpp
@@ -31,7 +31,7 @@ add_mlir_dialect_library(MLIRSPIRVConversion
3131

3232
add_mlir_dialect_library(MLIRSPIRVTransforms
3333
CanonicalizeGLPass.cpp
34-
ConversionToReplicatedConstantCompositePass.cpp
34+
ConvertToReplicatedConstantCompositePass.cpp
3535
DecorateCompositeTypeLayoutPass.cpp
3636
LowerABIAttributesPass.cpp
3737
RewriteInsertsPass.cpp
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
//===- ConversionToReplicatedConstantCompositePass.cpp --------------------===//
1+
//===- ConvertToReplicatedConstantCompositePass.cpp --------------------===//
22
//
33
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44
// See https://llvm.org/LICENSE.txt for license information.
@@ -14,21 +14,18 @@
1414

1515
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
1616
#include "mlir/Dialect/SPIRV/Transforms/Passes.h"
17-
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
17+
#include "mlir/Transforms/WalkPatternRewriteDriver.h"
1818

19-
namespace mlir {
20-
namespace spirv {
19+
namespace mlir::spirv {
2120
#define GEN_PASS_DEF_SPIRVREPLICATEDCONSTANTCOMPOSITEPASS
2221
#include "mlir/Dialect/SPIRV/Transforms/Passes.h.inc"
23-
} // namespace spirv
24-
} // namespace mlir
25-
26-
using namespace mlir;
2722

2823
namespace {
2924

30-
Attribute getSplatAttribute(Attribute valueAttr, uint32_t &splatCount) {
25+
static std::pair<Attribute, uint32_t>
26+
getSplatAttributeAndCount(Attribute valueAttr) {
3127
Attribute attr;
28+
uint32_t splatCount = 0;
3229
if (auto denseAttr = dyn_cast<DenseElementsAttr>(valueAttr)) {
3330
if (denseAttr.isSplat()) {
3431
attr = denseAttr.getSplatValue<Attribute>();
@@ -44,56 +41,53 @@ Attribute getSplatAttribute(Attribute valueAttr, uint32_t &splatCount) {
4441

4542
if (attr) {
4643
if (auto typedAttr = dyn_cast<TypedAttr>(attr)) {
47-
if (isa<spirv::CompositeType>(typedAttr.getType()))
48-
if (Attribute newAttr = getSplatAttribute(attr, splatCount))
49-
attr = newAttr;
44+
if (isa<spirv::CompositeType>(typedAttr.getType())) {
45+
std::pair<Attribute, uint32_t> newSplatAttrAndCount =
46+
getSplatAttributeAndCount(attr);
47+
if (newSplatAttrAndCount.first) {
48+
return newSplatAttrAndCount;
49+
}
50+
}
5051
} else if (isa<ArrayAttr>(attr)) {
51-
if (Attribute newAttr = getSplatAttribute(attr, splatCount))
52-
attr = newAttr;
52+
std::pair<Attribute, uint32_t> newSplatAttrAndCount =
53+
getSplatAttributeAndCount(attr);
54+
if (newSplatAttrAndCount.first) {
55+
return newSplatAttrAndCount;
56+
}
5357
}
5458
}
5559

56-
return attr;
60+
return {attr, splatCount};
5761
}
5862

59-
} // namespace
60-
61-
namespace {
62-
class ConversionToReplicatedConstantCompositePass
63-
: public spirv::impl::SPIRVReplicatedConstantCompositePassBase<
64-
ConversionToReplicatedConstantCompositePass> {
65-
public:
66-
void runOnOperation() override;
67-
};
68-
69-
class ConstantOpConversion : public OpRewritePattern<spirv::ConstantOp> {
70-
using OpRewritePattern<spirv::ConstantOp>::OpRewritePattern;
63+
struct ConstantOpConversion final : OpRewritePattern<spirv::ConstantOp> {
64+
using OpRewritePattern::OpRewritePattern;
7165

7266
LogicalResult matchAndRewrite(spirv::ConstantOp op,
7367
PatternRewriter &rewriter) const override {
7468
auto compositeType = dyn_cast_or_null<spirv::CompositeType>(op.getType());
7569
if (!compositeType)
7670
return rewriter.notifyMatchFailure(op, "not a composite constant");
7771

78-
uint32_t splatCount = 0;
79-
Attribute splatAttr = getSplatAttribute(op.getValue(), splatCount);
80-
if (!splatAttr)
72+
std::pair<Attribute, uint32_t> splatAttrAndCount =
73+
getSplatAttributeAndCount(op.getValue());
74+
if (!splatAttrAndCount.first)
8175
return rewriter.notifyMatchFailure(op, "composite is not splat");
8276

83-
if (splatCount == 1)
77+
if (splatAttrAndCount.second == 1)
8478
return rewriter.notifyMatchFailure(op,
8579
"composite has only one constituent");
8680

8781
rewriter.replaceOpWithNewOp<spirv::EXTConstantCompositeReplicateOp>(
88-
op, op.getType(), splatAttr);
82+
op, op.getType(), splatAttrAndCount.first);
8983

9084
return success();
9185
}
9286
};
9387

94-
class SpecConstantCompositeOpConversion
95-
: public OpRewritePattern<spirv::SpecConstantCompositeOp> {
96-
using OpRewritePattern<spirv::SpecConstantCompositeOp>::OpRewritePattern;
88+
struct SpecConstantCompositeOpConversion final
89+
: OpRewritePattern<spirv::SpecConstantCompositeOp> {
90+
using OpRewritePattern::OpRewritePattern;
9791

9892
LogicalResult matchAndRewrite(spirv::SpecConstantCompositeOp op,
9993
PatternRewriter &rewriter) const override {
@@ -123,15 +117,17 @@ class SpecConstantCompositeOpConversion
123117
}
124118
};
125119

126-
void ConversionToReplicatedConstantCompositePass::runOnOperation() {
127-
MLIRContext *context = &getContext();
128-
RewritePatternSet patterns(context);
129-
patterns.add<ConstantOpConversion>(context);
130-
patterns.add<SpecConstantCompositeOpConversion>(context);
131-
132-
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
133-
signalPassFailure();
120+
struct ConvertToReplicatedConstantCompositePass final
121+
: spirv::impl::SPIRVReplicatedConstantCompositePassBase<
122+
ConvertToReplicatedConstantCompositePass> {
123+
void runOnOperation() override {
124+
MLIRContext *context = &getContext();
125+
RewritePatternSet patterns(context);
126+
patterns.add<ConstantOpConversion, SpecConstantCompositeOpConversion>(
127+
context);
128+
walkAndApplyPatterns(getOperation(), std::move(patterns));
134129
}
135-
}
130+
};
136131

137-
} // namespace
132+
} // namespace
133+
} // namespace mlir::spirv

mlir/test/Dialect/SPIRV/Transforms/replicated-const-composites.mlir

Lines changed: 1 addition & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -1,216 +1,132 @@
1-
// RUN: mlir-opt --spirv-convert-to-replicated-const-composite --split-input-file --verify-diagnostics %s -o - | FileCheck %s
1+
// RUN: mlir-opt --spirv-convert-to-replicated-const-composite --split-input-file --verify-diagnostics %s | FileCheck %s
22

33
spirv.module Logical GLSL450 {
44
spirv.func @splat_vector_of_i32() -> (vector<3xi32>) "None" {
55
// CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [2 : i32] : vector<3xi32>
66
%0 = spirv.Constant dense<2> : vector<3xi32>
77
spirv.ReturnValue %0 : vector<3xi32>
88
}
9-
}
10-
11-
// -----
129

13-
spirv.module Logical GLSL450 {
1410
spirv.func @splat_array_of_i32() -> (!spirv.array<3 x i32>) "None" {
1511
// CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [1 : i32] : !spirv.array<3 x i32>
1612
%0 = spirv.Constant [1 : i32, 1 : i32, 1 : i32] : !spirv.array<3 x i32>
1713
spirv.ReturnValue %0 : !spirv.array<3 x i32>
1814
}
19-
}
2015

21-
// -----
22-
23-
spirv.module Logical GLSL450 {
2416
spirv.func @splat_array_of_splat_array_of_i32() -> (!spirv.array<2 x !spirv.array<3 x i32>>) "None" {
2517
// CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [3 : i32] : !spirv.array<2 x !spirv.array<3 x i32>>
2618
%0 = spirv.Constant [[3 : i32, 3 : i32, 3 : i32], [3 : i32, 3 : i32, 3 : i32]] : !spirv.array<2 x !spirv.array<3 x i32>>
2719
spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<3 x i32>>
2820
}
29-
}
30-
31-
// -----
3221

33-
spirv.module Logical GLSL450 {
3422
spirv.func @splat_array_of_non_splat_array_of_i32() -> (!spirv.array<2 x !spirv.array<3 x i32>>) "None" {
3523
// CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate {{\[}}[1 : i32, 2 : i32, 3 : i32]] : !spirv.array<2 x !spirv.array<3 x i32>>
3624
%0 = spirv.Constant [[1 : i32, 2 : i32, 3 : i32], [1 : i32, 2 : i32, 3 : i32]] : !spirv.array<2 x !spirv.array<3 x i32>>
3725
spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<3 x i32>>
3826
}
39-
}
4027

41-
// -----
42-
43-
spirv.module Logical GLSL450 {
4428
spirv.func @splat_array_of_vectors_of_i32() -> (!spirv.array<2xvector<2xi32>>) "None" {
4529
// CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [dense<[1, 2]> : vector<2xi32>] : !spirv.array<2 x vector<2xi32>>
4630
%0 = spirv.Constant [dense<[1, 2]> : vector<2xi32>, dense<[1, 2]> : vector<2xi32>] : !spirv.array<2 x vector<2xi32>>
4731
spirv.ReturnValue %0 : !spirv.array<2 x vector<2xi32>>
4832
}
49-
}
50-
51-
// -----
5233

53-
spirv.module Logical GLSL450 {
5434
spirv.func @splat_array_of_splat_vectors_of_i32() -> (!spirv.array<2 x vector<2xi32>>) "None" {
5535
// CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [2 : i32] : !spirv.array<2 x vector<2xi32>>
5636
%0 = spirv.Constant [dense<2> : vector<2xi32>, dense<2> : vector<2xi32>] : !spirv.array<2 x vector<2xi32>>
5737
spirv.ReturnValue %0 : !spirv.array<2 x vector<2xi32>>
5838
}
59-
}
6039

61-
// -----
62-
63-
spirv.module Logical GLSL450 {
6440
spirv.func @splat_tensor_of_i32() -> (!spirv.array<2 x !spirv.array<3 x i32>>) "None" {
6541
// CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [3 : i32] : !spirv.array<2 x !spirv.array<3 x i32>>
6642
%0 = spirv.Constant dense<3> : tensor<2x3xi32> : !spirv.array<2 x !spirv.array<3 x i32>>
6743
spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<3 x i32>>
6844
}
69-
}
70-
71-
// -----
7245

73-
spirv.module Logical GLSL450 {
7446
spirv.func @splat_arm_tensor_of_i32() -> (!spirv.arm.tensor<2x3xi32>) "None" {
7547
// CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [2 : i32] : !spirv.arm.tensor<2x3xi32>
7648
%0 = spirv.Constant dense<2> : !spirv.arm.tensor<2x3xi32>
7749
spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xi32>
7850
}
79-
}
8051

81-
// -----
82-
83-
spirv.module Logical GLSL450 {
8452
spirv.func @splat_vector_of_f32() -> (vector<3xf32>) "None" {
8553
// CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [2.000000e+00 : f32] : vector<3xf32>
8654
%0 = spirv.Constant dense<2.0> : vector<3xf32>
8755
spirv.ReturnValue %0 : vector<3xf32>
8856
}
89-
}
90-
91-
// -----
9257

93-
spirv.module Logical GLSL450 {
9458
spirv.func @splat_array_of_f32() -> (!spirv.array<3 x f32>) "None" {
9559
// CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [1.000000e+00 : f32] : !spirv.array<3 x f32>
9660
%0 = spirv.Constant [1.0 : f32, 1.0 : f32, 1.0 : f32] : !spirv.array<3 x f32>
9761
spirv.ReturnValue %0 : !spirv.array<3 x f32>
9862
}
99-
}
10063

101-
// -----
102-
103-
spirv.module Logical GLSL450 {
10464
spirv.func @splat_array_of_splat_array_of_f32() -> (!spirv.array<2 x !spirv.array<3 x f32>>) "None" {
10565
// CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [3.000000e+00 : f32] : !spirv.array<2 x !spirv.array<3 x f32>>
10666
%0 = spirv.Constant [[3.0 : f32, 3.0 : f32, 3.0 : f32], [3.0 : f32, 3.0 : f32, 3.0 : f32]] : !spirv.array<2 x !spirv.array<3 x f32>>
10767
spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<3 x f32>>
10868
}
109-
}
110-
111-
// -----
11269

113-
spirv.module Logical GLSL450 {
11470
spirv.func @splat_array_of_non_splat_array_of_f32() -> (!spirv.array<2 x !spirv.array<3 x f32>>) "None" {
11571
// CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate {{\[}}[1.000000e+00 : f32, 2.000000e+00 : f32, 3.000000e+00 : f32]] : !spirv.array<2 x !spirv.array<3 x f32>>
11672
%0 = spirv.Constant [[1.0 : f32, 2.0 : f32, 3.0 : f32], [1.0 : f32, 2.0 : f32, 3.0 : f32]] : !spirv.array<2 x !spirv.array<3 x f32>>
11773
spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<3 x f32>>
11874
}
119-
}
12075

121-
// -----
122-
123-
spirv.module Logical GLSL450 {
12476
spirv.func @splat_array_of_vectors_of_f32() -> (!spirv.array<2xvector<2xf32>>) "None" {
12577
// CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [dense<[1.000000e+00, 2.000000e+00]> : vector<2xf32>] : !spirv.array<2 x vector<2xf32>>
12678
%0 = spirv.Constant [dense<[1.0, 2.0]> : vector<2xf32>, dense<[1.0, 2.0]> : vector<2xf32>] : !spirv.array<2 x vector<2xf32>>
12779
spirv.ReturnValue %0 : !spirv.array<2 x vector<2xf32>>
12880
}
129-
}
130-
131-
// -----
13281

133-
spirv.module Logical GLSL450 {
13482
spirv.func @splat_array_of_splat_vectors_of_f32() -> (!spirv.array<2 x vector<2xf32>>) "None" {
13583
// CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [2.000000e+00 : f32] : !spirv.array<2 x vector<2xf32>>
13684
%0 = spirv.Constant [dense<2.0> : vector<2xf32>, dense<2.0> : vector<2xf32>] : !spirv.array<2 x vector<2xf32>>
13785
spirv.ReturnValue %0 : !spirv.array<2 x vector<2xf32>>
13886
}
139-
}
14087

141-
// -----
142-
143-
spirv.module Logical GLSL450 {
14488
spirv.func @splat_tensor_of_f32() -> (!spirv.array<2 x !spirv.array<3 x f32>>) "None" {
14589
// CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [3.000000e+00 : f32] : !spirv.array<2 x !spirv.array<3 x f32>>
14690
%0 = spirv.Constant dense<3.0> : tensor<2x3xf32> : !spirv.array<2 x !spirv.array<3 x f32>>
14791
spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<3 x f32>>
14892
}
149-
}
150-
151-
// -----
15293

153-
spirv.module Logical GLSL450 {
15494
spirv.func @splat_arm_tensor_of_f32() -> (!spirv.arm.tensor<2x3xf32>) "None" {
15595
// CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [2.000000e+00 : f32] : !spirv.arm.tensor<2x3xf32>
15696
%0 = spirv.Constant dense<2.0> : !spirv.arm.tensor<2x3xf32>
15797
spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xf32>
15898
}
159-
}
160-
161-
// -----
16299

163-
spirv.module Logical GLSL450 {
164100
spirv.func @array_of_one_i32() -> (!spirv.array<1 x i32>) "None" {
165101
// CHECK-NOT: spirv.EXT.ConstantCompositeReplicate
166102
%0 = spirv.Constant [1 : i32] : !spirv.array<1 x i32>
167103
spirv.ReturnValue %0 : !spirv.array<1 x i32>
168104
}
169-
}
170-
171-
// -----
172105

173-
spirv.module Logical GLSL450 {
174106
spirv.func @arm_tensor_of_one_i32() -> (!spirv.arm.tensor<1xi32>) "None" {
175107
// CHECK-NOT: spirv.EXT.ConstantCompositeReplicate
176108
%0 = spirv.Constant dense<1> : !spirv.arm.tensor<1xi32>
177109
spirv.ReturnValue %0 : !spirv.arm.tensor<1xi32>
178110
}
179-
}
180111

181-
// -----
182-
183-
spirv.module Logical GLSL450 {
184112
spirv.func @non_splat_vector_of_i32() -> (vector<3xi32>) "None" {
185113
// CHECK-NOT: spirv.EXT.ConstantCompositeReplicate
186114
%0 = spirv.Constant dense<[0, 1, 2]> : vector<3xi32>
187115
spirv.ReturnValue %0 : vector<3xi32>
188116
}
189-
}
190117

191-
// -----
192-
193-
spirv.module Logical GLSL450 {
194118
spirv.func @non_splat_array_of_vectors_of_i32() -> (!spirv.array<2xvector<2xi32>>) "None" {
195119
// CHECK-NOT: spirv.EXT.ConstantCompositeReplicate
196120
%0 = spirv.Constant [dense<[1, 2]> : vector<2xi32>, dense<[1, 3]> : vector<2xi32>] : !spirv.array<2 x vector<2xi32>>
197121
spirv.ReturnValue %0 : !spirv.array<2 x vector<2xi32>>
198122
}
199-
}
200-
201-
// -----
202123

203-
spirv.module Logical GLSL450 {
204124
spirv.func @array_of_one_f32() -> (!spirv.array<1 x f32>) "None" {
205125
// CHECK-NOT: spirv.EXT.ConstantCompositeReplicate
206126
%0 = spirv.Constant [1.0 : f32] : !spirv.array<1 x f32>
207127
spirv.ReturnValue %0 : !spirv.array<1 x f32>
208128
}
209-
}
210129

211-
// -----
212-
213-
spirv.module Logical GLSL450 {
214130
spirv.func @arm_tensor_of_one_f32() -> (!spirv.arm.tensor<1xf32>) "None" {
215131
// CHECK-NOT: spirv.EXT.ConstantCompositeReplicate
216132
%0 = spirv.Constant dense<1.0> : !spirv.arm.tensor<1xf32>

0 commit comments

Comments
 (0)