Skip to content

Commit 2013ca3

Browse files
committed
[ mlir][scf] Allow 'ult'/'ugt' in uplift
1 parent 7e690db commit 2013ca3

File tree

2 files changed

+75
-18
lines changed

2 files changed

+75
-18
lines changed

mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp

Lines changed: 11 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -91,9 +91,10 @@ FailureOr<scf::ForOp> mlir::scf::upliftWhileToForLoop(RewriterBase &rewriter,
9191

9292
using Pred = arith::CmpIPredicate;
9393
Pred predicate = cmp.getPredicate();
94-
if (predicate != Pred::slt && predicate != Pred::sgt)
94+
if (predicate != Pred::slt && predicate != Pred::sgt &&
95+
predicate != Pred::ult && predicate != Pred::ugt)
9596
return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) {
96-
diag << "Expected 'slt' or 'sgt' predicate: " << *cmp;
97+
diag << "Expected 'slt'/'ult' or 'sgt'/'ugt' predicate: " << *cmp;
9798
});
9899

99100
BlockArgument inductionVar;
@@ -103,24 +104,16 @@ FailureOr<scf::ForOp> mlir::scf::upliftWhileToForLoop(RewriterBase &rewriter,
103104
// Check if cmp has a suitable form. One of the arguments must be a `before`
104105
// block arg, other must be defined outside `scf.while` and will be treated
105106
// as upper bound.
106-
for (bool reverse : {false, true}) {
107-
auto expectedPred = reverse ? Pred::sgt : Pred::slt;
108-
if (cmp.getPredicate() != expectedPred)
109-
continue;
110-
111-
auto arg1 = reverse ? cmp.getRhs() : cmp.getLhs();
112-
auto arg2 = reverse ? cmp.getLhs() : cmp.getRhs();
113-
114-
auto blockArg = dyn_cast<BlockArgument>(arg1);
115-
if (!blockArg || blockArg.getOwner() != beforeBody)
116-
continue;
117-
118-
if (!dom.properlyDominates(arg2, loop))
119-
continue;
120-
107+
auto arg1 = cmp.getLhs();
108+
auto arg2 = cmp.getRhs();
109+
if (predicate == Pred::sgt || predicate == Pred::ugt)
110+
std::swap(arg1, arg2);
111+
112+
auto blockArg = dyn_cast<BlockArgument>(arg1);
113+
if (blockArg && blockArg.getOwner() == beforeBody &&
114+
dom.properlyDominates(arg2, loop)) {
121115
inductionVar = blockArg;
122116
ub = arg2;
123-
break;
124117
}
125118

126119
if (!inductionVar)

mlir/test/Dialect/SCF/uplift-while.mlir

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,3 +185,67 @@ func.func @uplift_while(%arg0: index, %arg1: index, %arg2: index) -> (i32, f32)
185185
// CHECK: %[[T2:.*]] = "test.test2"(%[[ARG2]]) : (f32) -> f32
186186
// CHECK: scf.yield %[[T1]], %[[T2]] : i32, f32
187187
// CHECK: return %[[RES]]#0, %[[RES]]#1 : i32, f32
188+
189+
// -----
190+
191+
func.func @uplift_while(%arg0: index, %arg1: index, %arg2: index) -> index {
192+
%0 = scf.while (%arg3 = %arg0) : (index) -> (index) {
193+
%1 = arith.cmpi ult, %arg3, %arg1 : index
194+
scf.condition(%1) %arg3 : index
195+
} do {
196+
^bb0(%arg3: index):
197+
"test.test1"(%arg3) : (index) -> ()
198+
%added = arith.addi %arg3, %arg2 : index
199+
"test.test2"(%added) : (index) -> ()
200+
scf.yield %added : index
201+
}
202+
return %0 : index
203+
}
204+
205+
// CHECK-LABEL: func @uplift_while
206+
// CHECK-SAME: (%[[BEGIN:.*]]: index, %[[END:.*]]: index, %[[STEP:.*]]: index) -> index
207+
// CHECK: %[[C1:.*]] = arith.constant 1 : index
208+
// CHECK: scf.for %[[I:.*]] = %[[BEGIN]] to %[[END]] step %[[STEP]] {
209+
// CHECK: "test.test1"(%[[I]]) : (index) -> ()
210+
// CHECK: %[[INC:.*]] = arith.addi %[[I]], %[[STEP]] : index
211+
// CHECK: "test.test2"(%[[INC]]) : (index) -> ()
212+
// CHECK: %[[R1:.*]] = arith.subi %[[STEP]], %[[C1]] : index
213+
// CHECK: %[[R2:.*]] = arith.subi %[[END]], %[[BEGIN]] : index
214+
// CHECK: %[[R3:.*]] = arith.addi %[[R2]], %[[R1]] : index
215+
// CHECK: %[[R4:.*]] = arith.divsi %[[R3]], %[[STEP]] : index
216+
// CHECK: %[[R5:.*]] = arith.subi %[[R4]], %[[C1]] : index
217+
// CHECK: %[[R6:.*]] = arith.muli %[[R5]], %[[STEP]] : index
218+
// CHECK: %[[R7:.*]] = arith.addi %[[BEGIN]], %[[R6]] : index
219+
// CHECK: return %[[R7]] : index
220+
221+
// -----
222+
223+
func.func @uplift_while(%arg0: index, %arg1: index, %arg2: index) -> index {
224+
%0 = scf.while (%arg3 = %arg0) : (index) -> (index) {
225+
%1 = arith.cmpi ugt, %arg1, %arg3 : index
226+
scf.condition(%1) %arg3 : index
227+
} do {
228+
^bb0(%arg3: index):
229+
"test.test1"(%arg3) : (index) -> ()
230+
%added = arith.addi %arg3, %arg2 : index
231+
"test.test2"(%added) : (index) -> ()
232+
scf.yield %added : index
233+
}
234+
return %0 : index
235+
}
236+
237+
// CHECK-LABEL: func @uplift_while
238+
// CHECK-SAME: (%[[BEGIN:.*]]: index, %[[END:.*]]: index, %[[STEP:.*]]: index) -> index
239+
// CHECK: %[[C1:.*]] = arith.constant 1 : index
240+
// CHECK: scf.for %[[I:.*]] = %[[BEGIN]] to %[[END]] step %[[STEP]] {
241+
// CHECK: "test.test1"(%[[I]]) : (index) -> ()
242+
// CHECK: %[[INC:.*]] = arith.addi %[[I]], %[[STEP]] : index
243+
// CHECK: "test.test2"(%[[INC]]) : (index) -> ()
244+
// CHECK: %[[R1:.*]] = arith.subi %[[STEP]], %[[C1]] : index
245+
// CHECK: %[[R2:.*]] = arith.subi %[[END]], %[[BEGIN]] : index
246+
// CHECK: %[[R3:.*]] = arith.addi %[[R2]], %[[R1]] : index
247+
// CHECK: %[[R4:.*]] = arith.divsi %[[R3]], %[[STEP]] : index
248+
// CHECK: %[[R5:.*]] = arith.subi %[[R4]], %[[C1]] : index
249+
// CHECK: %[[R6:.*]] = arith.muli %[[R5]], %[[STEP]] : index
250+
// CHECK: %[[R7:.*]] = arith.addi %[[BEGIN]], %[[R6]] : index
251+
// CHECK: return %[[R7]] : index

0 commit comments

Comments
 (0)