Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -669,7 +669,7 @@ IndexSwitchLowering::matchAndRewrite(IndexSwitchOp op,

// Convert the case regions.
SmallVector<Block *> caseSuccessors;
SmallVector<int32_t> caseValues;
SmallVector<int64_t> caseValues;
caseSuccessors.reserve(op.getCases().size());
caseValues.reserve(op.getCases().size());
for (auto [region, value] : llvm::zip(op.getCaseRegions(), op.getCases())) {
Expand All @@ -691,11 +691,14 @@ IndexSwitchLowering::matchAndRewrite(IndexSwitchOp op,

// Cast switch index to integer case value.
Value caseValue = rewriter.create<arith::IndexCastOp>(
op.getLoc(), rewriter.getI32Type(), op.getArg());
op.getLoc(), rewriter.getI64Type(), op.getArg());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm surprised to see an IndexCast here. Isn't there a different mechanism in MLIR to specify how IndexType is lowered?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we have to convert index values at all? Can we defer this to the point where we lower to LLVM?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My take on this is that LLVMConversionTarget might not be the best way to resolve this. The index type is converted to integer types based on the data layout (or overridden by hand), as you suggested. On a 32-bit machine, e.g. RV32, index type is typically converted to i32. This still causes the overflow problem since we need i64 for the switch argument. If the conversion bitwidth is overridden to i64, then it will affect other dialect conversion, such as memref, since memref.extract_aligned_pointer_as_index returns an index type. Having an i64 casted to a llvm.ptr seems awkward on a 32-bit machine.

An alternative is to lower the LLVM in multiple steps with different conversion targets, but I am not sure if this approach is user friendly. Or am I missing something here?


ShapedType caseValueType = VectorType::get(
static_cast<int64_t>(caseValues.size()), rewriter.getI64Type());
rewriter.create<cf::SwitchOp>(
op.getLoc(), caseValue, *defaultBlock, ValueRange(),
rewriter.getDenseI32ArrayAttr(caseValues), caseSuccessors, caseOperands);
DenseIntElementsAttr::get(caseValueType, caseValues), caseSuccessors,
caseOperands);
rewriter.replaceOp(op, continueBlock->getArguments());
return success();
}
Expand Down
21 changes: 19 additions & 2 deletions mlir/test/Conversion/SCFToControlFlow/convert-to-cfg.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -622,8 +622,8 @@ func.func @func_execute_region_elim_multi_yield() {

// CHECK-LABEL: @index_switch
func.func @index_switch(%i: index, %a: i32, %b: i32, %c: i32) -> i32 {
// CHECK: %[[CASE:.*]] = arith.index_cast %arg0 : index to i32
// CHECK: cf.switch %[[CASE]] : i32
// CHECK: %[[CASE:.*]] = arith.index_cast %arg0 : index to i64
// CHECK: cf.switch %[[CASE]] : i64
// CHECK-NEXT: default: ^[[DEFAULT:.+]],
// CHECK-NEXT: 0: ^[[bb1:.+]],
// CHECK-NEXT: 1: ^[[bb2:.+]]
Expand All @@ -648,6 +648,23 @@ func.func @index_switch(%i: index, %a: i32, %b: i32, %c: i32) -> i32 {
return %0 : i32
}

// CHECK-LABEL: @index_switch_large_case
func.func @index_switch_large_case(%i : index) {
// CHECK: cf.switch
// CHECK: 4294967296: ^[[bb1:.+]]
scf.index_switch %i
case 4294967296 { // 2^32
// CHECK: ^[[bb1]]:
// CHECK-NEXT: "test.op"
"test.op"() : () -> ()
scf.yield
}
default {
scf.yield
}
return
}

// Note: scf.forall is lowered to scf.parallel, which is currently lowered to
// scf.for and then to unstructured control flow. scf.parallel could lower more
// efficiently to multi-threaded IR, at which point scf.forall would
Expand Down
Loading