Skip to content

Conversation

@darkbuck
Copy link
Contributor

  • Allow 'before' arguments are forwarded in different order to 'after' body when uplifting scf.while to scf.for.

@llvmbot
Copy link
Member

llvmbot commented Mar 26, 2025

@llvm/pr-subscribers-mlir-scf

@llvm/pr-subscribers-mlir

Author: None (darkbuck)

Changes
  • Allow 'before' arguments are forwarded in different order to 'after' body when uplifting scf.while to scf.for.

Full diff: https://github.com/llvm/llvm-project/pull/133117.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp (+55-3)
  • (modified) mlir/test/Dialect/SCF/uplift-while.mlir (+30)
diff --git a/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp b/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
index 7b4024b6861a7..9c4fe702de119 100644
--- a/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
@@ -48,10 +48,43 @@ FailureOr<scf::ForOp> mlir::scf::upliftWhileToForLoop(RewriterBase &rewriter,
       diag << "Expected single condition use: " << *cmp;
     });
 
+  std::optional<SmallVector<unsigned>> argReorder;
   // All `before` block args must be directly forwarded to ConditionOp.
   // They will be converted to `scf.for` `iter_vars` except induction var.
-  if (ValueRange(beforeBody->getArguments()) != beforeTerm.getArgs())
-    return rewriter.notifyMatchFailure(loop, "Invalid args order");
+  if (ValueRange(beforeBody->getArguments()) != beforeTerm.getArgs()) {
+    auto getArgReordering =
+        [](Block *beforeBody,
+           scf::ConditionOp cond) -> std::optional<SmallVector<unsigned>> {
+      // Skip further checking if their sizes mismatch.
+      if (beforeBody->getNumArguments() != cond.getArgs().size())
+        return std::nullopt;
+      // Bitset on which 'before' argument is forwarded.
+      BitVector forwarded(beforeBody->getNumArguments(), false);
+      // The forwarding order of 'before' arguments.
+      SmallVector<unsigned> order;
+      for (Value a : cond.getArgs()) {
+        BlockArgument arg = dyn_cast<BlockArgument>(a);
+        // Skip if 'arg' is not a 'before' argument.
+        if (!arg || arg.getOwner() != beforeBody)
+          return std::nullopt;
+        unsigned idx = arg.getArgNumber();
+        // Skip if 'arg' is already forwarded in another place.
+        if (forwarded[idx])
+          return std::nullopt;
+        // Record the presence of 'arg' and its order.
+        forwarded[idx] = true;
+        order.push_back(idx);
+      }
+      // Skip if not all 'before' arguments are forwarded.
+      if (!forwarded.all())
+        return std::nullopt;
+      return order;
+    };
+    // Check if 'before' arguments are all forwarded but just reordered.
+    argReorder = getArgReordering(beforeBody, beforeTerm);
+    if (!argReorder)
+      return rewriter.notifyMatchFailure(loop, "Invalid args order");
+  }
 
   using Pred = arith::CmpIPredicate;
   Pred predicate = cmp.getPredicate();
@@ -100,6 +133,17 @@ FailureOr<scf::ForOp> mlir::scf::upliftWhileToForLoop(RewriterBase &rewriter,
     });
 
   Block *afterBody = loop.getAfterBody();
+  if (argReorder) {
+    // If forwarded arguments are not the same order as 'before' arguments,
+    // reorder them before converting 'after' body into 'for' body.
+    for (unsigned order : *argReorder) {
+      BlockArgument oldArg = afterBody->getArgument(order);
+      BlockArgument newArg =
+          afterBody->addArgument(oldArg.getType(), oldArg.getLoc());
+      oldArg.replaceAllUsesWith(newArg);
+    }
+    afterBody->eraseArguments(0, argReorder->size());
+  }
   scf::YieldOp afterTerm = loop.getYieldOp();
   unsigned argNumber = inductionVar.getArgNumber();
   Value afterTermIndArg = afterTerm.getResults()[argNumber];
@@ -130,7 +174,7 @@ FailureOr<scf::ForOp> mlir::scf::upliftWhileToForLoop(RewriterBase &rewriter,
   assert(lb.getType() == ub.getType());
   assert(lb.getType() == step.getType());
 
-  llvm::SmallVector<Value> newArgs;
+  SmallVector<Value> newArgs;
 
   // Populate inits for new `scf.for`, skip induction var.
   newArgs.reserve(loop.getInits().size());
@@ -205,6 +249,14 @@ FailureOr<scf::ForOp> mlir::scf::upliftWhileToForLoop(RewriterBase &rewriter,
   newArgs.clear();
   llvm::append_range(newArgs, newLoop.getResults());
   newArgs.insert(newArgs.begin() + argNumber, res);
+  if (argReorder) {
+    // If 'yield' arguments (or forwarded arguments) are not the same order as
+    // 'before' arguments (or 'for' results), reorder them.
+    SmallVector<Value> results;
+    for (unsigned order : *argReorder)
+      results.push_back(newArgs[order]);
+    newArgs = results;
+  }
   rewriter.replaceOp(loop, newArgs);
   return newLoop;
 }
diff --git a/mlir/test/Dialect/SCF/uplift-while.mlir b/mlir/test/Dialect/SCF/uplift-while.mlir
index 25ea6142a332d..cbe2ce5076ad2 100644
--- a/mlir/test/Dialect/SCF/uplift-while.mlir
+++ b/mlir/test/Dialect/SCF/uplift-while.mlir
@@ -155,3 +155,33 @@ func.func @uplift_while(%arg0: i64, %arg1: i64, %arg2: i64) -> i64 {
 //       CHECK:     %[[R6:.*]] = arith.muli %[[R5]], %[[STEP]] : i64
 //       CHECK:     %[[R7:.*]] = arith.addi %[[BEGIN]], %[[R6]] : i64
 //       CHECK:     return %[[R7]] : i64
+
+// -----
+
+// A case where all 'before' arguments are forwarded but reordered.
+func.func @uplift_while(%arg0: index, %arg1: index, %arg2: index) -> (i32, f32) {
+  %c1 = arith.constant 1 : i32
+  %c2 = arith.constant 2.0 : f32
+  %0:3 = scf.while (%arg4 = %c1, %arg3 = %arg0, %arg5 = %c2) : (i32, index, f32) -> (index, i32, f32) {
+    %1 = arith.cmpi slt, %arg3, %arg1 : index
+    scf.condition(%1) %arg3, %arg4, %arg5 : index, i32, f32
+  } do {
+  ^bb0(%arg3: index, %arg4: i32, %arg5: f32):
+    %1 = "test.test1"(%arg4) : (i32) -> i32
+    %added = arith.addi %arg3, %arg2 : index
+    %2 = "test.test2"(%arg5) : (f32) -> f32
+    scf.yield %1, %added, %2 : i32, index, f32
+  }
+  return %0#1, %0#2 : i32, f32
+}
+
+// CHECK-LABEL: func @uplift_while
+//  CHECK-SAME:     (%[[BEGIN:.*]]: index, %[[END:.*]]: index, %[[STEP:.*]]: index) -> (i32, f32)
+//   CHECK-DAG:     %[[C1:.*]] = arith.constant 1 : i32
+//   CHECK-DAG:     %[[C2:.*]] = arith.constant 2.000000e+00 : f32
+//       CHECK:     %[[RES:.*]]:2 = scf.for %[[I:.*]] = %[[BEGIN]] to %[[END]] step %[[STEP]]
+//  CHECK-SAME:     iter_args(%[[ARG1:.*]] = %[[C1]], %[[ARG2:.*]] = %[[C2]]) -> (i32, f32) {
+//       CHECK:     %[[T1:.*]] = "test.test1"(%[[ARG1]]) : (i32) -> i32
+//       CHECK:     %[[T2:.*]] = "test.test2"(%[[ARG2]]) : (f32) -> f32
+//       CHECK:     scf.yield %[[T1]], %[[T2]] : i32, f32
+//       CHECK:     return %[[RES]]#0, %[[RES]]#1 : i32, f32

Copy link
Contributor

@Hardcode84 Hardcode84 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I have a couple comments.

- Allow 'before' arguments are forwarded in different order to 'after'
  body when uplifting `scf.while` to `scf.for`.
@darkbuck darkbuck force-pushed the hliao/main/refine-uplift branch from 25d6e13 to d23a38c Compare March 27, 2025 20:09
Copy link
Contributor

@Hardcode84 Hardcode84 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@darkbuck
Copy link
Contributor Author

Thanks!

@darkbuck darkbuck closed this Mar 27, 2025
@Hardcode84
Copy link
Contributor

@darkbuck did you mean to merge it? :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants