Skip to content

Conversation

@darkbuck
Copy link
Contributor

No description provided.

@llvmbot
Copy link
Member

llvmbot commented May 14, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-scf

Author: None (darkbuck)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/139911.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp (+11-18)
  • (modified) mlir/test/Dialect/SCF/uplift-while.mlir (+64)
diff --git a/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp b/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
index ebe718ae4fb61..0fabaf6e63ee4 100644
--- a/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
@@ -91,9 +91,10 @@ FailureOr<scf::ForOp> mlir::scf::upliftWhileToForLoop(RewriterBase &rewriter,
 
   using Pred = arith::CmpIPredicate;
   Pred predicate = cmp.getPredicate();
-  if (predicate != Pred::slt && predicate != Pred::sgt)
+  if (predicate != Pred::slt && predicate != Pred::sgt &&
+      predicate != Pred::ult && predicate != Pred::ugt)
     return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) {
-      diag << "Expected 'slt' or 'sgt' predicate: " << *cmp;
+      diag << "Expected 'slt'/'ult' or 'sgt'/'ugt' predicate: " << *cmp;
     });
 
   BlockArgument inductionVar;
@@ -103,24 +104,16 @@ FailureOr<scf::ForOp> mlir::scf::upliftWhileToForLoop(RewriterBase &rewriter,
   // Check if cmp has a suitable form. One of the arguments must be a `before`
   // block arg, other must be defined outside `scf.while` and will be treated
   // as upper bound.
-  for (bool reverse : {false, true}) {
-    auto expectedPred = reverse ? Pred::sgt : Pred::slt;
-    if (cmp.getPredicate() != expectedPred)
-      continue;
-
-    auto arg1 = reverse ? cmp.getRhs() : cmp.getLhs();
-    auto arg2 = reverse ? cmp.getLhs() : cmp.getRhs();
-
-    auto blockArg = dyn_cast<BlockArgument>(arg1);
-    if (!blockArg || blockArg.getOwner() != beforeBody)
-      continue;
-
-    if (!dom.properlyDominates(arg2, loop))
-      continue;
-
+  auto arg1 = cmp.getLhs();
+  auto arg2 = cmp.getRhs();
+  if (predicate == Pred::sgt || predicate == Pred::ugt)
+    std::swap(arg1, arg2);
+
+  auto blockArg = dyn_cast<BlockArgument>(arg1);
+  if (blockArg && blockArg.getOwner() == beforeBody &&
+      dom.properlyDominates(arg2, loop)) {
     inductionVar = blockArg;
     ub = arg2;
-    break;
   }
 
   if (!inductionVar)
diff --git a/mlir/test/Dialect/SCF/uplift-while.mlir b/mlir/test/Dialect/SCF/uplift-while.mlir
index cbe2ce5076ad2..f11f5ab28d707 100644
--- a/mlir/test/Dialect/SCF/uplift-while.mlir
+++ b/mlir/test/Dialect/SCF/uplift-while.mlir
@@ -185,3 +185,67 @@ func.func @uplift_while(%arg0: index, %arg1: index, %arg2: index) -> (i32, f32)
 //       CHECK:     %[[T2:.*]] = "test.test2"(%[[ARG2]]) : (f32) -> f32
 //       CHECK:     scf.yield %[[T1]], %[[T2]] : i32, f32
 //       CHECK:     return %[[RES]]#0, %[[RES]]#1 : i32, f32
+
+// -----
+
+func.func @uplift_while(%arg0: index, %arg1: index, %arg2: index) -> index {
+  %0 = scf.while (%arg3 = %arg0) : (index) -> (index) {
+    %1 = arith.cmpi ult, %arg3, %arg1 : index
+    scf.condition(%1) %arg3 : index
+  } do {
+  ^bb0(%arg3: index):
+    "test.test1"(%arg3) : (index) -> ()
+    %added = arith.addi %arg3, %arg2 : index
+    "test.test2"(%added) : (index) -> ()
+    scf.yield %added : index
+  }
+  return %0 : index
+}
+
+// CHECK-LABEL: func @uplift_while
+//  CHECK-SAME:     (%[[BEGIN:.*]]: index, %[[END:.*]]: index, %[[STEP:.*]]: index) -> index
+//       CHECK:     %[[C1:.*]] = arith.constant 1 : index
+//       CHECK:     scf.for %[[I:.*]] = %[[BEGIN]] to %[[END]] step %[[STEP]] {
+//       CHECK:     "test.test1"(%[[I]]) : (index) -> ()
+//       CHECK:     %[[INC:.*]] = arith.addi %[[I]], %[[STEP]] : index
+//       CHECK:     "test.test2"(%[[INC]]) : (index) -> ()
+//       CHECK:     %[[R1:.*]] = arith.subi %[[STEP]], %[[C1]] : index
+//       CHECK:     %[[R2:.*]] = arith.subi %[[END]], %[[BEGIN]] : index
+//       CHECK:     %[[R3:.*]] = arith.addi %[[R2]], %[[R1]] : index
+//       CHECK:     %[[R4:.*]] = arith.divsi %[[R3]], %[[STEP]] : index
+//       CHECK:     %[[R5:.*]] = arith.subi %[[R4]], %[[C1]] : index
+//       CHECK:     %[[R6:.*]] = arith.muli %[[R5]], %[[STEP]] : index
+//       CHECK:     %[[R7:.*]] = arith.addi %[[BEGIN]], %[[R6]] : index
+//       CHECK:     return %[[R7]] : index
+
+// -----
+
+func.func @uplift_while(%arg0: index, %arg1: index, %arg2: index) -> index {
+  %0 = scf.while (%arg3 = %arg0) : (index) -> (index) {
+    %1 = arith.cmpi ugt, %arg1, %arg3 : index
+    scf.condition(%1) %arg3 : index
+  } do {
+  ^bb0(%arg3: index):
+    "test.test1"(%arg3) : (index) -> ()
+    %added = arith.addi %arg3, %arg2 : index
+    "test.test2"(%added) : (index) -> ()
+    scf.yield %added : index
+  }
+  return %0 : index
+}
+
+// CHECK-LABEL: func @uplift_while
+//  CHECK-SAME:     (%[[BEGIN:.*]]: index, %[[END:.*]]: index, %[[STEP:.*]]: index) -> index
+//       CHECK:     %[[C1:.*]] = arith.constant 1 : index
+//       CHECK:     scf.for %[[I:.*]] = %[[BEGIN]] to %[[END]] step %[[STEP]] {
+//       CHECK:     "test.test1"(%[[I]]) : (index) -> ()
+//       CHECK:     %[[INC:.*]] = arith.addi %[[I]], %[[STEP]] : index
+//       CHECK:     "test.test2"(%[[INC]]) : (index) -> ()
+//       CHECK:     %[[R1:.*]] = arith.subi %[[STEP]], %[[C1]] : index
+//       CHECK:     %[[R2:.*]] = arith.subi %[[END]], %[[BEGIN]] : index
+//       CHECK:     %[[R3:.*]] = arith.addi %[[R2]], %[[R1]] : index
+//       CHECK:     %[[R4:.*]] = arith.divsi %[[R3]], %[[STEP]] : index
+//       CHECK:     %[[R5:.*]] = arith.subi %[[R4]], %[[C1]] : index
+//       CHECK:     %[[R6:.*]] = arith.muli %[[R5]], %[[STEP]] : index
+//       CHECK:     %[[R7:.*]] = arith.addi %[[BEGIN]], %[[R6]] : index
+//       CHECK:     return %[[R7]] : index

Pred predicate = cmp.getPredicate();
if (predicate != Pred::slt && predicate != Pred::sgt)
if (predicate != Pred::slt && predicate != Pred::sgt &&
predicate != Pred::ult && predicate != Pred::ugt)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is ult/ugt correct? scf::for will do a signed comparison so the new IR is not equivalent.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1, Maybe better approach will be to try to convert unsigned comparisons to signed first (e.g. using int range analysis).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is ult/ugt correct? scf::for will do a signed comparison so the new IR is not equivalent.

Sorry, I may misread that. But, for scf.for, it reads that

 151     The `scf.for` operation represents a loop taking 3 SSA value as operands
 152     that represent the lower bound, upper bound and step respectively. The
 153     operation defines an SSA value for its induction variable. It has one
 154     region capturing the loop body. The induction variable is represented as an
 155     argument of this region. This SSA value is a signless integer or index.

Is that indvar signless integer?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

signless means the sign information is held by the operation. The semantic of the op is clarified in the doc:

The lower and upper bounds specify a half-open range: the iteration is executed iff the signed comparison of induction variable value is less than the upper bound and bigger or equal to the lower bound.

@darkbuck darkbuck closed this May 14, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants