Skip to content

Commit fb3d7b1

Browse files
authored
Fix a return-type bug in some folded case branches (#2849)
When folding case ops with side effects, the prior logic could sometimes leave branches with mismatched return types. We now explicitly create constants matching the expected return types in order to ensure they always match.
1 parent 9e31595 commit fb3d7b1

File tree

2 files changed

+37
-23
lines changed

2 files changed

+37
-23
lines changed

stablehlo/tests/transforms/stablehlo_aggressive_folder.mlir

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -62,40 +62,49 @@ func.func @case_fold_constant_branch_index(%arg0: tensor<i32>, %arg1: tensor<i32
6262
// -----
6363

6464
// CHECK-LABEL: func.func @case_fold_preserve_side_effects
65-
func.func @case_fold_preserve_side_effects(%arg0: tensor<i32>, %arg1: tensor<i32>, %arg2: tensor<i32>) -> tensor<i32> {
65+
func.func @case_fold_preserve_side_effects(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<f32>) -> (tensor<3x2xi32>, tensor<f32>) {
6666
// COM: // Inline the executed branch of the `case` op:
67-
// CHECK-DAG: [[RESULT:%.+]] = stablehlo.custom_call @bar(%arg1) {has_side_effect = true}
67+
// CHECK-DAG: [[BAR:%.+]] = stablehlo.custom_call @bar(%arg1) {has_side_effect = true}
68+
69+
// COM: // Replace the inlined branch with a trivial placeholder that
70+
// COM: // just returns arbitrary constants matching the original type
71+
// COM: // signature.
72+
// CHECK-DAG: [[PLACEHOLDER_INT_TENSOR:%.+]] = stablehlo.constant dense<0> : tensor<3x2xi32>
73+
// CHECK-DAG: [[PLACEHOLDER_FLOAT:%.+]] = stablehlo.constant dense<0.000000e+00> : tensor<f32>
6874

6975
// COM: // Keep the rest of the `case` op if it's non-trivially dead:
70-
// CHECK-DAG: [[BRANCH_INDEX:%.+]] = stablehlo.constant
76+
// CHECK-DAG: [[BRANCH_INDEX:%.+]] = stablehlo.constant dense<1> : tensor<i32>
77+
7178
// CHECK-DAG: [[NON_TRIVIALLY_DEAD_CASE_OP:%.+]] = {{"?}}stablehlo.case{{"?}}([[BRANCH_INDEX]]) ({
7279
// COM: // Non-trivially dead branches are preserved but unused.
7380
// CHECK-DAG: [[FOO:%.+]] = stablehlo.custom_call @foo(%arg0) {has_side_effect = true}
74-
// CHECK-DAG: stablehlo.return [[FOO]]
81+
// CHECK-DAG: stablehlo.return [[FOO]], %arg0
7582
// COM: }, {
7683
// COM: // The executed branch is now just a trivial placeholder; its
77-
// COM: // original logic has been inlined outside of the `case` op.
78-
// CHECK-DAG: stablehlo.return [[BRANCH_INDEX]]
84+
// COM: // original logic has been inlined outside of the `case` op,
85+
// COM: // and the replacement logic just returns arbitrary constants
86+
// COM: // matching the original type signature.
87+
// CHECK-DAG: stablehlo.return [[PLACEHOLDER_INT_TENSOR]], [[PLACEHOLDER_FLOAT]]
7988
// COM: }, {
8089
// COM: // Non-trivially dead branches are preserved but unused.
8190
// CHECK-DAG: [[BAZ:%.+]] = stablehlo.custom_call @baz(%arg2) {has_side_effect = true}
82-
// CHECK-DAG: stablehlo.return [[BAZ]]
91+
// CHECK-DAG: stablehlo.return [[BAZ]], %arg2
8392
// COM: })
8493

8594
// COM: // Return the result of the inlined branch.
86-
// CHECK-DAG: {{(^ *|func\.)}}return [[RESULT]]
95+
// CHECK-DAG: {{(^ *|func\.)}}return [[BAR]], %arg1
8796
%branch_index = stablehlo.constant dense<1> : tensor<i32>
88-
%result = "stablehlo.case"(%branch_index) ({
89-
%foo = stablehlo.custom_call @foo(%arg0) {has_side_effect = true} : (tensor<i32>) -> tensor<i32>
90-
stablehlo.return %foo : tensor<i32>
97+
%result:2 = "stablehlo.case"(%branch_index) ({
98+
%foo = stablehlo.custom_call @foo(%arg0) {has_side_effect = true} : (tensor<f32>) -> tensor<3x2xi32>
99+
stablehlo.return %foo, %arg0 : tensor<3x2xi32>, tensor<f32>
91100
}, {
92-
%bar = stablehlo.custom_call @bar(%arg1) {has_side_effect = true} : (tensor<i32>) -> tensor<i32>
93-
stablehlo.return %bar : tensor<i32>
101+
%bar = stablehlo.custom_call @bar(%arg1) {has_side_effect = true} : (tensor<f32>) -> tensor<3x2xi32>
102+
stablehlo.return %bar, %arg1 : tensor<3x2xi32>, tensor<f32>
94103
}, {
95-
%baz = stablehlo.custom_call @baz(%arg2) {has_side_effect = true} : (tensor<i32>) -> tensor<i32>
96-
stablehlo.return %baz : tensor<i32>
97-
}) : (tensor<i32>) -> tensor<i32>
98-
func.return %result: tensor<i32>
104+
%baz = stablehlo.custom_call @baz(%arg2) {has_side_effect = true} : (tensor<f32>) -> tensor<3x2xi32>
105+
stablehlo.return %baz, %arg2 : tensor<3x2xi32>, tensor<f32>
106+
}) : (tensor<i32>) -> (tensor<3x2xi32>, tensor<f32>)
107+
func.return %result#0, %result#1 : tensor<3x2xi32>, tensor<f32>
99108
}
100109

101110
// -----

stablehlo/transforms/optimization/StablehloAggressiveFolder.cpp

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -503,24 +503,29 @@ class InlineCaseOpWithConstantBranchIndex
503503
selectedBranchIndex = op.getNumRegions() - 1;
504504

505505
Region& region = op.getRegion(selectedBranchIndex);
506-
assert(llvm::hasSingleElement(region));
507-
Block* block = &region.front();
506+
assert(region.hasOneBlock());
507+
Block* blockToInline = &region.front();
508508
ValueRange blockArgs = {};
509-
Operation* terminator = block->getTerminator();
509+
Operation* terminator = blockToInline->getTerminator();
510510
ValueRange results = terminator->getOperands();
511511

512512
// Inline the active branch of the `case` op.
513-
rewriter.inlineBlockBefore(block, op, blockArgs);
513+
rewriter.inlineBlockBefore(blockToInline, op, blockArgs);
514514
rewriter.replaceAllOpUsesWith(op, results);
515515
rewriter.eraseOp(terminator);
516516

517517
// Make sure the now-dead `case` op is still syntactically valid in case it
518518
// can't be safely deleted (e.g. due to side effects). Specifically, we left
519519
// one region of the `case` op empty when we inlined that block; it expects
520-
// a block with a terminator op, so we just make it return the branch index.
520+
// a block with a terminator op, so we need to make it return something.
521521
Block& noopBlock = region.emplaceBlock();
522+
SmallVector<Value> placeholderResults;
522523
rewriter.setInsertionPointToEnd(&noopBlock);
523-
rewriter.create<stablehlo::ReturnOp>(region.getLoc(), branchIndexArgument);
524+
for (auto result : op.getResults()) {
525+
placeholderResults.push_back(rewriter.create<ConstantOp>(
526+
region.getLoc(), rewriter.getZeroAttr(result.getType())));
527+
}
528+
rewriter.create<stablehlo::ReturnOp>(region.getLoc(), placeholderResults);
524529

525530
return success();
526531
}

0 commit comments

Comments
 (0)