@@ -48,10 +48,46 @@ FailureOr<scf::ForOp> mlir::scf::upliftWhileToForLoop(RewriterBase &rewriter,
4848 diag << " Expected single condition use: " << *cmp;
4949 });
5050
51+ // If all 'before' arguments are forwarded but the order is different from
52+ // 'after' arguments, here is the mapping from the 'after' argument index to
53+ // the 'before' argument index.
54+ std::optional<SmallVector<unsigned >> argReorder;
5155 // All `before` block args must be directly forwarded to ConditionOp.
5256 // They will be converted to `scf.for` `iter_vars` except induction var.
53- if (ValueRange (beforeBody->getArguments ()) != beforeTerm.getArgs ())
54- return rewriter.notifyMatchFailure (loop, " Invalid args order" );
57+ if (ValueRange (beforeBody->getArguments ()) != beforeTerm.getArgs ()) {
58+ auto getArgReordering =
59+ [](Block *beforeBody,
60+ scf::ConditionOp cond) -> std::optional<SmallVector<unsigned >> {
61+ // Skip further checking if their sizes mismatch.
62+ if (beforeBody->getNumArguments () != cond.getArgs ().size ())
63+ return std::nullopt ;
64+ // Bitset on which 'before' argument is forwarded.
65+ llvm::SmallBitVector forwarded (beforeBody->getNumArguments (), false );
66+ // The forwarding order of 'before' arguments.
67+ SmallVector<unsigned > order;
68+ for (Value a : cond.getArgs ()) {
69+ BlockArgument arg = dyn_cast<BlockArgument>(a);
70+ // Skip if 'arg' is not a 'before' argument.
71+ if (!arg || arg.getOwner () != beforeBody)
72+ return std::nullopt ;
73+ unsigned idx = arg.getArgNumber ();
74+ // Skip if 'arg' is already forwarded in another place.
75+ if (forwarded[idx])
76+ return std::nullopt ;
77+ // Record the presence of 'arg' and its order.
78+ forwarded[idx] = true ;
79+ order.push_back (idx);
80+ }
81+ // Skip if not all 'before' arguments are forwarded.
82+ if (!forwarded.all ())
83+ return std::nullopt ;
84+ return order;
85+ };
86+ // Check if 'before' arguments are all forwarded but just reordered.
87+ argReorder = getArgReordering (beforeBody, beforeTerm);
88+ if (!argReorder)
89+ return rewriter.notifyMatchFailure (loop, " Invalid args order" );
90+ }
5591
5692 using Pred = arith::CmpIPredicate;
5793 Pred predicate = cmp.getPredicate ();
@@ -104,7 +140,14 @@ FailureOr<scf::ForOp> mlir::scf::upliftWhileToForLoop(RewriterBase &rewriter,
104140 unsigned argNumber = inductionVar.getArgNumber ();
105141 Value afterTermIndArg = afterTerm.getResults ()[argNumber];
106142
107- Value inductionVarAfter = afterBody->getArgument (argNumber);
143+ auto findAfterArgNo = [](ArrayRef<unsigned > indices, unsigned beforeArgNo) {
144+ return std::distance (indices.begin (),
145+ llvm::find_if (indices, [beforeArgNo](unsigned n) {
146+ return n == beforeArgNo;
147+ }));
148+ };
149+ Value inductionVarAfter = afterBody->getArgument (
150+ argReorder ? findAfterArgNo (*argReorder, argNumber) : argNumber);
108151
109152 // Find suitable `addi` op inside `after` block, one of the args must be an
110153 // Induction var passed from `before` block and second arg must be defined
@@ -130,7 +173,7 @@ FailureOr<scf::ForOp> mlir::scf::upliftWhileToForLoop(RewriterBase &rewriter,
130173 assert (lb.getType () == ub.getType ());
131174 assert (lb.getType () == step.getType ());
132175
133- llvm:: SmallVector<Value> newArgs;
176+ SmallVector<Value> newArgs;
134177
135178 // Populate inits for new `scf.for`, skip induction var.
136179 newArgs.reserve (loop.getInits ().size ());
@@ -164,6 +207,14 @@ FailureOr<scf::ForOp> mlir::scf::upliftWhileToForLoop(RewriterBase &rewriter,
164207 newArgs.emplace_back (newBodyArgs[i]);
165208 }
166209 }
210+ if (argReorder) {
211+ // Reorder arguments following the 'after' argument order from the original
212+ // 'while' loop.
213+ SmallVector<Value> args;
214+ for (unsigned order : *argReorder)
215+ args.push_back (newArgs[order]);
216+ newArgs = args;
217+ }
167218
168219 rewriter.inlineBlockBefore (loop.getAfterBody (), newBody, newBody->end (),
169220 newArgs);
@@ -205,6 +256,14 @@ FailureOr<scf::ForOp> mlir::scf::upliftWhileToForLoop(RewriterBase &rewriter,
205256 newArgs.clear ();
206257 llvm::append_range (newArgs, newLoop.getResults ());
207258 newArgs.insert (newArgs.begin () + argNumber, res);
259+ if (argReorder) {
260+ // Reorder arguments following the 'after' argument order from the original
261+ // 'while' loop.
262+ SmallVector<Value> results;
263+ for (unsigned order : *argReorder)
264+ results.push_back (newArgs[order]);
265+ newArgs = results;
266+ }
208267 rewriter.replaceOp (loop, newArgs);
209268 return newLoop;
210269}
0 commit comments