Skip to content

Commit ae1eb18

Browse files
committed
Addressing further code review comments
Signed-off-by: Mohammadreza Ameri Mahabadian <[email protected]>
1 parent 568bccb commit ae1eb18

File tree

1 file changed

+11
-14
lines changed

1 file changed

+11
-14
lines changed

mlir/lib/Dialect/SPIRV/Transforms/ConvertToReplicatedConstantCompositePass.cpp

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -25,27 +25,24 @@ namespace {
2525
static std::pair<Attribute, uint32_t>
2626
getSplatAttrAndNumElements(Attribute valueAttr) {
2727
Attribute attr;
28-
uint32_t splatCount = 0;
28+
uint32_t numElements = 0;
2929
if (auto splatAttr = dyn_cast<SplatElementsAttr>(valueAttr)) {
3030
return {splatAttr.getSplatValue<Attribute>(), splatAttr.size()};
3131
}
3232
if (auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
3333
if (llvm::all_equal(arrayAttr)) {
3434
attr = arrayAttr[0];
35-
splatCount = arrayAttr.size();
36-
}
35+
numElements = arrayAttr.size();
3736

38-
if (attr) {
3937
// Find the inner-most splat value for array of composites
40-
std::pair<Attribute, uint32_t> newSplatAttrAndCount =
41-
getSplatAttrAndNumElements(attr);
42-
if (newSplatAttrAndCount.first) {
43-
return newSplatAttrAndCount;
38+
auto [newAttr, newNumElements] = getSplatAttrAndNumElements(attr);
39+
if (newAttr) {
40+
return {newAttr, numElements * newNumElements};
4441
}
4542
}
4643
}
4744

48-
return {attr, splatCount};
45+
return {attr, numElements};
4946
}
5047

5148
struct ConstantOpConversion final : OpRewritePattern<spirv::ConstantOp> {
@@ -57,16 +54,16 @@ struct ConstantOpConversion final : OpRewritePattern<spirv::ConstantOp> {
5754
if (!compositeType)
5855
return rewriter.notifyMatchFailure(op, "not a composite constant");
5956

60-
auto [splatAttr, splatCount] = getSplatAttrAndNumElements(op.getValue());
61-
if (!splatAttr)
57+
auto [attr, numElements] = getSplatAttrAndNumElements(op.getValue());
58+
if (!attr)
6259
return rewriter.notifyMatchFailure(op, "composite is not splat");
6360

64-
if (splatCount == 1)
61+
if (numElements == 1)
6562
return rewriter.notifyMatchFailure(op,
6663
"composite has only one constituent");
6764

6865
rewriter.replaceOpWithNewOp<spirv::EXTConstantCompositeReplicateOp>(
69-
op, op.getType(), splatAttr);
66+
op, op.getType(), attr);
7067
return success();
7168
}
7269
};
@@ -86,7 +83,7 @@ struct SpecConstantCompositeOpConversion final
8683
return rewriter.notifyMatchFailure(op,
8784
"composite has only one consituent");
8885

89-
if (!(llvm::all_equal(constituents)))
86+
if (!llvm::all_equal(constituents))
9087
return rewriter.notifyMatchFailure(op, "composite is not splat");
9188

9289
auto splatConstituent = dyn_cast<FlatSymbolRefAttr>(constituents[0]);

0 commit comments

Comments
 (0)