Skip to content

Commit 4d20289

Browse files
authored
Optimize folding transposes that operate on splats (#2863)
Instead of expanding and iterating over the splat to create the new constant, we now just replace the splat constant's dimensions as specified by the transpose op. When tested on the same input program that brought this issue to light, this fix improved the optimizer's execution time from 8.68 s to 0.80 s, a 985% speedup.
1 parent ef07ca9 commit 4d20289

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

stablehlo/transforms/optimization/StablehloAggressiveFolder.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1549,9 +1549,11 @@ struct FoldTransposeOpPattern : public FoldOpRewritePattern<TransposeOp> {
15491549
return rewriter.notifyMatchFailure(
15501550
op, "expected constant integer or float operand");
15511551

1552-
// TODO: Does this expand splat values? Should we special case splats?
15531552
DenseElementsAttr resAttr;
1554-
if (auto data = els.tryGetValues<APInt>())
1553+
if (auto splat = dyn_cast<SplatElementsAttr>(els))
1554+
resAttr =
1555+
DenseElementsAttr::get(resultType, splat.getSplatValue<Attribute>());
1556+
else if (auto data = els.tryGetValues<APInt>())
15551557
resAttr = transposeType(op, *data);
15561558
else if (auto data = els.tryGetValues<APFloat>())
15571559
resAttr = transposeType(op, *data);

0 commit comments

Comments
 (0)