Skip to content

Commit 568bccb

Browse files
committed
Addressing further code review comments
Signed-off-by: Mohammadreza Ameri Mahabadian <[email protected]>
1 parent 8234e68 commit 568bccb

File tree

3 files changed

+16
-32
lines changed

3 files changed

+16
-32
lines changed

mlir/include/mlir/Dialect/SPIRV/Transforms/Passes.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def SPIRVWebGPUPreparePass : Pass<"spirv-webgpu-prepare", "spirv::ModuleOp"> {
7878
}
7979

8080
def SPIRVReplicatedConstantCompositePass
81-
: Pass<"spirv-convert-to-replicated-const-composite", "spirv::ModuleOp"> {
81+
: Pass<"spirv-promote-to-replicated-constants", "spirv::ModuleOp"> {
8282
let summary = "Convert splat composite constants and spec constants to "
8383
"corresponding replicated constant composite ops defined by "
8484
"SPV_EXT_replicated_composites";

mlir/lib/Dialect/SPIRV/Transforms/ConvertToReplicatedConstantCompositePass.cpp

Lines changed: 11 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -23,28 +23,22 @@ namespace mlir::spirv {
2323
namespace {
2424

2525
static std::pair<Attribute, uint32_t>
26-
getSplatAttributeAndCount(Attribute valueAttr) {
26+
getSplatAttrAndNumElements(Attribute valueAttr) {
2727
Attribute attr;
2828
uint32_t splatCount = 0;
29-
if (auto denseAttr = dyn_cast<DenseElementsAttr>(valueAttr)) {
30-
if (denseAttr.isSplat()) {
31-
attr = denseAttr.getSplatValue<Attribute>();
32-
splatCount = denseAttr.size();
33-
}
34-
} else if (auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
35-
if (std::adjacent_find(arrayAttr.begin(), arrayAttr.end(),
36-
std::not_equal_to<>()) == arrayAttr.end()) {
29+
if (auto splatAttr = dyn_cast<SplatElementsAttr>(valueAttr)) {
30+
return {splatAttr.getSplatValue<Attribute>(), splatAttr.size()};
31+
}
32+
if (auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
33+
if (llvm::all_equal(arrayAttr)) {
3734
attr = arrayAttr[0];
3835
splatCount = arrayAttr.size();
3936
}
40-
}
4137

42-
if (attr) {
43-
auto typedAttr = dyn_cast<TypedAttr>(attr);
44-
if ((typedAttr && isa<spirv::CompositeType>(typedAttr.getType())) ||
45-
isa<ArrayAttr>(attr)) {
38+
if (attr) {
39+
// Find the inner-most splat value for array of composites
4640
std::pair<Attribute, uint32_t> newSplatAttrAndCount =
47-
getSplatAttributeAndCount(attr);
41+
getSplatAttrAndNumElements(attr);
4842
if (newSplatAttrAndCount.first) {
4943
return newSplatAttrAndCount;
5044
}
@@ -63,7 +57,7 @@ struct ConstantOpConversion final : OpRewritePattern<spirv::ConstantOp> {
6357
if (!compositeType)
6458
return rewriter.notifyMatchFailure(op, "not a composite constant");
6559

66-
auto [splatAttr, splatCount] = getSplatAttributeAndCount(op.getValue());
60+
auto [splatAttr, splatCount] = getSplatAttrAndNumElements(op.getValue());
6761
if (!splatAttr)
6862
return rewriter.notifyMatchFailure(op, "composite is not splat");
6963

@@ -73,7 +67,6 @@ struct ConstantOpConversion final : OpRewritePattern<spirv::ConstantOp> {
7367

7468
rewriter.replaceOpWithNewOp<spirv::EXTConstantCompositeReplicateOp>(
7569
op, op.getType(), splatAttr);
76-
7770
return success();
7871
}
7972
};
@@ -93,8 +86,7 @@ struct SpecConstantCompositeOpConversion final
9386
return rewriter.notifyMatchFailure(op,
9487
"composite has only one consituent");
9588

96-
if (!(std::adjacent_find(constituents.begin(), constituents.end(),
97-
std::not_equal_to<>()) == constituents.end()))
89+
if (!(llvm::all_equal(constituents)))
9890
return rewriter.notifyMatchFailure(op, "composite is not splat");
9991

10092
auto splatConstituent = dyn_cast<FlatSymbolRefAttr>(constituents[0]);

mlir/test/Dialect/SPIRV/Transforms/replicated-const-composites.mlir

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
// RUN: mlir-opt --spirv-convert-to-replicated-const-composite --split-input-file --verify-diagnostics %s | FileCheck %s
1+
// RUN: mlir-opt --spirv-promote-to-replicated-constants --split-input-file %s | FileCheck %s
22

3-
spirv.module Logical GLSL450 {
3+
spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, ReplicatedCompositesEXT], [SPV_EXT_replicated_composites]> {
44
spirv.func @splat_vector_of_i32() -> (vector<3xi32>) "None" {
55
// CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [2 : i32] : vector<3xi32>
66
%0 = spirv.Constant dense<2> : vector<3xi32>
@@ -132,21 +132,13 @@ spirv.module Logical GLSL450 {
132132
%0 = spirv.Constant dense<1.0> : !spirv.arm.tensor<1xf32>
133133
spirv.ReturnValue %0 : !spirv.arm.tensor<1xf32>
134134
}
135-
}
136-
137-
// -----
138135

139-
spirv.module Logical GLSL450 {
140136
spirv.func @non_splat_vector_of_f32() -> (vector<3xf32>) "None" {
141137
// CHECK-NOT: spirv.EXT.ConstantCompositeReplicate
142138
%0 = spirv.Constant dense<[0.0, 1.0, 2.0]> : vector<3xf32>
143139
spirv.ReturnValue %0 : vector<3xf32>
144140
}
145-
}
146141

147-
// -----
148-
149-
spirv.module Logical GLSL450 {
150142
spirv.func @non_splat_array_of_vectors_of_f32() -> (!spirv.array<2xvector<2xf32>>) "None" {
151143
// CHECK-NOT: spirv.EXT.ConstantCompositeReplicate
152144
%0 = spirv.Constant [dense<[1.0, 2.0]> : vector<2xf32>, dense<[1.0, 3.0]> : vector<2xf32>] : !spirv.array<2 x vector<2xf32>>
@@ -156,7 +148,7 @@ spirv.module Logical GLSL450 {
156148

157149
// -----
158150

159-
spirv.module Logical GLSL450 {
151+
spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, ReplicatedCompositesEXT], [SPV_EXT_replicated_composites]> {
160152

161153
spirv.SpecConstant @sc_i32_1 = 1 : i32
162154

@@ -216,4 +208,4 @@ spirv.module Logical GLSL450 {
216208

217209
// CHECK-NOT: spirv.EXT.SpecConstantCompositeReplicate
218210
spirv.SpecConstantComposite @scc_struct_of_i32_and_f32 (@sc_i32_1, @sc_i32_1, @sc_f32_1) : !spirv.struct<(i32, i32, f32)>
219-
}
211+
}

0 commit comments

Comments
 (0)