-
Notifications
You must be signed in to change notification settings - Fork 15.4k
[ mlir][scf] Allow 'ult'/'ugt' in uplift #139911
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-scf Author: None (darkbuck) ChangesFull diff: https://github.com/llvm/llvm-project/pull/139911.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp b/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
index ebe718ae4fb61..0fabaf6e63ee4 100644
--- a/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
@@ -91,9 +91,10 @@ FailureOr<scf::ForOp> mlir::scf::upliftWhileToForLoop(RewriterBase &rewriter,
using Pred = arith::CmpIPredicate;
Pred predicate = cmp.getPredicate();
- if (predicate != Pred::slt && predicate != Pred::sgt)
+ if (predicate != Pred::slt && predicate != Pred::sgt &&
+ predicate != Pred::ult && predicate != Pred::ugt)
return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) {
- diag << "Expected 'slt' or 'sgt' predicate: " << *cmp;
+ diag << "Expected 'slt'/'ult' or 'sgt'/'ugt' predicate: " << *cmp;
});
BlockArgument inductionVar;
@@ -103,24 +104,16 @@ FailureOr<scf::ForOp> mlir::scf::upliftWhileToForLoop(RewriterBase &rewriter,
// Check if cmp has a suitable form. One of the arguments must be a `before`
// block arg, other must be defined outside `scf.while` and will be treated
// as upper bound.
- for (bool reverse : {false, true}) {
- auto expectedPred = reverse ? Pred::sgt : Pred::slt;
- if (cmp.getPredicate() != expectedPred)
- continue;
-
- auto arg1 = reverse ? cmp.getRhs() : cmp.getLhs();
- auto arg2 = reverse ? cmp.getLhs() : cmp.getRhs();
-
- auto blockArg = dyn_cast<BlockArgument>(arg1);
- if (!blockArg || blockArg.getOwner() != beforeBody)
- continue;
-
- if (!dom.properlyDominates(arg2, loop))
- continue;
-
+ auto arg1 = cmp.getLhs();
+ auto arg2 = cmp.getRhs();
+ if (predicate == Pred::sgt || predicate == Pred::ugt)
+ std::swap(arg1, arg2);
+
+ auto blockArg = dyn_cast<BlockArgument>(arg1);
+ if (blockArg && blockArg.getOwner() == beforeBody &&
+ dom.properlyDominates(arg2, loop)) {
inductionVar = blockArg;
ub = arg2;
- break;
}
if (!inductionVar)
diff --git a/mlir/test/Dialect/SCF/uplift-while.mlir b/mlir/test/Dialect/SCF/uplift-while.mlir
index cbe2ce5076ad2..f11f5ab28d707 100644
--- a/mlir/test/Dialect/SCF/uplift-while.mlir
+++ b/mlir/test/Dialect/SCF/uplift-while.mlir
@@ -185,3 +185,67 @@ func.func @uplift_while(%arg0: index, %arg1: index, %arg2: index) -> (i32, f32)
// CHECK: %[[T2:.*]] = "test.test2"(%[[ARG2]]) : (f32) -> f32
// CHECK: scf.yield %[[T1]], %[[T2]] : i32, f32
// CHECK: return %[[RES]]#0, %[[RES]]#1 : i32, f32
+
+// -----
+
+func.func @uplift_while(%arg0: index, %arg1: index, %arg2: index) -> index {
+ %0 = scf.while (%arg3 = %arg0) : (index) -> (index) {
+ %1 = arith.cmpi ult, %arg3, %arg1 : index
+ scf.condition(%1) %arg3 : index
+ } do {
+ ^bb0(%arg3: index):
+ "test.test1"(%arg3) : (index) -> ()
+ %added = arith.addi %arg3, %arg2 : index
+ "test.test2"(%added) : (index) -> ()
+ scf.yield %added : index
+ }
+ return %0 : index
+}
+
+// CHECK-LABEL: func @uplift_while
+// CHECK-SAME: (%[[BEGIN:.*]]: index, %[[END:.*]]: index, %[[STEP:.*]]: index) -> index
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: scf.for %[[I:.*]] = %[[BEGIN]] to %[[END]] step %[[STEP]] {
+// CHECK: "test.test1"(%[[I]]) : (index) -> ()
+// CHECK: %[[INC:.*]] = arith.addi %[[I]], %[[STEP]] : index
+// CHECK: "test.test2"(%[[INC]]) : (index) -> ()
+// CHECK: %[[R1:.*]] = arith.subi %[[STEP]], %[[C1]] : index
+// CHECK: %[[R2:.*]] = arith.subi %[[END]], %[[BEGIN]] : index
+// CHECK: %[[R3:.*]] = arith.addi %[[R2]], %[[R1]] : index
+// CHECK: %[[R4:.*]] = arith.divsi %[[R3]], %[[STEP]] : index
+// CHECK: %[[R5:.*]] = arith.subi %[[R4]], %[[C1]] : index
+// CHECK: %[[R6:.*]] = arith.muli %[[R5]], %[[STEP]] : index
+// CHECK: %[[R7:.*]] = arith.addi %[[BEGIN]], %[[R6]] : index
+// CHECK: return %[[R7]] : index
+
+// -----
+
+func.func @uplift_while(%arg0: index, %arg1: index, %arg2: index) -> index {
+ %0 = scf.while (%arg3 = %arg0) : (index) -> (index) {
+ %1 = arith.cmpi ugt, %arg1, %arg3 : index
+ scf.condition(%1) %arg3 : index
+ } do {
+ ^bb0(%arg3: index):
+ "test.test1"(%arg3) : (index) -> ()
+ %added = arith.addi %arg3, %arg2 : index
+ "test.test2"(%added) : (index) -> ()
+ scf.yield %added : index
+ }
+ return %0 : index
+}
+
+// CHECK-LABEL: func @uplift_while
+// CHECK-SAME: (%[[BEGIN:.*]]: index, %[[END:.*]]: index, %[[STEP:.*]]: index) -> index
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: scf.for %[[I:.*]] = %[[BEGIN]] to %[[END]] step %[[STEP]] {
+// CHECK: "test.test1"(%[[I]]) : (index) -> ()
+// CHECK: %[[INC:.*]] = arith.addi %[[I]], %[[STEP]] : index
+// CHECK: "test.test2"(%[[INC]]) : (index) -> ()
+// CHECK: %[[R1:.*]] = arith.subi %[[STEP]], %[[C1]] : index
+// CHECK: %[[R2:.*]] = arith.subi %[[END]], %[[BEGIN]] : index
+// CHECK: %[[R3:.*]] = arith.addi %[[R2]], %[[R1]] : index
+// CHECK: %[[R4:.*]] = arith.divsi %[[R3]], %[[STEP]] : index
+// CHECK: %[[R5:.*]] = arith.subi %[[R4]], %[[C1]] : index
+// CHECK: %[[R6:.*]] = arith.muli %[[R5]], %[[STEP]] : index
+// CHECK: %[[R7:.*]] = arith.addi %[[BEGIN]], %[[R6]] : index
+// CHECK: return %[[R7]] : index
|
| Pred predicate = cmp.getPredicate(); | ||
| if (predicate != Pred::slt && predicate != Pred::sgt) | ||
| if (predicate != Pred::slt && predicate != Pred::sgt && | ||
| predicate != Pred::ult && predicate != Pred::ugt) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is ult/ugt correct? scf::for will do a signed comparison so the new IR is not equivalent.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1, Maybe better approach will be to try to convert unsigned comparisons to signed first (e.g. using int range analysis).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is ult/ugt correct? scf::for will do a signed comparison so the new IR is not equivalent.
Sorry, I may misread that. But, for scf.for, it reads that
151 The `scf.for` operation represents a loop taking 3 SSA value as operands
152 that represent the lower bound, upper bound and step respectively. The
153 operation defines an SSA value for its induction variable. It has one
154 region capturing the loop body. The induction variable is represented as an
155 argument of this region. This SSA value is a signless integer or index.
Is that indvar signless integer?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
signless means the sign information is held by the operation. The semantic of the op is clarified in the doc:
The lower and upper bounds specify a half-open range: the iteration is executed iff the signed comparison of induction variable value is less than the upper bound and bigger or equal to the lower bound.
No description provided.