Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions lib/Transforms/IndexSwitchToIf.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "mlir/IR/OperationSupport.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/STLExtras.h"

namespace circt {
#define GEN_PASS_DEF_INDEXSWITCHTOIF
Expand All @@ -38,21 +39,23 @@ struct SwitchToIfConversion : public OpConversionPattern<scf::IndexSwitchOp> {
Region &defaultRegion = switchOp.getDefaultRegion();
bool hasResults = !switchOp.getResultTypes().empty();

Value finalResult;
SmallVector<Value> finalResults;
scf::IfOp prevIfOp = nullptr;

rewriter.setInsertionPointAfter(switchOp);
auto switchCases = switchOp.getCases();
Value switchOperand = adaptor.getArg();
if (!switchOperand)
switchOperand = switchOp.getOperand();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If getArg fails I would have thought this should be a failure? I.e. return failure(); as then the operator is not satisfying its definition?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, I addressed that the simplest possible way @cowardsa

for (size_t i = 0; i < switchCases.size(); i++) {
auto caseValueInt = switchCases[i];
if (prevIfOp)
rewriter.setInsertionPointToStart(&prevIfOp.getElseRegion().front());

Value caseValue =
arith::ConstantIndexOp::create(rewriter, loc, caseValueInt);
Value cond =
arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq,
switchOp.getOperand(), caseValue);
Value cond = arith::CmpIOp::create(
rewriter, loc, arith::CmpIPredicate::eq, switchOperand, caseValue);

auto ifOp = scf::IfOp::create(rewriter, loc, switchOp.getResultTypes(),
cond, /*hasElseRegion=*/true);
Expand All @@ -70,17 +73,17 @@ struct SwitchToIfConversion : public OpConversionPattern<scf::IndexSwitchOp> {

if (prevIfOp && hasResults) {
rewriter.setInsertionPointToEnd(&prevIfOp.getElseRegion().front());
scf::YieldOp::create(rewriter, loc, ifOp.getResult(0));
scf::YieldOp::create(rewriter, loc, ifOp.getResults());
}

if (i == 0 && hasResults)
finalResult = ifOp.getResult(0);
llvm::append_range(finalResults, ifOp.getResults());

prevIfOp = ifOp;
}

if (hasResults)
rewriter.replaceOp(switchOp, finalResult);
rewriter.replaceOp(switchOp, finalResults);
else
rewriter.eraseOp(switchOp);

Expand Down
51 changes: 51 additions & 0 deletions test/Transforms/switch-to-if.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,57 @@ module {
}
}

// Switches that yield multiple values should thread every result through the
// nested if-else chain.

// -----

// CHECK-LABEL: func.func @multi_result(
// CHECK-SAME: %[[ARG:.*]]: index) -> (i32, f32) {
// CHECK: %[[ZERO:.*]] = arith.constant 0 : index
// CHECK: %[[CMP0:.*]] = arith.cmpi eq, %[[ARG]], %[[ZERO]] : index
// CHECK: %[[IF0:[0-9]+]]:2 = scf.if %[[CMP0]] -> (i32, f32) {
// CHECK: %[[C0:.*]] = arith.constant 10 : i32
// CHECK: %[[F0:.*]] = arith.constant 1.000000e+00 : f32
// CHECK: scf.yield %[[C0]], %[[F0]] : i32, f32
// CHECK: } else {
// CHECK: %[[ONE:.*]] = arith.constant 1 : index
// CHECK: %[[CMP1:.*]] = arith.cmpi eq, %[[ARG]], %[[ONE]] : index
// CHECK: %[[IF1:[0-9]+]]:2 = scf.if %[[CMP1]] -> (i32, f32) {
// CHECK: %[[C1:.*]] = arith.constant 20 : i32
// CHECK: %[[F1:.*]] = arith.constant 2.000000e+00 : f32
// CHECK: scf.yield %[[C1]], %[[F1]] : i32, f32
// CHECK: } else {
// CHECK: %[[CDEF:.*]] = arith.constant 30 : i32
// CHECK: %[[FDEF:.*]] = arith.constant 3.000000e+00 : f32
// CHECK: scf.yield %[[CDEF]], %[[FDEF]] : i32, f32
Comment on lines +80 to +82
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Amusing observation - C default -> CDEF (just thought you'd written a string of letters)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just realizing that, Constant of the default case in my mind..

// CHECK: }
// CHECK: scf.yield %[[IF1]]#0, %[[IF1]]#1 : i32, f32
// CHECK: }
// CHECK: return %[[IF0]]#0, %[[IF0]]#1 : i32, f32
// CHECK: }
module {
func.func @multi_result(%arg0 : index) -> (i32, f32) {
%0, %1 = scf.index_switch %arg0 -> i32, f32
case 0 {
%c0 = arith.constant 10 : i32
%f0 = arith.constant 1.0 : f32
scf.yield %c0, %f0 : i32, f32
}
case 1 {
%c1 = arith.constant 20 : i32
%f1 = arith.constant 2.0 : f32
scf.yield %c1, %f1 : i32, f32
}
default {
%cd = arith.constant 30 : i32
%fd = arith.constant 3.0 : f32
scf.yield %cd, %fd : i32, f32
}
return %0, %1 : i32, f32
}
}

// Switch to nested if-else when the yielded result is empty

// -----
Expand Down