@@ -62,40 +62,49 @@ func.func @case_fold_constant_branch_index(%arg0: tensor<i32>, %arg1: tensor<i32
6262// -----
6363
6464// CHECK-LABEL: func.func @case_fold_preserve_side_effects
65- func.func @case_fold_preserve_side_effects (%arg0: tensor <i32 >, %arg1: tensor <i32 >, %arg2: tensor <i32 >) -> tensor <i32 > {
65+ func.func @case_fold_preserve_side_effects (%arg0: tensor <f32 >, %arg1: tensor <f32 >, %arg2: tensor <f32 >) -> ( tensor <3 x 2 x i32 >, tensor < f32 >) {
6666 // COM: // Inline the executed branch of the `case` op:
67- // CHECK-DAG: [[RESULT:%.+]] = stablehlo.custom_call @bar(%arg1) {has_side_effect = true}
67+ // CHECK-DAG: [[BAR:%.+]] = stablehlo.custom_call @bar(%arg1) {has_side_effect = true}
68+
69+ // COM: // Replace the inlined branch with a trivial placeholder that
70+ // COM: // just returns arbitrary constants matching the original type
71+ // COM: // signature.
72+ // CHECK-DAG: [[PLACEHOLDER_INT_TENSOR:%.+]] = stablehlo.constant dense<0> : tensor<3x2xi32>
73+ // CHECK-DAG: [[PLACEHOLDER_FLOAT:%.+]] = stablehlo.constant dense<0.000000e+00> : tensor<f32>
6874
6975 // COM: // Keep the rest of the `case` op if it's non-trivially dead:
70- // CHECK-DAG: [[BRANCH_INDEX:%.+]] = stablehlo.constant
76+ // CHECK-DAG: [[BRANCH_INDEX:%.+]] = stablehlo.constant dense<1> : tensor<i32>
77+
7178 // CHECK-DAG: [[NON_TRIVIALLY_DEAD_CASE_OP:%.+]] = {{"?}}stablehlo.case{{"?}}([[BRANCH_INDEX]]) ({
7279 // COM: // Non-trivially dead branches are preserved but unused.
7380 // CHECK-DAG: [[FOO:%.+]] = stablehlo.custom_call @foo(%arg0) {has_side_effect = true}
74- // CHECK-DAG: stablehlo.return [[FOO]]
81+ // CHECK-DAG: stablehlo.return [[FOO]], %arg0
7582 // COM: }, {
7683 // COM: // The executed branch is now just a trivial placeholder; its
77- // COM: // original logic has been inlined outside of the `case` op.
78- // CHECK-DAG: stablehlo.return [[BRANCH_INDEX]]
84+ // COM: // original logic has been inlined outside of the `case` op,
85+ // COM: // and the replacement logic just returns arbitrary constants
86+ // COM: // matching the original type signature.
87+ // CHECK-DAG: stablehlo.return [[PLACEHOLDER_INT_TENSOR]], [[PLACEHOLDER_FLOAT]]
7988 // COM: }, {
8089 // COM: // Non-trivially dead branches are preserved but unused.
8190 // CHECK-DAG: [[BAZ:%.+]] = stablehlo.custom_call @baz(%arg2) {has_side_effect = true}
82- // CHECK-DAG: stablehlo.return [[BAZ]]
91+ // CHECK-DAG: stablehlo.return [[BAZ]], %arg2
8392 // COM: })
8493
8594 // COM: // Return the result of the inlined branch.
86- // CHECK-DAG: {{(^ *|func\.)}}return [[RESULT]]
95+ // CHECK-DAG: {{(^ *|func\.)}}return [[BAR]], %arg1
8796 %branch_index = stablehlo.constant dense <1 > : tensor <i32 >
88- %result = " stablehlo.case" (%branch_index ) ({
89- %foo = stablehlo.custom_call @foo (%arg0 ) {has_side_effect = true } : (tensor <i32 >) -> tensor <i32 >
90- stablehlo.return %foo : tensor <i32 >
97+ %result:2 = " stablehlo.case" (%branch_index ) ({
98+ %foo = stablehlo.custom_call @foo (%arg0 ) {has_side_effect = true } : (tensor <f32 >) -> tensor <3 x 2 x i32 >
99+ stablehlo.return %foo , %arg0 : tensor <3 x 2 x i32 >, tensor < f32 >
91100 }, {
92- %bar = stablehlo.custom_call @bar (%arg1 ) {has_side_effect = true } : (tensor <i32 >) -> tensor <i32 >
93- stablehlo.return %bar : tensor <i32 >
101+ %bar = stablehlo.custom_call @bar (%arg1 ) {has_side_effect = true } : (tensor <f32 >) -> tensor <3 x 2 x i32 >
102+ stablehlo.return %bar , %arg1 : tensor <3 x 2 x i32 >, tensor < f32 >
94103 }, {
95- %baz = stablehlo.custom_call @baz (%arg2 ) {has_side_effect = true } : (tensor <i32 >) -> tensor <i32 >
96- stablehlo.return %baz : tensor <i32 >
97- }) : (tensor <i32 >) -> tensor <i32 >
98- func.return %result: tensor <i32 >
104+ %baz = stablehlo.custom_call @baz (%arg2 ) {has_side_effect = true } : (tensor <f32 >) -> tensor <3 x 2 x i32 >
105+ stablehlo.return %baz , %arg2 : tensor <3 x 2 x i32 >, tensor < f32 >
106+ }) : (tensor <i32 >) -> ( tensor <3 x 2 x i32 >, tensor < f32 >)
107+ func.return %result#0 , %result#1 : tensor <3 x 2 x i32 >, tensor < f32 >
99108}
100109
101110// -----
0 commit comments