diff --git a/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp b/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp index 7b4024b6861a7..ebe718ae4fb61 100644 --- a/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp @@ -48,10 +48,46 @@ FailureOr mlir::scf::upliftWhileToForLoop(RewriterBase &rewriter, diag << "Expected single condition use: " << *cmp; }); + // If all 'before' arguments are forwarded but the order is different from + // 'after' arguments, here is the mapping from the 'after' argument index to + // the 'before' argument index. + std::optional> argReorder; // All `before` block args must be directly forwarded to ConditionOp. // They will be converted to `scf.for` `iter_vars` except induction var. - if (ValueRange(beforeBody->getArguments()) != beforeTerm.getArgs()) - return rewriter.notifyMatchFailure(loop, "Invalid args order"); + if (ValueRange(beforeBody->getArguments()) != beforeTerm.getArgs()) { + auto getArgReordering = + [](Block *beforeBody, + scf::ConditionOp cond) -> std::optional> { + // Skip further checking if their sizes mismatch. + if (beforeBody->getNumArguments() != cond.getArgs().size()) + return std::nullopt; + // Bitset on which 'before' argument is forwarded. + llvm::SmallBitVector forwarded(beforeBody->getNumArguments(), false); + // The forwarding order of 'before' arguments. + SmallVector order; + for (Value a : cond.getArgs()) { + BlockArgument arg = dyn_cast(a); + // Skip if 'arg' is not a 'before' argument. + if (!arg || arg.getOwner() != beforeBody) + return std::nullopt; + unsigned idx = arg.getArgNumber(); + // Skip if 'arg' is already forwarded in another place. + if (forwarded[idx]) + return std::nullopt; + // Record the presence of 'arg' and its order. + forwarded[idx] = true; + order.push_back(idx); + } + // Skip if not all 'before' arguments are forwarded. + if (!forwarded.all()) + return std::nullopt; + return order; + }; + // Check if 'before' arguments are all forwarded but just reordered. + argReorder = getArgReordering(beforeBody, beforeTerm); + if (!argReorder) + return rewriter.notifyMatchFailure(loop, "Invalid args order"); + } using Pred = arith::CmpIPredicate; Pred predicate = cmp.getPredicate(); @@ -104,7 +140,14 @@ FailureOr mlir::scf::upliftWhileToForLoop(RewriterBase &rewriter, unsigned argNumber = inductionVar.getArgNumber(); Value afterTermIndArg = afterTerm.getResults()[argNumber]; - Value inductionVarAfter = afterBody->getArgument(argNumber); + auto findAfterArgNo = [](ArrayRef indices, unsigned beforeArgNo) { + return std::distance(indices.begin(), + llvm::find_if(indices, [beforeArgNo](unsigned n) { + return n == beforeArgNo; + })); + }; + Value inductionVarAfter = afterBody->getArgument( + argReorder ? findAfterArgNo(*argReorder, argNumber) : argNumber); // Find suitable `addi` op inside `after` block, one of the args must be an // Induction var passed from `before` block and second arg must be defined @@ -130,7 +173,7 @@ FailureOr mlir::scf::upliftWhileToForLoop(RewriterBase &rewriter, assert(lb.getType() == ub.getType()); assert(lb.getType() == step.getType()); - llvm::SmallVector newArgs; + SmallVector newArgs; // Populate inits for new `scf.for`, skip induction var. newArgs.reserve(loop.getInits().size()); @@ -164,6 +207,14 @@ FailureOr mlir::scf::upliftWhileToForLoop(RewriterBase &rewriter, newArgs.emplace_back(newBodyArgs[i]); } } + if (argReorder) { + // Reorder arguments following the 'after' argument order from the original + // 'while' loop. + SmallVector args; + for (unsigned order : *argReorder) + args.push_back(newArgs[order]); + newArgs = args; + } rewriter.inlineBlockBefore(loop.getAfterBody(), newBody, newBody->end(), newArgs); @@ -205,6 +256,14 @@ FailureOr mlir::scf::upliftWhileToForLoop(RewriterBase &rewriter, newArgs.clear(); llvm::append_range(newArgs, newLoop.getResults()); newArgs.insert(newArgs.begin() + argNumber, res); + if (argReorder) { + // Reorder arguments following the 'after' argument order from the original + // 'while' loop. + SmallVector results; + for (unsigned order : *argReorder) + results.push_back(newArgs[order]); + newArgs = results; + } rewriter.replaceOp(loop, newArgs); return newLoop; } diff --git a/mlir/test/Dialect/SCF/uplift-while.mlir b/mlir/test/Dialect/SCF/uplift-while.mlir index 25ea6142a332d..cbe2ce5076ad2 100644 --- a/mlir/test/Dialect/SCF/uplift-while.mlir +++ b/mlir/test/Dialect/SCF/uplift-while.mlir @@ -155,3 +155,33 @@ func.func @uplift_while(%arg0: i64, %arg1: i64, %arg2: i64) -> i64 { // CHECK: %[[R6:.*]] = arith.muli %[[R5]], %[[STEP]] : i64 // CHECK: %[[R7:.*]] = arith.addi %[[BEGIN]], %[[R6]] : i64 // CHECK: return %[[R7]] : i64 + +// ----- + +// A case where all 'before' arguments are forwarded but reordered. +func.func @uplift_while(%arg0: index, %arg1: index, %arg2: index) -> (i32, f32) { + %c1 = arith.constant 1 : i32 + %c2 = arith.constant 2.0 : f32 + %0:3 = scf.while (%arg4 = %c1, %arg3 = %arg0, %arg5 = %c2) : (i32, index, f32) -> (index, i32, f32) { + %1 = arith.cmpi slt, %arg3, %arg1 : index + scf.condition(%1) %arg3, %arg4, %arg5 : index, i32, f32 + } do { + ^bb0(%arg3: index, %arg4: i32, %arg5: f32): + %1 = "test.test1"(%arg4) : (i32) -> i32 + %added = arith.addi %arg3, %arg2 : index + %2 = "test.test2"(%arg5) : (f32) -> f32 + scf.yield %1, %added, %2 : i32, index, f32 + } + return %0#1, %0#2 : i32, f32 +} + +// CHECK-LABEL: func @uplift_while +// CHECK-SAME: (%[[BEGIN:.*]]: index, %[[END:.*]]: index, %[[STEP:.*]]: index) -> (i32, f32) +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : i32 +// CHECK-DAG: %[[C2:.*]] = arith.constant 2.000000e+00 : f32 +// CHECK: %[[RES:.*]]:2 = scf.for %[[I:.*]] = %[[BEGIN]] to %[[END]] step %[[STEP]] +// CHECK-SAME: iter_args(%[[ARG1:.*]] = %[[C1]], %[[ARG2:.*]] = %[[C2]]) -> (i32, f32) { +// CHECK: %[[T1:.*]] = "test.test1"(%[[ARG1]]) : (i32) -> i32 +// CHECK: %[[T2:.*]] = "test.test2"(%[[ARG2]]) : (f32) -> f32 +// CHECK: scf.yield %[[T1]], %[[T2]] : i32, f32 +// CHECK: return %[[RES]]#0, %[[RES]]#1 : i32, f32