Skip to content

Commit 25d6e13

Browse files
committed
[mlir][scf] Allow different forwarding ordering in uplift
- Allow 'before' arguments are forwarded in different order to 'after' body when uplifting `scf.while` to `scf.for`.
1 parent 52f941a commit 25d6e13

File tree

2 files changed

+85
-3
lines changed

2 files changed

+85
-3
lines changed

mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp

Lines changed: 55 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,43 @@ FailureOr<scf::ForOp> mlir::scf::upliftWhileToForLoop(RewriterBase &rewriter,
4848
diag << "Expected single condition use: " << *cmp;
4949
});
5050

51+
std::optional<SmallVector<unsigned>> argReorder;
5152
// All `before` block args must be directly forwarded to ConditionOp.
5253
// They will be converted to `scf.for` `iter_vars` except induction var.
53-
if (ValueRange(beforeBody->getArguments()) != beforeTerm.getArgs())
54-
return rewriter.notifyMatchFailure(loop, "Invalid args order");
54+
if (ValueRange(beforeBody->getArguments()) != beforeTerm.getArgs()) {
55+
auto getArgReordering =
56+
[](Block *beforeBody,
57+
scf::ConditionOp cond) -> std::optional<SmallVector<unsigned>> {
58+
// Skip further checking if their sizes mismatch.
59+
if (beforeBody->getNumArguments() != cond.getArgs().size())
60+
return std::nullopt;
61+
// Bitset on which 'before' argument is forwarded.
62+
BitVector forwarded(beforeBody->getNumArguments(), false);
63+
// The forwarding order of 'before' arguments.
64+
SmallVector<unsigned> order;
65+
for (Value a : cond.getArgs()) {
66+
BlockArgument arg = dyn_cast<BlockArgument>(a);
67+
// Skip if 'arg' is not a 'before' argument.
68+
if (!arg || arg.getOwner() != beforeBody)
69+
return std::nullopt;
70+
unsigned idx = arg.getArgNumber();
71+
// Skip if 'arg' is already forwarded in another place.
72+
if (forwarded[idx])
73+
return std::nullopt;
74+
// Record the presence of 'arg' and its order.
75+
forwarded[idx] = true;
76+
order.push_back(idx);
77+
}
78+
// Skip if not all 'before' arguments are forwarded.
79+
if (!forwarded.all())
80+
return std::nullopt;
81+
return order;
82+
};
83+
// Check if 'before' arguments are all forwarded but just reordered.
84+
argReorder = getArgReordering(beforeBody, beforeTerm);
85+
if (!argReorder)
86+
return rewriter.notifyMatchFailure(loop, "Invalid args order");
87+
}
5588

5689
using Pred = arith::CmpIPredicate;
5790
Pred predicate = cmp.getPredicate();
@@ -100,6 +133,17 @@ FailureOr<scf::ForOp> mlir::scf::upliftWhileToForLoop(RewriterBase &rewriter,
100133
});
101134

102135
Block *afterBody = loop.getAfterBody();
136+
if (argReorder) {
137+
// If forwarded arguments are not the same order as 'before' arguments,
138+
// reorder them before converting 'after' body into 'for' body.
139+
for (unsigned order : *argReorder) {
140+
BlockArgument oldArg = afterBody->getArgument(order);
141+
BlockArgument newArg =
142+
afterBody->addArgument(oldArg.getType(), oldArg.getLoc());
143+
oldArg.replaceAllUsesWith(newArg);
144+
}
145+
afterBody->eraseArguments(0, argReorder->size());
146+
}
103147
scf::YieldOp afterTerm = loop.getYieldOp();
104148
unsigned argNumber = inductionVar.getArgNumber();
105149
Value afterTermIndArg = afterTerm.getResults()[argNumber];
@@ -130,7 +174,7 @@ FailureOr<scf::ForOp> mlir::scf::upliftWhileToForLoop(RewriterBase &rewriter,
130174
assert(lb.getType() == ub.getType());
131175
assert(lb.getType() == step.getType());
132176

133-
llvm::SmallVector<Value> newArgs;
177+
SmallVector<Value> newArgs;
134178

135179
// Populate inits for new `scf.for`, skip induction var.
136180
newArgs.reserve(loop.getInits().size());
@@ -205,6 +249,14 @@ FailureOr<scf::ForOp> mlir::scf::upliftWhileToForLoop(RewriterBase &rewriter,
205249
newArgs.clear();
206250
llvm::append_range(newArgs, newLoop.getResults());
207251
newArgs.insert(newArgs.begin() + argNumber, res);
252+
if (argReorder) {
253+
// If 'yield' arguments (or forwarded arguments) are not the same order as
254+
// 'before' arguments (or 'for' results), reorder them.
255+
SmallVector<Value> results;
256+
for (unsigned order : *argReorder)
257+
results.push_back(newArgs[order]);
258+
newArgs = results;
259+
}
208260
rewriter.replaceOp(loop, newArgs);
209261
return newLoop;
210262
}

mlir/test/Dialect/SCF/uplift-while.mlir

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,3 +155,33 @@ func.func @uplift_while(%arg0: i64, %arg1: i64, %arg2: i64) -> i64 {
155155
// CHECK: %[[R6:.*]] = arith.muli %[[R5]], %[[STEP]] : i64
156156
// CHECK: %[[R7:.*]] = arith.addi %[[BEGIN]], %[[R6]] : i64
157157
// CHECK: return %[[R7]] : i64
158+
159+
// -----
160+
161+
// A case where all 'before' arguments are forwarded but reordered.
162+
func.func @uplift_while(%arg0: index, %arg1: index, %arg2: index) -> (i32, f32) {
163+
%c1 = arith.constant 1 : i32
164+
%c2 = arith.constant 2.0 : f32
165+
%0:3 = scf.while (%arg4 = %c1, %arg3 = %arg0, %arg5 = %c2) : (i32, index, f32) -> (index, i32, f32) {
166+
%1 = arith.cmpi slt, %arg3, %arg1 : index
167+
scf.condition(%1) %arg3, %arg4, %arg5 : index, i32, f32
168+
} do {
169+
^bb0(%arg3: index, %arg4: i32, %arg5: f32):
170+
%1 = "test.test1"(%arg4) : (i32) -> i32
171+
%added = arith.addi %arg3, %arg2 : index
172+
%2 = "test.test2"(%arg5) : (f32) -> f32
173+
scf.yield %1, %added, %2 : i32, index, f32
174+
}
175+
return %0#1, %0#2 : i32, f32
176+
}
177+
178+
// CHECK-LABEL: func @uplift_while
179+
// CHECK-SAME: (%[[BEGIN:.*]]: index, %[[END:.*]]: index, %[[STEP:.*]]: index) -> (i32, f32)
180+
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : i32
181+
// CHECK-DAG: %[[C2:.*]] = arith.constant 2.000000e+00 : f32
182+
// CHECK: %[[RES:.*]]:2 = scf.for %[[I:.*]] = %[[BEGIN]] to %[[END]] step %[[STEP]]
183+
// CHECK-SAME: iter_args(%[[ARG1:.*]] = %[[C1]], %[[ARG2:.*]] = %[[C2]]) -> (i32, f32) {
184+
// CHECK: %[[T1:.*]] = "test.test1"(%[[ARG1]]) : (i32) -> i32
185+
// CHECK: %[[T2:.*]] = "test.test2"(%[[ARG2]]) : (f32) -> f32
186+
// CHECK: scf.yield %[[T1]], %[[T2]] : i32, f32
187+
// CHECK: return %[[RES]]#0, %[[RES]]#1 : i32, f32

0 commit comments

Comments
 (0)