diff --git a/lib/Transforms/IndexSwitchToIf.cpp b/lib/Transforms/IndexSwitchToIf.cpp index 359ef0699e1f..0d328127b0c6 100644 --- a/lib/Transforms/IndexSwitchToIf.cpp +++ b/lib/Transforms/IndexSwitchToIf.cpp @@ -18,6 +18,7 @@ #include "mlir/IR/OperationSupport.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/STLExtras.h" namespace circt { #define GEN_PASS_DEF_INDEXSWITCHTOIF @@ -38,11 +39,15 @@ struct SwitchToIfConversion : public OpConversionPattern { Region &defaultRegion = switchOp.getDefaultRegion(); bool hasResults = !switchOp.getResultTypes().empty(); - Value finalResult; + SmallVector finalResults; scf::IfOp prevIfOp = nullptr; rewriter.setInsertionPointAfter(switchOp); auto switchCases = switchOp.getCases(); + Value switchOperand = adaptor.getArg(); + if (!switchOperand) + return rewriter.notifyMatchFailure(switchOp, + "missing converted switch operand"); for (size_t i = 0; i < switchCases.size(); i++) { auto caseValueInt = switchCases[i]; if (prevIfOp) @@ -50,9 +55,8 @@ struct SwitchToIfConversion : public OpConversionPattern { Value caseValue = arith::ConstantIndexOp::create(rewriter, loc, caseValueInt); - Value cond = - arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq, - switchOp.getOperand(), caseValue); + Value cond = arith::CmpIOp::create( + rewriter, loc, arith::CmpIPredicate::eq, switchOperand, caseValue); auto ifOp = scf::IfOp::create(rewriter, loc, switchOp.getResultTypes(), cond, /*hasElseRegion=*/true); @@ -70,17 +74,17 @@ struct SwitchToIfConversion : public OpConversionPattern { if (prevIfOp && hasResults) { rewriter.setInsertionPointToEnd(&prevIfOp.getElseRegion().front()); - scf::YieldOp::create(rewriter, loc, ifOp.getResult(0)); + scf::YieldOp::create(rewriter, loc, ifOp.getResults()); } if (i == 0 && hasResults) - finalResult = ifOp.getResult(0); + llvm::append_range(finalResults, ifOp.getResults()); prevIfOp = ifOp; } if (hasResults) - rewriter.replaceOp(switchOp, finalResult); + rewriter.replaceOp(switchOp, finalResults); else rewriter.eraseOp(switchOp); diff --git a/test/Transforms/switch-to-if.mlir b/test/Transforms/switch-to-if.mlir index 579a720172f1..98feedfbbd90 100644 --- a/test/Transforms/switch-to-if.mlir +++ b/test/Transforms/switch-to-if.mlir @@ -56,6 +56,57 @@ module { } } +// Switches that yield multiple values should thread every result through the +// nested if-else chain. + +// ----- + +// CHECK-LABEL: func.func @multi_result( +// CHECK-SAME: %[[ARG:.*]]: index) -> (i32, f32) { +// CHECK: %[[ZERO:.*]] = arith.constant 0 : index +// CHECK: %[[CMP0:.*]] = arith.cmpi eq, %[[ARG]], %[[ZERO]] : index +// CHECK: %[[IF0:[0-9]+]]:2 = scf.if %[[CMP0]] -> (i32, f32) { +// CHECK: %[[C0:.*]] = arith.constant 10 : i32 +// CHECK: %[[F0:.*]] = arith.constant 1.000000e+00 : f32 +// CHECK: scf.yield %[[C0]], %[[F0]] : i32, f32 +// CHECK: } else { +// CHECK: %[[ONE:.*]] = arith.constant 1 : index +// CHECK: %[[CMP1:.*]] = arith.cmpi eq, %[[ARG]], %[[ONE]] : index +// CHECK: %[[IF1:[0-9]+]]:2 = scf.if %[[CMP1]] -> (i32, f32) { +// CHECK: %[[C1:.*]] = arith.constant 20 : i32 +// CHECK: %[[F1:.*]] = arith.constant 2.000000e+00 : f32 +// CHECK: scf.yield %[[C1]], %[[F1]] : i32, f32 +// CHECK: } else { +// CHECK: %[[CDEF:.*]] = arith.constant 30 : i32 +// CHECK: %[[FDEF:.*]] = arith.constant 3.000000e+00 : f32 +// CHECK: scf.yield %[[CDEF]], %[[FDEF]] : i32, f32 +// CHECK: } +// CHECK: scf.yield %[[IF1]]#0, %[[IF1]]#1 : i32, f32 +// CHECK: } +// CHECK: return %[[IF0]]#0, %[[IF0]]#1 : i32, f32 +// CHECK: } +module { + func.func @multi_result(%arg0 : index) -> (i32, f32) { + %0, %1 = scf.index_switch %arg0 -> i32, f32 + case 0 { + %c0 = arith.constant 10 : i32 + %f0 = arith.constant 1.0 : f32 + scf.yield %c0, %f0 : i32, f32 + } + case 1 { + %c1 = arith.constant 20 : i32 + %f1 = arith.constant 2.0 : f32 + scf.yield %c1, %f1 : i32, f32 + } + default { + %cd = arith.constant 30 : i32 + %fd = arith.constant 3.0 : f32 + scf.yield %cd, %fd : i32, f32 + } + return %0, %1 : i32, f32 + } +} + // Switch to nested if-else when the yielded result is empty // -----