Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 11 additions & 18 deletions mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,10 @@ FailureOr<scf::ForOp> mlir::scf::upliftWhileToForLoop(RewriterBase &rewriter,

using Pred = arith::CmpIPredicate;
Pred predicate = cmp.getPredicate();
if (predicate != Pred::slt && predicate != Pred::sgt)
if (predicate != Pred::slt && predicate != Pred::sgt &&
predicate != Pred::ult && predicate != Pred::ugt)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is ult/ugt correct? scf::for will do a signed comparison so the new IR is not equivalent.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1, Maybe better approach will be to try to convert unsigned comparisons to signed first (e.g. using int range analysis).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is ult/ugt correct? scf::for will do a signed comparison so the new IR is not equivalent.

Sorry, I may misread that. But, for scf.for, it reads that

 151     The `scf.for` operation represents a loop taking 3 SSA value as operands
 152     that represent the lower bound, upper bound and step respectively. The
 153     operation defines an SSA value for its induction variable. It has one
 154     region capturing the loop body. The induction variable is represented as an
 155     argument of this region. This SSA value is a signless integer or index.

Is that indvar signless integer?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

signless means the sign information is held by the operation. The semantic of the op is clarified in the doc:

The lower and upper bounds specify a half-open range: the iteration is executed iff the signed comparison of induction variable value is less than the upper bound and bigger or equal to the lower bound.

return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) {
diag << "Expected 'slt' or 'sgt' predicate: " << *cmp;
diag << "Expected 'slt'/'ult' or 'sgt'/'ugt' predicate: " << *cmp;
});

BlockArgument inductionVar;
Expand All @@ -103,24 +104,16 @@ FailureOr<scf::ForOp> mlir::scf::upliftWhileToForLoop(RewriterBase &rewriter,
// Check if cmp has a suitable form. One of the arguments must be a `before`
// block arg, other must be defined outside `scf.while` and will be treated
// as upper bound.
for (bool reverse : {false, true}) {
auto expectedPred = reverse ? Pred::sgt : Pred::slt;
if (cmp.getPredicate() != expectedPred)
continue;

auto arg1 = reverse ? cmp.getRhs() : cmp.getLhs();
auto arg2 = reverse ? cmp.getLhs() : cmp.getRhs();

auto blockArg = dyn_cast<BlockArgument>(arg1);
if (!blockArg || blockArg.getOwner() != beforeBody)
continue;

if (!dom.properlyDominates(arg2, loop))
continue;

auto arg1 = cmp.getLhs();
auto arg2 = cmp.getRhs();
if (predicate == Pred::sgt || predicate == Pred::ugt)
std::swap(arg1, arg2);

auto blockArg = dyn_cast<BlockArgument>(arg1);
if (blockArg && blockArg.getOwner() == beforeBody &&
dom.properlyDominates(arg2, loop)) {
inductionVar = blockArg;
ub = arg2;
break;
}

if (!inductionVar)
Expand Down
64 changes: 64 additions & 0 deletions mlir/test/Dialect/SCF/uplift-while.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -185,3 +185,67 @@ func.func @uplift_while(%arg0: index, %arg1: index, %arg2: index) -> (i32, f32)
// CHECK: %[[T2:.*]] = "test.test2"(%[[ARG2]]) : (f32) -> f32
// CHECK: scf.yield %[[T1]], %[[T2]] : i32, f32
// CHECK: return %[[RES]]#0, %[[RES]]#1 : i32, f32

// -----

func.func @uplift_while(%arg0: index, %arg1: index, %arg2: index) -> index {
%0 = scf.while (%arg3 = %arg0) : (index) -> (index) {
%1 = arith.cmpi ult, %arg3, %arg1 : index
scf.condition(%1) %arg3 : index
} do {
^bb0(%arg3: index):
"test.test1"(%arg3) : (index) -> ()
%added = arith.addi %arg3, %arg2 : index
"test.test2"(%added) : (index) -> ()
scf.yield %added : index
}
return %0 : index
}

// CHECK-LABEL: func @uplift_while
// CHECK-SAME: (%[[BEGIN:.*]]: index, %[[END:.*]]: index, %[[STEP:.*]]: index) -> index
// CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK: scf.for %[[I:.*]] = %[[BEGIN]] to %[[END]] step %[[STEP]] {
// CHECK: "test.test1"(%[[I]]) : (index) -> ()
// CHECK: %[[INC:.*]] = arith.addi %[[I]], %[[STEP]] : index
// CHECK: "test.test2"(%[[INC]]) : (index) -> ()
// CHECK: %[[R1:.*]] = arith.subi %[[STEP]], %[[C1]] : index
// CHECK: %[[R2:.*]] = arith.subi %[[END]], %[[BEGIN]] : index
// CHECK: %[[R3:.*]] = arith.addi %[[R2]], %[[R1]] : index
// CHECK: %[[R4:.*]] = arith.divsi %[[R3]], %[[STEP]] : index
// CHECK: %[[R5:.*]] = arith.subi %[[R4]], %[[C1]] : index
// CHECK: %[[R6:.*]] = arith.muli %[[R5]], %[[STEP]] : index
// CHECK: %[[R7:.*]] = arith.addi %[[BEGIN]], %[[R6]] : index
// CHECK: return %[[R7]] : index

// -----

func.func @uplift_while(%arg0: index, %arg1: index, %arg2: index) -> index {
%0 = scf.while (%arg3 = %arg0) : (index) -> (index) {
%1 = arith.cmpi ugt, %arg1, %arg3 : index
scf.condition(%1) %arg3 : index
} do {
^bb0(%arg3: index):
"test.test1"(%arg3) : (index) -> ()
%added = arith.addi %arg3, %arg2 : index
"test.test2"(%added) : (index) -> ()
scf.yield %added : index
}
return %0 : index
}

// CHECK-LABEL: func @uplift_while
// CHECK-SAME: (%[[BEGIN:.*]]: index, %[[END:.*]]: index, %[[STEP:.*]]: index) -> index
// CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK: scf.for %[[I:.*]] = %[[BEGIN]] to %[[END]] step %[[STEP]] {
// CHECK: "test.test1"(%[[I]]) : (index) -> ()
// CHECK: %[[INC:.*]] = arith.addi %[[I]], %[[STEP]] : index
// CHECK: "test.test2"(%[[INC]]) : (index) -> ()
// CHECK: %[[R1:.*]] = arith.subi %[[STEP]], %[[C1]] : index
// CHECK: %[[R2:.*]] = arith.subi %[[END]], %[[BEGIN]] : index
// CHECK: %[[R3:.*]] = arith.addi %[[R2]], %[[R1]] : index
// CHECK: %[[R4:.*]] = arith.divsi %[[R3]], %[[STEP]] : index
// CHECK: %[[R5:.*]] = arith.subi %[[R4]], %[[C1]] : index
// CHECK: %[[R6:.*]] = arith.muli %[[R5]], %[[STEP]] : index
// CHECK: %[[R7:.*]] = arith.addi %[[BEGIN]], %[[R6]] : index
// CHECK: return %[[R7]] : index