@@ -48,10 +48,43 @@ FailureOr<scf::ForOp> mlir::scf::upliftWhileToForLoop(RewriterBase &rewriter,
4848 diag << " Expected single condition use: " << *cmp;
4949 });
5050
51+ std::optional<SmallVector<unsigned >> argReorder;
5152 // All `before` block args must be directly forwarded to ConditionOp.
5253 // They will be converted to `scf.for` `iter_vars` except induction var.
53- if (ValueRange (beforeBody->getArguments ()) != beforeTerm.getArgs ())
54- return rewriter.notifyMatchFailure (loop, " Invalid args order" );
54+ if (ValueRange (beforeBody->getArguments ()) != beforeTerm.getArgs ()) {
55+ auto getArgReordering =
56+ [](Block *beforeBody,
57+ scf::ConditionOp cond) -> std::optional<SmallVector<unsigned >> {
58+ // Skip further checking if their sizes mismatch.
59+ if (beforeBody->getNumArguments () != cond.getArgs ().size ())
60+ return std::nullopt ;
61+ // Bitset on which 'before' argument is forwarded.
62+ BitVector forwarded (beforeBody->getNumArguments (), false );
63+ // The forwarding order of 'before' arguments.
64+ SmallVector<unsigned > order;
65+ for (Value a : cond.getArgs ()) {
66+ BlockArgument arg = dyn_cast<BlockArgument>(a);
67+ // Skip if 'arg' is not a 'before' argument.
68+ if (!arg || arg.getOwner () != beforeBody)
69+ return std::nullopt ;
70+ unsigned idx = arg.getArgNumber ();
71+ // Skip if 'arg' is already forwarded in another place.
72+ if (forwarded[idx])
73+ return std::nullopt ;
74+ // Record the presence of 'arg' and its order.
75+ forwarded[idx] = true ;
76+ order.push_back (idx);
77+ }
78+ // Skip if not all 'before' arguments are forwarded.
79+ if (!forwarded.all ())
80+ return std::nullopt ;
81+ return order;
82+ };
83+ // Check if 'before' arguments are all forwarded but just reordered.
84+ argReorder = getArgReordering (beforeBody, beforeTerm);
85+ if (!argReorder)
86+ return rewriter.notifyMatchFailure (loop, " Invalid args order" );
87+ }
5588
5689 using Pred = arith::CmpIPredicate;
5790 Pred predicate = cmp.getPredicate ();
@@ -100,6 +133,17 @@ FailureOr<scf::ForOp> mlir::scf::upliftWhileToForLoop(RewriterBase &rewriter,
100133 });
101134
102135 Block *afterBody = loop.getAfterBody ();
136+ if (argReorder) {
137+ // If forwarded arguments are not the same order as 'before' arguments,
138+ // reorder them before converting 'after' body into 'for' body.
139+ for (unsigned order : *argReorder) {
140+ BlockArgument oldArg = afterBody->getArgument (order);
141+ BlockArgument newArg =
142+ afterBody->addArgument (oldArg.getType (), oldArg.getLoc ());
143+ oldArg.replaceAllUsesWith (newArg);
144+ }
145+ afterBody->eraseArguments (0 , argReorder->size ());
146+ }
103147 scf::YieldOp afterTerm = loop.getYieldOp ();
104148 unsigned argNumber = inductionVar.getArgNumber ();
105149 Value afterTermIndArg = afterTerm.getResults ()[argNumber];
@@ -130,7 +174,7 @@ FailureOr<scf::ForOp> mlir::scf::upliftWhileToForLoop(RewriterBase &rewriter,
130174 assert (lb.getType () == ub.getType ());
131175 assert (lb.getType () == step.getType ());
132176
133- llvm:: SmallVector<Value> newArgs;
177+ SmallVector<Value> newArgs;
134178
135179 // Populate inits for new `scf.for`, skip induction var.
136180 newArgs.reserve (loop.getInits ().size ());
@@ -205,6 +249,14 @@ FailureOr<scf::ForOp> mlir::scf::upliftWhileToForLoop(RewriterBase &rewriter,
205249 newArgs.clear ();
206250 llvm::append_range (newArgs, newLoop.getResults ());
207251 newArgs.insert (newArgs.begin () + argNumber, res);
252+ if (argReorder) {
253+ // If 'yield' arguments (or forwarded arguments) are not the same order as
254+ // 'before' arguments (or 'for' results), reorder them.
255+ SmallVector<Value> results;
256+ for (unsigned order : *argReorder)
257+ results.push_back (newArgs[order]);
258+ newArgs = results;
259+ }
208260 rewriter.replaceOp (loop, newArgs);
209261 return newLoop;
210262}
0 commit comments