Skip to content

Commit b718e34

Browse files
committed
Addressing further code review comments
Signed-off-by: Mohammadreza Ameri Mahabadian <[email protected]>
1 parent d065a5a commit b718e34

File tree

1 file changed

+9
-17
lines changed

1 file changed

+9
-17
lines changed

mlir/lib/Dialect/SPIRV/Transforms/ConvertToReplicatedConstantCompositePass.cpp

Lines changed: 9 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
//===- ConvertToReplicatedConstantCompositePass.cpp --------------------===//
1+
//===- ConvertToReplicatedConstantCompositePass.cpp -----------------------===//
22
//
33
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44
// See https://llvm.org/LICENSE.txt for license information.
@@ -40,15 +40,9 @@ getSplatAttributeAndCount(Attribute valueAttr) {
4040
}
4141

4242
if (attr) {
43-
if (auto typedAttr = dyn_cast<TypedAttr>(attr)) {
44-
if (isa<spirv::CompositeType>(typedAttr.getType())) {
45-
std::pair<Attribute, uint32_t> newSplatAttrAndCount =
46-
getSplatAttributeAndCount(attr);
47-
if (newSplatAttrAndCount.first) {
48-
return newSplatAttrAndCount;
49-
}
50-
}
51-
} else if (isa<ArrayAttr>(attr)) {
43+
auto typedAttr = dyn_cast<TypedAttr>(attr);
44+
if ((typedAttr && isa<spirv::CompositeType>(typedAttr.getType())) ||
45+
isa<ArrayAttr>(attr)) {
5246
std::pair<Attribute, uint32_t> newSplatAttrAndCount =
5347
getSplatAttributeAndCount(attr);
5448
if (newSplatAttrAndCount.first) {
@@ -69,17 +63,16 @@ struct ConstantOpConversion final : OpRewritePattern<spirv::ConstantOp> {
6963
if (!compositeType)
7064
return rewriter.notifyMatchFailure(op, "not a composite constant");
7165

72-
std::pair<Attribute, uint32_t> splatAttrAndCount =
73-
getSplatAttributeAndCount(op.getValue());
74-
if (!splatAttrAndCount.first)
66+
auto [splattAttr, splatCount] = getSplatAttributeAndCount(op.getValue());
67+
if (!splattAttr)
7568
return rewriter.notifyMatchFailure(op, "composite is not splat");
7669

77-
if (splatAttrAndCount.second == 1)
70+
if (splatCount == 1)
7871
return rewriter.notifyMatchFailure(op,
7972
"composite has only one constituent");
8073

8174
rewriter.replaceOpWithNewOp<spirv::EXTConstantCompositeReplicateOp>(
82-
op, op.getType(), splatAttrAndCount.first);
75+
op, op.getType(), splattAttr);
8376

8477
return success();
8578
}
@@ -104,8 +97,7 @@ struct SpecConstantCompositeOpConversion final
10497
std::not_equal_to<>()) == constituents.end()))
10598
return rewriter.notifyMatchFailure(op, "composite is not splat");
10699

107-
auto splatConstituent =
108-
dyn_cast<FlatSymbolRefAttr>(op.getConstituents()[0]);
100+
auto splatConstituent = dyn_cast<FlatSymbolRefAttr>(constituents[0]);
109101
if (!splatConstituent)
110102
return rewriter.notifyMatchFailure(
111103
op, "expected flat symbol reference for splat constituent");

0 commit comments

Comments
 (0)