-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[MLIR] fix invalid scf.index_switch lowering to cf.switch when case values are large
#111590
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
@llvm/pr-subscribers-mlir Author: Keyi Zhang (Kuree) ChangesThis PR fixes #111589 by making sure Full diff: https://github.com/llvm/llvm-project/pull/111590.diff 2 Files Affected:
diff --git a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
index 45f3bcfa393be8..5b7b6713397048 100644
--- a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
+++ b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
@@ -669,7 +669,7 @@ IndexSwitchLowering::matchAndRewrite(IndexSwitchOp op,
// Convert the case regions.
SmallVector<Block *> caseSuccessors;
- SmallVector<int32_t> caseValues;
+ SmallVector<int64_t> caseValues;
caseSuccessors.reserve(op.getCases().size());
caseValues.reserve(op.getCases().size());
for (auto [region, value] : llvm::zip(op.getCaseRegions(), op.getCases())) {
@@ -691,11 +691,14 @@ IndexSwitchLowering::matchAndRewrite(IndexSwitchOp op,
// Cast switch index to integer case value.
Value caseValue = rewriter.create<arith::IndexCastOp>(
- op.getLoc(), rewriter.getI32Type(), op.getArg());
+ op.getLoc(), rewriter.getI64Type(), op.getArg());
+ ShapedType caseValueType = VectorType::get(
+ static_cast<int64_t>(caseValues.size()), rewriter.getI64Type());
rewriter.create<cf::SwitchOp>(
op.getLoc(), caseValue, *defaultBlock, ValueRange(),
- rewriter.getDenseI32ArrayAttr(caseValues), caseSuccessors, caseOperands);
+ DenseIntElementsAttr::get(caseValueType, caseValues), caseSuccessors,
+ caseOperands);
rewriter.replaceOp(op, continueBlock->getArguments());
return success();
}
diff --git a/mlir/test/Conversion/SCFToControlFlow/convert-to-cfg.mlir b/mlir/test/Conversion/SCFToControlFlow/convert-to-cfg.mlir
index 9ea0093eff7868..ba841313320194 100644
--- a/mlir/test/Conversion/SCFToControlFlow/convert-to-cfg.mlir
+++ b/mlir/test/Conversion/SCFToControlFlow/convert-to-cfg.mlir
@@ -622,8 +622,8 @@ func.func @func_execute_region_elim_multi_yield() {
// CHECK-LABEL: @index_switch
func.func @index_switch(%i: index, %a: i32, %b: i32, %c: i32) -> i32 {
- // CHECK: %[[CASE:.*]] = arith.index_cast %arg0 : index to i32
- // CHECK: cf.switch %[[CASE]] : i32
+ // CHECK: %[[CASE:.*]] = arith.index_cast %arg0 : index to i64
+ // CHECK: cf.switch %[[CASE]] : i64
// CHECK-NEXT: default: ^[[DEFAULT:.+]],
// CHECK-NEXT: 0: ^[[bb1:.+]],
// CHECK-NEXT: 1: ^[[bb2:.+]]
@@ -648,6 +648,23 @@ func.func @index_switch(%i: index, %a: i32, %b: i32, %c: i32) -> i32 {
return %0 : i32
}
+// CHECK-LABEL: @index_switch_large_case
+func.func @index_switch_large_case(%i : index) {
+ // CHECK: cf.switch
+ // CHECK: 4294967296: ^[[bb1:.+]]
+ scf.index_switch %i
+ case 4294967296 { // 2^32
+ // CHECK: ^[[bb1]]:
+ // CHECK-NEXT: "test.op"
+ "test.op"() : () -> ()
+ scf.yield
+ }
+ default {
+ scf.yield
+ }
+ return
+}
+
// Note: scf.forall is lowered to scf.parallel, which is currently lowered to
// scf.for and then to unstructured control flow. scf.parallel could lower more
// efficiently to multi-threaded IR, at which point scf.forall would
|
| // Cast switch index to integer case value. | ||
| Value caseValue = rewriter.create<arith::IndexCastOp>( | ||
| op.getLoc(), rewriter.getI32Type(), op.getArg()); | ||
| op.getLoc(), rewriter.getI64Type(), op.getArg()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm surprised to see an IndexCast here. Isn't there a different mechanism in MLIR to specify how IndexType is lowered?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we have to convert index values at all? Can we defer this to the point where we lower to LLVM?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My take on this is that LLVMConversionTarget might not be the best way to resolve this. The index type is converted to integer types based on the data layout (or overridden by hand), as you suggested. On a 32-bit machine, e.g. RV32, index type is typically converted to i32. This still causes the overflow problem since we need i64 for the switch argument. If the conversion bitwidth is overridden to i64, then it will affect other dialect conversion, such as memref, since memref.extract_aligned_pointer_as_index returns an index type. Having an i64 casted to a llvm.ptr seems awkward on a 32-bit machine.
An alternative is to lower the LLVM in multiple steps with different conversion targets, but I am not sure if this approach is user friendly. Or am I missing something here?
This PR fixes #111589 by making sure
int64_tis used when converting case values. Usingint32_tmay cause an overflow and result in an invalid IR, as shown in the issue. A test case is also added.