Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 63 additions & 4 deletions mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,46 @@ FailureOr<scf::ForOp> mlir::scf::upliftWhileToForLoop(RewriterBase &rewriter,
diag << "Expected single condition use: " << *cmp;
});

// If all 'before' arguments are forwarded but the order is different from
// 'after' arguments, here is the mapping from the 'after' argument index to
// the 'before' argument index.
std::optional<SmallVector<unsigned>> argReorder;
// All `before` block args must be directly forwarded to ConditionOp.
// They will be converted to `scf.for` `iter_vars` except induction var.
if (ValueRange(beforeBody->getArguments()) != beforeTerm.getArgs())
return rewriter.notifyMatchFailure(loop, "Invalid args order");
if (ValueRange(beforeBody->getArguments()) != beforeTerm.getArgs()) {
auto getArgReordering =
[](Block *beforeBody,
scf::ConditionOp cond) -> std::optional<SmallVector<unsigned>> {
// Skip further checking if their sizes mismatch.
if (beforeBody->getNumArguments() != cond.getArgs().size())
return std::nullopt;
// Bitset on which 'before' argument is forwarded.
llvm::SmallBitVector forwarded(beforeBody->getNumArguments(), false);
// The forwarding order of 'before' arguments.
SmallVector<unsigned> order;
for (Value a : cond.getArgs()) {
BlockArgument arg = dyn_cast<BlockArgument>(a);
// Skip if 'arg' is not a 'before' argument.
if (!arg || arg.getOwner() != beforeBody)
return std::nullopt;
unsigned idx = arg.getArgNumber();
// Skip if 'arg' is already forwarded in another place.
if (forwarded[idx])
return std::nullopt;
// Record the presence of 'arg' and its order.
forwarded[idx] = true;
order.push_back(idx);
}
// Skip if not all 'before' arguments are forwarded.
if (!forwarded.all())
return std::nullopt;
return order;
};
// Check if 'before' arguments are all forwarded but just reordered.
argReorder = getArgReordering(beforeBody, beforeTerm);
if (!argReorder)
return rewriter.notifyMatchFailure(loop, "Invalid args order");
}

using Pred = arith::CmpIPredicate;
Pred predicate = cmp.getPredicate();
Expand Down Expand Up @@ -104,7 +140,14 @@ FailureOr<scf::ForOp> mlir::scf::upliftWhileToForLoop(RewriterBase &rewriter,
unsigned argNumber = inductionVar.getArgNumber();
Value afterTermIndArg = afterTerm.getResults()[argNumber];

Value inductionVarAfter = afterBody->getArgument(argNumber);
auto findAfterArgNo = [](ArrayRef<unsigned> indices, unsigned beforeArgNo) {
return std::distance(indices.begin(),
llvm::find_if(indices, [beforeArgNo](unsigned n) {
return n == beforeArgNo;
}));
};
Value inductionVarAfter = afterBody->getArgument(
argReorder ? findAfterArgNo(*argReorder, argNumber) : argNumber);

// Find suitable `addi` op inside `after` block, one of the args must be an
// Induction var passed from `before` block and second arg must be defined
Expand All @@ -130,7 +173,7 @@ FailureOr<scf::ForOp> mlir::scf::upliftWhileToForLoop(RewriterBase &rewriter,
assert(lb.getType() == ub.getType());
assert(lb.getType() == step.getType());

llvm::SmallVector<Value> newArgs;
SmallVector<Value> newArgs;

// Populate inits for new `scf.for`, skip induction var.
newArgs.reserve(loop.getInits().size());
Expand Down Expand Up @@ -164,6 +207,14 @@ FailureOr<scf::ForOp> mlir::scf::upliftWhileToForLoop(RewriterBase &rewriter,
newArgs.emplace_back(newBodyArgs[i]);
}
}
if (argReorder) {
// Reorder arguments following the 'after' argument order from the original
// 'while' loop.
SmallVector<Value> args;
for (unsigned order : *argReorder)
args.push_back(newArgs[order]);
newArgs = args;
}

rewriter.inlineBlockBefore(loop.getAfterBody(), newBody, newBody->end(),
newArgs);
Expand Down Expand Up @@ -205,6 +256,14 @@ FailureOr<scf::ForOp> mlir::scf::upliftWhileToForLoop(RewriterBase &rewriter,
newArgs.clear();
llvm::append_range(newArgs, newLoop.getResults());
newArgs.insert(newArgs.begin() + argNumber, res);
if (argReorder) {
// Reorder arguments following the 'after' argument order from the original
// 'while' loop.
SmallVector<Value> results;
for (unsigned order : *argReorder)
results.push_back(newArgs[order]);
newArgs = results;
}
rewriter.replaceOp(loop, newArgs);
return newLoop;
}
Expand Down
30 changes: 30 additions & 0 deletions mlir/test/Dialect/SCF/uplift-while.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -155,3 +155,33 @@ func.func @uplift_while(%arg0: i64, %arg1: i64, %arg2: i64) -> i64 {
// CHECK: %[[R6:.*]] = arith.muli %[[R5]], %[[STEP]] : i64
// CHECK: %[[R7:.*]] = arith.addi %[[BEGIN]], %[[R6]] : i64
// CHECK: return %[[R7]] : i64

// -----

// A case where all 'before' arguments are forwarded but reordered.
func.func @uplift_while(%arg0: index, %arg1: index, %arg2: index) -> (i32, f32) {
%c1 = arith.constant 1 : i32
%c2 = arith.constant 2.0 : f32
%0:3 = scf.while (%arg4 = %c1, %arg3 = %arg0, %arg5 = %c2) : (i32, index, f32) -> (index, i32, f32) {
%1 = arith.cmpi slt, %arg3, %arg1 : index
scf.condition(%1) %arg3, %arg4, %arg5 : index, i32, f32
} do {
^bb0(%arg3: index, %arg4: i32, %arg5: f32):
%1 = "test.test1"(%arg4) : (i32) -> i32
%added = arith.addi %arg3, %arg2 : index
%2 = "test.test2"(%arg5) : (f32) -> f32
scf.yield %1, %added, %2 : i32, index, f32
}
return %0#1, %0#2 : i32, f32
}

// CHECK-LABEL: func @uplift_while
// CHECK-SAME: (%[[BEGIN:.*]]: index, %[[END:.*]]: index, %[[STEP:.*]]: index) -> (i32, f32)
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : i32
// CHECK-DAG: %[[C2:.*]] = arith.constant 2.000000e+00 : f32
// CHECK: %[[RES:.*]]:2 = scf.for %[[I:.*]] = %[[BEGIN]] to %[[END]] step %[[STEP]]
// CHECK-SAME: iter_args(%[[ARG1:.*]] = %[[C1]], %[[ARG2:.*]] = %[[C2]]) -> (i32, f32) {
// CHECK: %[[T1:.*]] = "test.test1"(%[[ARG1]]) : (i32) -> i32
// CHECK: %[[T2:.*]] = "test.test2"(%[[ARG2]]) : (f32) -> f32
// CHECK: scf.yield %[[T1]], %[[T2]] : i32, f32
// CHECK: return %[[RES]]#0, %[[RES]]#1 : i32, f32