Skip to content
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -25,27 +25,24 @@ namespace {
static std::pair<Attribute, uint32_t>
getSplatAttrAndNumElements(Attribute valueAttr) {
Attribute attr;
uint32_t splatCount = 0;
uint32_t numElements = 0;
if (auto splatAttr = dyn_cast<SplatElementsAttr>(valueAttr)) {
return {splatAttr.getSplatValue<Attribute>(), splatAttr.size()};
}
if (auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
if (llvm::all_equal(arrayAttr)) {
attr = arrayAttr[0];
splatCount = arrayAttr.size();
}
numElements = arrayAttr.size();

if (attr) {
// Find the inner-most splat value for array of composites
std::pair<Attribute, uint32_t> newSplatAttrAndCount =
getSplatAttrAndNumElements(attr);
if (newSplatAttrAndCount.first) {
return newSplatAttrAndCount;
auto [newAttr, newNumElements] = getSplatAttrAndNumElements(attr);
if (newAttr) {
return {newAttr, numElements * newNumElements};
}
}
}

return {attr, splatCount};
return {attr, numElements};
}

struct ConstantOpConversion final : OpRewritePattern<spirv::ConstantOp> {
Expand All @@ -57,16 +54,16 @@ struct ConstantOpConversion final : OpRewritePattern<spirv::ConstantOp> {
if (!compositeType)
return rewriter.notifyMatchFailure(op, "not a composite constant");

auto [splatAttr, splatCount] = getSplatAttrAndNumElements(op.getValue());
if (!splatAttr)
auto [attr, numElements] = getSplatAttrAndNumElements(op.getValue());
if (!attr)
return rewriter.notifyMatchFailure(op, "composite is not splat");

if (splatCount == 1)
if (numElements == 1)
return rewriter.notifyMatchFailure(op,
"composite has only one constituent");

rewriter.replaceOpWithNewOp<spirv::EXTConstantCompositeReplicateOp>(
op, op.getType(), splatAttr);
op, op.getType(), attr);
return success();
}
};
Expand All @@ -86,7 +83,7 @@ struct SpecConstantCompositeOpConversion final
return rewriter.notifyMatchFailure(op,
"composite has only one consituent");

if (!(llvm::all_equal(constituents)))
if (!llvm::all_equal(constituents))
return rewriter.notifyMatchFailure(op, "composite is not splat");

auto splatConstituent = dyn_cast<FlatSymbolRefAttr>(constituents[0]);
Expand Down
Loading