Skip to content

Conversation

@Kuree
Copy link
Contributor

@Kuree Kuree commented Oct 8, 2024

This PR fixes #111589 by making sure int64_t is used when converting case values. Using int32_t may cause an overflow and result in an invalid IR, as shown in the issue. A test case is also added.

@llvmbot llvmbot added the mlir label Oct 8, 2024
@llvmbot
Copy link
Member

llvmbot commented Oct 8, 2024

@llvm/pr-subscribers-mlir

Author: Keyi Zhang (Kuree)

Changes

This PR fixes #111589 by making sure int64_t is used when converting case values. Using int32_t may cause an overflow and result in an invalid IR, as shown in the issue. A test case is also added.


Full diff: https://github.com/llvm/llvm-project/pull/111590.diff

2 Files Affected:

  • (modified) mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp (+6-3)
  • (modified) mlir/test/Conversion/SCFToControlFlow/convert-to-cfg.mlir (+19-2)
diff --git a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
index 45f3bcfa393be8..5b7b6713397048 100644
--- a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
+++ b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
@@ -669,7 +669,7 @@ IndexSwitchLowering::matchAndRewrite(IndexSwitchOp op,
 
   // Convert the case regions.
   SmallVector<Block *> caseSuccessors;
-  SmallVector<int32_t> caseValues;
+  SmallVector<int64_t> caseValues;
   caseSuccessors.reserve(op.getCases().size());
   caseValues.reserve(op.getCases().size());
   for (auto [region, value] : llvm::zip(op.getCaseRegions(), op.getCases())) {
@@ -691,11 +691,14 @@ IndexSwitchLowering::matchAndRewrite(IndexSwitchOp op,
 
   // Cast switch index to integer case value.
   Value caseValue = rewriter.create<arith::IndexCastOp>(
-      op.getLoc(), rewriter.getI32Type(), op.getArg());
+      op.getLoc(), rewriter.getI64Type(), op.getArg());
 
+  ShapedType caseValueType = VectorType::get(
+      static_cast<int64_t>(caseValues.size()), rewriter.getI64Type());
   rewriter.create<cf::SwitchOp>(
       op.getLoc(), caseValue, *defaultBlock, ValueRange(),
-      rewriter.getDenseI32ArrayAttr(caseValues), caseSuccessors, caseOperands);
+      DenseIntElementsAttr::get(caseValueType, caseValues), caseSuccessors,
+      caseOperands);
   rewriter.replaceOp(op, continueBlock->getArguments());
   return success();
 }
diff --git a/mlir/test/Conversion/SCFToControlFlow/convert-to-cfg.mlir b/mlir/test/Conversion/SCFToControlFlow/convert-to-cfg.mlir
index 9ea0093eff7868..ba841313320194 100644
--- a/mlir/test/Conversion/SCFToControlFlow/convert-to-cfg.mlir
+++ b/mlir/test/Conversion/SCFToControlFlow/convert-to-cfg.mlir
@@ -622,8 +622,8 @@ func.func @func_execute_region_elim_multi_yield() {
 
 // CHECK-LABEL: @index_switch
 func.func @index_switch(%i: index, %a: i32, %b: i32, %c: i32) -> i32 {
-  // CHECK: %[[CASE:.*]] = arith.index_cast %arg0 : index to i32
-  // CHECK: cf.switch %[[CASE]] : i32
+  // CHECK: %[[CASE:.*]] = arith.index_cast %arg0 : index to i64
+  // CHECK: cf.switch %[[CASE]] : i64
   // CHECK-NEXT: default: ^[[DEFAULT:.+]],
   // CHECK-NEXT: 0: ^[[bb1:.+]],
   // CHECK-NEXT: 1: ^[[bb2:.+]]
@@ -648,6 +648,23 @@ func.func @index_switch(%i: index, %a: i32, %b: i32, %c: i32) -> i32 {
   return %0 : i32
 }
 
+// CHECK-LABEL: @index_switch_large_case
+func.func @index_switch_large_case(%i : index) {
+  // CHECK: cf.switch
+  // CHECK: 4294967296: ^[[bb1:.+]]
+  scf.index_switch %i
+  case 4294967296 { // 2^32
+    // CHECK: ^[[bb1]]:
+    // CHECK-NEXT: "test.op"
+    "test.op"() : () -> ()
+    scf.yield
+  }
+  default {
+    scf.yield
+  }
+  return
+}
+
 // Note: scf.forall is lowered to scf.parallel, which is currently lowered to
 // scf.for and then to unstructured control flow. scf.parallel could lower more
 // efficiently to multi-threaded IR, at which point scf.forall would

// Cast switch index to integer case value.
Value caseValue = rewriter.create<arith::IndexCastOp>(
op.getLoc(), rewriter.getI32Type(), op.getArg());
op.getLoc(), rewriter.getI64Type(), op.getArg());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm surprised to see an IndexCast here. Isn't there a different mechanism in MLIR to specify how IndexType is lowered?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we have to convert index values at all? Can we defer this to the point where we lower to LLVM?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My take on this is that LLVMConversionTarget might not be the best way to resolve this. The index type is converted to integer types based on the data layout (or overridden by hand), as you suggested. On a 32-bit machine, e.g. RV32, index type is typically converted to i32. This still causes the overflow problem since we need i64 for the switch argument. If the conversion bitwidth is overridden to i64, then it will affect other dialect conversion, such as memref, since memref.extract_aligned_pointer_as_index returns an index type. Having an i64 casted to a llvm.ptr seems awkward on a 32-bit machine.

An alternative is to lower the LLVM in multiple steps with different conversion targets, but I am not sure if this approach is user friendly. Or am I missing something here?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[MLIR] Invalid scf -> cf lowering when the scf.index_switch case values are greater than i32

3 participants