1- // ===- ConvertToReplicatedConstantCompositePass.cpp --------------------===//
1+ // ===- ConvertToReplicatedConstantCompositePass.cpp ----------------------- ===//
22//
33// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44// See https://llvm.org/LICENSE.txt for license information.
@@ -40,15 +40,9 @@ getSplatAttributeAndCount(Attribute valueAttr) {
4040 }
4141
4242 if (attr) {
43- if (auto typedAttr = dyn_cast<TypedAttr>(attr)) {
44- if (isa<spirv::CompositeType>(typedAttr.getType ())) {
45- std::pair<Attribute, uint32_t > newSplatAttrAndCount =
46- getSplatAttributeAndCount (attr);
47- if (newSplatAttrAndCount.first ) {
48- return newSplatAttrAndCount;
49- }
50- }
51- } else if (isa<ArrayAttr>(attr)) {
43+ auto typedAttr = dyn_cast<TypedAttr>(attr);
44+ if ((typedAttr && isa<spirv::CompositeType>(typedAttr.getType ())) ||
45+ isa<ArrayAttr>(attr)) {
5246 std::pair<Attribute, uint32_t > newSplatAttrAndCount =
5347 getSplatAttributeAndCount (attr);
5448 if (newSplatAttrAndCount.first ) {
@@ -69,17 +63,16 @@ struct ConstantOpConversion final : OpRewritePattern<spirv::ConstantOp> {
6963 if (!compositeType)
7064 return rewriter.notifyMatchFailure (op, " not a composite constant" );
7165
72- std::pair<Attribute, uint32_t > splatAttrAndCount =
73- getSplatAttributeAndCount (op.getValue ());
74- if (!splatAttrAndCount.first )
66+ auto [splattAttr, splatCount] = getSplatAttributeAndCount (op.getValue ());
67+ if (!splattAttr)
7568 return rewriter.notifyMatchFailure (op, " composite is not splat" );
7669
77- if (splatAttrAndCount. second == 1 )
70+ if (splatCount == 1 )
7871 return rewriter.notifyMatchFailure (op,
7972 " composite has only one constituent" );
8073
8174 rewriter.replaceOpWithNewOp <spirv::EXTConstantCompositeReplicateOp>(
82- op, op.getType (), splatAttrAndCount. first );
75+ op, op.getType (), splattAttr );
8376
8477 return success ();
8578 }
@@ -104,8 +97,7 @@ struct SpecConstantCompositeOpConversion final
10497 std::not_equal_to<>()) == constituents.end ()))
10598 return rewriter.notifyMatchFailure (op, " composite is not splat" );
10699
107- auto splatConstituent =
108- dyn_cast<FlatSymbolRefAttr>(op.getConstituents ()[0 ]);
100+ auto splatConstituent = dyn_cast<FlatSymbolRefAttr>(constituents[0 ]);
109101 if (!splatConstituent)
110102 return rewriter.notifyMatchFailure (
111103 op, " expected flat symbol reference for splat constituent" );
0 commit comments