@@ -25,27 +25,24 @@ namespace {
2525static std::pair<Attribute, uint32_t >
2626getSplatAttrAndNumElements (Attribute valueAttr) {
2727 Attribute attr;
28- uint32_t splatCount = 0 ;
28+ uint32_t numElements = 0 ;
2929 if (auto splatAttr = dyn_cast<SplatElementsAttr>(valueAttr)) {
3030 return {splatAttr.getSplatValue <Attribute>(), splatAttr.size ()};
3131 }
3232 if (auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
3333 if (llvm::all_equal (arrayAttr)) {
3434 attr = arrayAttr[0 ];
35- splatCount = arrayAttr.size ();
36- }
35+ numElements = arrayAttr.size ();
3736
38- if (attr) {
3937 // Find the inner-most splat value for array of composites
40- std::pair<Attribute, uint32_t > newSplatAttrAndCount =
41- getSplatAttrAndNumElements (attr);
42- if (newSplatAttrAndCount.first ) {
43- return newSplatAttrAndCount;
38+ auto [newAttr, newNumElements] = getSplatAttrAndNumElements (attr);
39+ if (newAttr) {
40+ return {newAttr, numElements * newNumElements};
4441 }
4542 }
4643 }
4744
48- return {attr, splatCount };
45+ return {attr, numElements };
4946}
5047
5148struct ConstantOpConversion final : OpRewritePattern<spirv::ConstantOp> {
@@ -57,16 +54,16 @@ struct ConstantOpConversion final : OpRewritePattern<spirv::ConstantOp> {
5754 if (!compositeType)
5855 return rewriter.notifyMatchFailure (op, " not a composite constant" );
5956
60- auto [splatAttr, splatCount ] = getSplatAttrAndNumElements (op.getValue ());
61- if (!splatAttr )
57+ auto [attr, numElements ] = getSplatAttrAndNumElements (op.getValue ());
58+ if (!attr )
6259 return rewriter.notifyMatchFailure (op, " composite is not splat" );
6360
64- if (splatCount == 1 )
61+ if (numElements == 1 )
6562 return rewriter.notifyMatchFailure (op,
6663 " composite has only one constituent" );
6764
6865 rewriter.replaceOpWithNewOp <spirv::EXTConstantCompositeReplicateOp>(
69- op, op.getType (), splatAttr );
66+ op, op.getType (), attr );
7067 return success ();
7168 }
7269};
@@ -86,7 +83,7 @@ struct SpecConstantCompositeOpConversion final
8683 return rewriter.notifyMatchFailure (op,
8784 " composite has only one consituent" );
8885
89- if (!( llvm::all_equal (constituents) ))
86+ if (!llvm::all_equal (constituents))
9087 return rewriter.notifyMatchFailure (op, " composite is not splat" );
9188
9289 auto splatConstituent = dyn_cast<FlatSymbolRefAttr>(constituents[0 ]);
0 commit comments