-
Notifications
You must be signed in to change notification settings - Fork 15.4k
[mlir][scf] Allow different forwarding ordering in uplift #133117
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@llvm/pr-subscribers-mlir-scf @llvm/pr-subscribers-mlir Author: None (darkbuck) Changes
Full diff: https://github.com/llvm/llvm-project/pull/133117.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp b/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
index 7b4024b6861a7..9c4fe702de119 100644
--- a/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
@@ -48,10 +48,43 @@ FailureOr<scf::ForOp> mlir::scf::upliftWhileToForLoop(RewriterBase &rewriter,
diag << "Expected single condition use: " << *cmp;
});
+ std::optional<SmallVector<unsigned>> argReorder;
// All `before` block args must be directly forwarded to ConditionOp.
// They will be converted to `scf.for` `iter_vars` except induction var.
- if (ValueRange(beforeBody->getArguments()) != beforeTerm.getArgs())
- return rewriter.notifyMatchFailure(loop, "Invalid args order");
+ if (ValueRange(beforeBody->getArguments()) != beforeTerm.getArgs()) {
+ auto getArgReordering =
+ [](Block *beforeBody,
+ scf::ConditionOp cond) -> std::optional<SmallVector<unsigned>> {
+ // Skip further checking if their sizes mismatch.
+ if (beforeBody->getNumArguments() != cond.getArgs().size())
+ return std::nullopt;
+ // Bitset on which 'before' argument is forwarded.
+ BitVector forwarded(beforeBody->getNumArguments(), false);
+ // The forwarding order of 'before' arguments.
+ SmallVector<unsigned> order;
+ for (Value a : cond.getArgs()) {
+ BlockArgument arg = dyn_cast<BlockArgument>(a);
+ // Skip if 'arg' is not a 'before' argument.
+ if (!arg || arg.getOwner() != beforeBody)
+ return std::nullopt;
+ unsigned idx = arg.getArgNumber();
+ // Skip if 'arg' is already forwarded in another place.
+ if (forwarded[idx])
+ return std::nullopt;
+ // Record the presence of 'arg' and its order.
+ forwarded[idx] = true;
+ order.push_back(idx);
+ }
+ // Skip if not all 'before' arguments are forwarded.
+ if (!forwarded.all())
+ return std::nullopt;
+ return order;
+ };
+ // Check if 'before' arguments are all forwarded but just reordered.
+ argReorder = getArgReordering(beforeBody, beforeTerm);
+ if (!argReorder)
+ return rewriter.notifyMatchFailure(loop, "Invalid args order");
+ }
using Pred = arith::CmpIPredicate;
Pred predicate = cmp.getPredicate();
@@ -100,6 +133,17 @@ FailureOr<scf::ForOp> mlir::scf::upliftWhileToForLoop(RewriterBase &rewriter,
});
Block *afterBody = loop.getAfterBody();
+ if (argReorder) {
+ // If forwarded arguments are not the same order as 'before' arguments,
+ // reorder them before converting 'after' body into 'for' body.
+ for (unsigned order : *argReorder) {
+ BlockArgument oldArg = afterBody->getArgument(order);
+ BlockArgument newArg =
+ afterBody->addArgument(oldArg.getType(), oldArg.getLoc());
+ oldArg.replaceAllUsesWith(newArg);
+ }
+ afterBody->eraseArguments(0, argReorder->size());
+ }
scf::YieldOp afterTerm = loop.getYieldOp();
unsigned argNumber = inductionVar.getArgNumber();
Value afterTermIndArg = afterTerm.getResults()[argNumber];
@@ -130,7 +174,7 @@ FailureOr<scf::ForOp> mlir::scf::upliftWhileToForLoop(RewriterBase &rewriter,
assert(lb.getType() == ub.getType());
assert(lb.getType() == step.getType());
- llvm::SmallVector<Value> newArgs;
+ SmallVector<Value> newArgs;
// Populate inits for new `scf.for`, skip induction var.
newArgs.reserve(loop.getInits().size());
@@ -205,6 +249,14 @@ FailureOr<scf::ForOp> mlir::scf::upliftWhileToForLoop(RewriterBase &rewriter,
newArgs.clear();
llvm::append_range(newArgs, newLoop.getResults());
newArgs.insert(newArgs.begin() + argNumber, res);
+ if (argReorder) {
+ // If 'yield' arguments (or forwarded arguments) are not the same order as
+ // 'before' arguments (or 'for' results), reorder them.
+ SmallVector<Value> results;
+ for (unsigned order : *argReorder)
+ results.push_back(newArgs[order]);
+ newArgs = results;
+ }
rewriter.replaceOp(loop, newArgs);
return newLoop;
}
diff --git a/mlir/test/Dialect/SCF/uplift-while.mlir b/mlir/test/Dialect/SCF/uplift-while.mlir
index 25ea6142a332d..cbe2ce5076ad2 100644
--- a/mlir/test/Dialect/SCF/uplift-while.mlir
+++ b/mlir/test/Dialect/SCF/uplift-while.mlir
@@ -155,3 +155,33 @@ func.func @uplift_while(%arg0: i64, %arg1: i64, %arg2: i64) -> i64 {
// CHECK: %[[R6:.*]] = arith.muli %[[R5]], %[[STEP]] : i64
// CHECK: %[[R7:.*]] = arith.addi %[[BEGIN]], %[[R6]] : i64
// CHECK: return %[[R7]] : i64
+
+// -----
+
+// A case where all 'before' arguments are forwarded but reordered.
+func.func @uplift_while(%arg0: index, %arg1: index, %arg2: index) -> (i32, f32) {
+ %c1 = arith.constant 1 : i32
+ %c2 = arith.constant 2.0 : f32
+ %0:3 = scf.while (%arg4 = %c1, %arg3 = %arg0, %arg5 = %c2) : (i32, index, f32) -> (index, i32, f32) {
+ %1 = arith.cmpi slt, %arg3, %arg1 : index
+ scf.condition(%1) %arg3, %arg4, %arg5 : index, i32, f32
+ } do {
+ ^bb0(%arg3: index, %arg4: i32, %arg5: f32):
+ %1 = "test.test1"(%arg4) : (i32) -> i32
+ %added = arith.addi %arg3, %arg2 : index
+ %2 = "test.test2"(%arg5) : (f32) -> f32
+ scf.yield %1, %added, %2 : i32, index, f32
+ }
+ return %0#1, %0#2 : i32, f32
+}
+
+// CHECK-LABEL: func @uplift_while
+// CHECK-SAME: (%[[BEGIN:.*]]: index, %[[END:.*]]: index, %[[STEP:.*]]: index) -> (i32, f32)
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : i32
+// CHECK-DAG: %[[C2:.*]] = arith.constant 2.000000e+00 : f32
+// CHECK: %[[RES:.*]]:2 = scf.for %[[I:.*]] = %[[BEGIN]] to %[[END]] step %[[STEP]]
+// CHECK-SAME: iter_args(%[[ARG1:.*]] = %[[C1]], %[[ARG2:.*]] = %[[C2]]) -> (i32, f32) {
+// CHECK: %[[T1:.*]] = "test.test1"(%[[ARG1]]) : (i32) -> i32
+// CHECK: %[[T2:.*]] = "test.test2"(%[[ARG2]]) : (f32) -> f32
+// CHECK: scf.yield %[[T1]], %[[T2]] : i32, f32
+// CHECK: return %[[RES]]#0, %[[RES]]#1 : i32, f32
|
Hardcode84
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! I have a couple comments.
- Allow 'before' arguments are forwarded in different order to 'after' body when uplifting `scf.while` to `scf.for`.
25d6e13 to
d23a38c
Compare
Hardcode84
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
|
Thanks! |
|
@darkbuck did you mean to merge it? :) |
scf.whiletoscf.for.