Skip to content

Commit 6b92cce

Browse files
authored
Don't try to fold case ops with unsupported types (#2856)
The case-op folder pattern currently supports only int and float types (and tensors thereof). We plan to extend this in future, but for now, fail to match if the op's return type is unsupported.
1 parent c00a02e commit 6b92cce

File tree

1 file changed

+17
-3
lines changed

1 file changed

+17
-3
lines changed

stablehlo/transforms/optimization/StablehloAggressiveFolder.cpp

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -564,6 +564,20 @@ class InlineCaseOpWithConstantBranchIndex
564564
Operation* terminator = blockToInline->getTerminator();
565565
ValueRange results = terminator->getOperands();
566566

567+
// TODO: Add support for complex, quantized, and token return types.
568+
// Currently, this pattern only supports int and float return types. We'll
569+
// need a more general equivalent of `getZeroAttr` to support other types.
570+
SmallVector<TypedAttr> placeholderAttrs;
571+
for (auto result : op.getResults()) {
572+
TypedAttr placeholderAttr = rewriter.getZeroAttr(result.getType());
573+
if (!placeholderAttr)
574+
return rewriter.notifyMatchFailure(
575+
op,
576+
"The case op's return type isn't currently supported by this "
577+
"optimization pattern.");
578+
placeholderAttrs.push_back(placeholderAttr);
579+
}
580+
567581
// Inline the active branch of the `case` op.
568582
rewriter.inlineBlockBefore(blockToInline, op, blockArgs);
569583
rewriter.replaceAllOpUsesWith(op, results);
@@ -576,9 +590,9 @@ class InlineCaseOpWithConstantBranchIndex
576590
Block& noopBlock = region.emplaceBlock();
577591
SmallVector<Value> placeholderResults;
578592
rewriter.setInsertionPointToEnd(&noopBlock);
579-
for (auto result : op.getResults()) {
580-
placeholderResults.push_back(rewriter.create<ConstantOp>(
581-
region.getLoc(), rewriter.getZeroAttr(result.getType())));
593+
for (auto placeholderAttr : placeholderAttrs) {
594+
placeholderResults.push_back(
595+
rewriter.create<ConstantOp>(region.getLoc(), placeholderAttr));
582596
}
583597
rewriter.create<stablehlo::ReturnOp>(region.getLoc(), placeholderResults);
584598

0 commit comments

Comments
 (0)