Skip to content

Commit 2a86b6e

Browse files
committed
Fix switch to if
1 parent 78b2b33 commit 2a86b6e

File tree

2 files changed

+25
-1
lines changed

2 files changed

+25
-1
lines changed

src/enzyme_ad/jax/Passes/CanonicalizeLoops.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -546,7 +546,7 @@ class SwitchToIf final : public OpRewritePattern<scf::IndexSwitchOp> {
546546
cmpResult, /*withElseRegion=*/true);
547547

548548
// Move the first case block into the then region
549-
Block &firstBlock = switchOp.getCaseBlock(cases.front());
549+
Block &firstBlock = switchOp.getCaseBlock(0);
550550
rewriter.mergeBlocks(&firstBlock, ifOp.thenBlock(),
551551
firstBlock.getArguments());
552552

test/lit_tests/switch_to_if.mlir

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,30 @@ func.func @switch_to_if(%arg0: index) -> i32 {
2424
return %0 : i32
2525
}
2626

27+
// CHECK-LABEL: func @switch_to_if2
28+
func.func @switch_to_if2(%arg0: index) -> i32 {
29+
// CHECK-DAG: %[[CONST:.*]] = arith.constant 20 : index
30+
// CHECK-DAG: %[[VAL1:.*]] = arith.constant 42 : i32
31+
// CHECK-DAG: %[[VAL2:.*]] = arith.constant 24 : i32
32+
// CHECK: %[[CMP:.*]] = arith.cmpi eq, %arg0, %[[CONST]] : index
33+
// CHECK: %[[RESULT:.*]] = scf.if %[[CMP]] -> (i32) {
34+
// CHECK: scf.yield %[[VAL1]] : i32
35+
// CHECK: } else {
36+
// CHECK: scf.yield %[[VAL2]] : i32
37+
// CHECK: }
38+
// CHECK: return %[[RESULT]] : i32
39+
%0 = scf.index_switch %arg0 -> i32
40+
case 20 {
41+
%1 = arith.constant 42 : i32
42+
scf.yield %1 : i32
43+
}
44+
default {
45+
%1 = arith.constant 24 : i32
46+
scf.yield %1 : i32
47+
}
48+
return %0 : i32
49+
}
50+
2751
// Should not convert switches with more than 1 case
2852
// CHECK-LABEL: func @switch_two_cases
2953
func.func @switch_two_cases(%arg0: index) -> i32 {

0 commit comments

Comments
 (0)