From 2013ca3703590323e573d001a7b0891cec552645 Mon Sep 17 00:00:00 2001 From: Michael Liao Date: Wed, 14 May 2025 09:02:34 -0400 Subject: [PATCH] [ mlir][scf] Allow 'ult'/'ugt' in uplift --- .../SCF/Transforms/UpliftWhileToFor.cpp | 29 ++++----- mlir/test/Dialect/SCF/uplift-while.mlir | 64 +++++++++++++++++++ 2 files changed, 75 insertions(+), 18 deletions(-) diff --git a/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp b/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp index ebe718ae4fb61..0fabaf6e63ee4 100644 --- a/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp @@ -91,9 +91,10 @@ FailureOr mlir::scf::upliftWhileToForLoop(RewriterBase &rewriter, using Pred = arith::CmpIPredicate; Pred predicate = cmp.getPredicate(); - if (predicate != Pred::slt && predicate != Pred::sgt) + if (predicate != Pred::slt && predicate != Pred::sgt && + predicate != Pred::ult && predicate != Pred::ugt) return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) { - diag << "Expected 'slt' or 'sgt' predicate: " << *cmp; + diag << "Expected 'slt'/'ult' or 'sgt'/'ugt' predicate: " << *cmp; }); BlockArgument inductionVar; @@ -103,24 +104,16 @@ FailureOr mlir::scf::upliftWhileToForLoop(RewriterBase &rewriter, // Check if cmp has a suitable form. One of the arguments must be a `before` // block arg, other must be defined outside `scf.while` and will be treated // as upper bound. - for (bool reverse : {false, true}) { - auto expectedPred = reverse ? Pred::sgt : Pred::slt; - if (cmp.getPredicate() != expectedPred) - continue; - - auto arg1 = reverse ? cmp.getRhs() : cmp.getLhs(); - auto arg2 = reverse ? cmp.getLhs() : cmp.getRhs(); - - auto blockArg = dyn_cast(arg1); - if (!blockArg || blockArg.getOwner() != beforeBody) - continue; - - if (!dom.properlyDominates(arg2, loop)) - continue; - + auto arg1 = cmp.getLhs(); + auto arg2 = cmp.getRhs(); + if (predicate == Pred::sgt || predicate == Pred::ugt) + std::swap(arg1, arg2); + + auto blockArg = dyn_cast(arg1); + if (blockArg && blockArg.getOwner() == beforeBody && + dom.properlyDominates(arg2, loop)) { inductionVar = blockArg; ub = arg2; - break; } if (!inductionVar) diff --git a/mlir/test/Dialect/SCF/uplift-while.mlir b/mlir/test/Dialect/SCF/uplift-while.mlir index cbe2ce5076ad2..f11f5ab28d707 100644 --- a/mlir/test/Dialect/SCF/uplift-while.mlir +++ b/mlir/test/Dialect/SCF/uplift-while.mlir @@ -185,3 +185,67 @@ func.func @uplift_while(%arg0: index, %arg1: index, %arg2: index) -> (i32, f32) // CHECK: %[[T2:.*]] = "test.test2"(%[[ARG2]]) : (f32) -> f32 // CHECK: scf.yield %[[T1]], %[[T2]] : i32, f32 // CHECK: return %[[RES]]#0, %[[RES]]#1 : i32, f32 + +// ----- + +func.func @uplift_while(%arg0: index, %arg1: index, %arg2: index) -> index { + %0 = scf.while (%arg3 = %arg0) : (index) -> (index) { + %1 = arith.cmpi ult, %arg3, %arg1 : index + scf.condition(%1) %arg3 : index + } do { + ^bb0(%arg3: index): + "test.test1"(%arg3) : (index) -> () + %added = arith.addi %arg3, %arg2 : index + "test.test2"(%added) : (index) -> () + scf.yield %added : index + } + return %0 : index +} + +// CHECK-LABEL: func @uplift_while +// CHECK-SAME: (%[[BEGIN:.*]]: index, %[[END:.*]]: index, %[[STEP:.*]]: index) -> index +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: scf.for %[[I:.*]] = %[[BEGIN]] to %[[END]] step %[[STEP]] { +// CHECK: "test.test1"(%[[I]]) : (index) -> () +// CHECK: %[[INC:.*]] = arith.addi %[[I]], %[[STEP]] : index +// CHECK: "test.test2"(%[[INC]]) : (index) -> () +// CHECK: %[[R1:.*]] = arith.subi %[[STEP]], %[[C1]] : index +// CHECK: %[[R2:.*]] = arith.subi %[[END]], %[[BEGIN]] : index +// CHECK: %[[R3:.*]] = arith.addi %[[R2]], %[[R1]] : index +// CHECK: %[[R4:.*]] = arith.divsi %[[R3]], %[[STEP]] : index +// CHECK: %[[R5:.*]] = arith.subi %[[R4]], %[[C1]] : index +// CHECK: %[[R6:.*]] = arith.muli %[[R5]], %[[STEP]] : index +// CHECK: %[[R7:.*]] = arith.addi %[[BEGIN]], %[[R6]] : index +// CHECK: return %[[R7]] : index + +// ----- + +func.func @uplift_while(%arg0: index, %arg1: index, %arg2: index) -> index { + %0 = scf.while (%arg3 = %arg0) : (index) -> (index) { + %1 = arith.cmpi ugt, %arg1, %arg3 : index + scf.condition(%1) %arg3 : index + } do { + ^bb0(%arg3: index): + "test.test1"(%arg3) : (index) -> () + %added = arith.addi %arg3, %arg2 : index + "test.test2"(%added) : (index) -> () + scf.yield %added : index + } + return %0 : index +} + +// CHECK-LABEL: func @uplift_while +// CHECK-SAME: (%[[BEGIN:.*]]: index, %[[END:.*]]: index, %[[STEP:.*]]: index) -> index +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: scf.for %[[I:.*]] = %[[BEGIN]] to %[[END]] step %[[STEP]] { +// CHECK: "test.test1"(%[[I]]) : (index) -> () +// CHECK: %[[INC:.*]] = arith.addi %[[I]], %[[STEP]] : index +// CHECK: "test.test2"(%[[INC]]) : (index) -> () +// CHECK: %[[R1:.*]] = arith.subi %[[STEP]], %[[C1]] : index +// CHECK: %[[R2:.*]] = arith.subi %[[END]], %[[BEGIN]] : index +// CHECK: %[[R3:.*]] = arith.addi %[[R2]], %[[R1]] : index +// CHECK: %[[R4:.*]] = arith.divsi %[[R3]], %[[STEP]] : index +// CHECK: %[[R5:.*]] = arith.subi %[[R4]], %[[C1]] : index +// CHECK: %[[R6:.*]] = arith.muli %[[R5]], %[[STEP]] : index +// CHECK: %[[R7:.*]] = arith.addi %[[BEGIN]], %[[R6]] : index +// CHECK: return %[[R7]] : index