-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[mlir][SCF] Fix dynamic loop pipeline peeling for num_stages > total_iters #112418
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-scf Author: SJW (sjw36) ChangesWhen pipelining an For example: When num_stages=3 the pipeline follows: The trailing Full diff: https://github.com/llvm/llvm-project/pull/112418.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
index 83c9cf69ba0364..be75640b44bd9a 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
@@ -642,22 +642,25 @@ LogicalResult
LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter,
llvm::SmallVector<Value> &returnValues) {
Location loc = forOp.getLoc();
+ Type t = lb.getType();
+
// Emit different versions of the induction variable. They will be
// removed by dead code if not used.
- // bounds_range = ub - lb
- // total_iterations = (bounds_range + step - 1) / step
- Type t = lb.getType();
- Value zero =
- rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, 0));
- Value one =
- rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, 1));
- Value minusOne =
- rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, -1));
+ auto getConst = [&](int v) {
+ return rewriter.create<arith::ConstantOp>(loc,
+ rewriter.getIntegerAttr(t, v));
+ };
+
+ // total_iterations = cdiv(range_diff, step);
+ // - range_diff = ub - lb
+ // - total_iterations = (range_diff + step + (step < 0 ? 1 : -1)) / step
+ Value zero = getConst(0);
+ Value one = getConst(1);
Value stepLessZero = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::slt, step, zero);
Value stepDecr =
- rewriter.create<arith::SelectOp>(loc, stepLessZero, one, minusOne);
+ rewriter.create<arith::SelectOp>(loc, stepLessZero, one, getConst(-1));
Value rangeDiff = rewriter.create<arith::SubIOp>(loc, ub, lb);
Value rangeIncrStep = rewriter.create<arith::AddIOp>(loc, rangeDiff, step);
@@ -665,25 +668,31 @@ LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter,
rewriter.create<arith::AddIOp>(loc, rangeIncrStep, stepDecr);
Value totalIterations = rewriter.create<arith::DivSIOp>(loc, rangeDecr, step);
+ // If total_iters < max_stage, start the epilogue at zero to match the
+ // ramp-up in the prologue.
+ // start_iter = max(0, total_iters - max_stage)
+ Value iterI =
+ rewriter.create<arith::SubIOp>(loc, totalIterations, getConst(maxStage));
+ iterI = rewriter.create<arith::MaxSIOp>(loc, zero, iterI);
+
+ // Capture predicates for dynamic loops.
SmallVector<Value> predicates(maxStage + 1);
- for (int64_t i = 0; i < maxStage; i++) {
- // iterI = total_iters - 1 - i
- // May go negative...
- Value minusI =
- rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, -i));
- Value iterI = rewriter.create<arith::AddIOp>(
- loc, rewriter.create<arith::AddIOp>(loc, totalIterations, minusOne),
- minusI);
+
+ for (int64_t i = 1; i <= maxStage; i++) {
// newLastIter = lb + step * iterI
Value newlastIter = rewriter.create<arith::AddIOp>(
loc, lb, rewriter.create<arith::MulIOp>(loc, step, iterI));
- setValueMapping(forOp.getInductionVar(), newlastIter, maxStage - i);
+ setValueMapping(forOp.getInductionVar(), newlastIter, i);
+
+ // increment to next iterI
+ iterI = rewriter.create<arith::AddIOp>(loc, iterI, one);
if (dynamicLoop) {
- // pred = iterI >= 0
- predicates[i + 1] = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::sge, iterI, zero);
+ // Disable stages when `i` is greater than total_iters.
+ // pred = total_iters >= i
+ predicates[i] = rewriter.create<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::sge, totalIterations, getConst(i));
}
}
diff --git a/mlir/test/Dialect/SCF/loop-pipelining.mlir b/mlir/test/Dialect/SCF/loop-pipelining.mlir
index af49d2afc049ba..c879c83275bf86 100644
--- a/mlir/test/Dialect/SCF/loop-pipelining.mlir
+++ b/mlir/test/Dialect/SCF/loop-pipelining.mlir
@@ -767,6 +767,7 @@ func.func @stage_0_value_escape(%A: memref<?xf32>, %result: memref<?xf32>, %ub:
// Check for predicated epilogue for dynamic loop.
// CHECK-LABEL: dynamic_loop(
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[CM1:.*]] = arith.constant -1 : index
// CHECK: %[[UBM:.*]] = arith.subi %[[UB:.*]], %{{.*}}
@@ -779,32 +780,32 @@ func.func @stage_0_value_escape(%A: memref<?xf32>, %result: memref<?xf32>, %ub:
// CHECK: scf.yield %[[ADDF_24]], %[[LOAD_27]]
// CHECK: }
// CHECK: %[[CMPI_10:.*]] = arith.cmpi slt, %[[STEP]], %[[C0]]
-// CHECK: %[[SEL_10:.*]] = arith.select %[[CMPI_10]], %[[C1]], %[[CM1]]
-// CHECK: %[[SUBI_10:.*]] = arith.subi %[[UB]], %[[LB]]
-// CHECK: %[[ADDI_11:.*]] = arith.addi %[[SUBI_10]], %[[STEP]]
-// CHECK: %[[ADDI_12:.*]] = arith.addi %[[ADDI_11]], %[[SEL_10]]
-// CHECK: %[[DIVSI_13:.*]] = arith.divsi %[[ADDI_12]], %[[STEP]]
-// CHECK: %[[ADDI_14:.*]] = arith.addi %[[DIVSI_13]], %[[CM1]]
-// CHECK: %[[MULI_15:.*]] = arith.muli %{{.*}}, %[[ADDI_14]]
-// CHECK: %[[ADDI_16:.*]] = arith.addi %{{.*}}, %[[MULI_15]]
-// CHECK: %[[CMPI_17:.*]] = arith.cmpi sge, %[[ADDI_14]], %[[C0]]
-// CHECK: %[[ADDI_18:.*]] = arith.addi %[[DIVSI_13]], %{{.*}}-1
-// CHECK: %[[ADDI_19:.*]] = arith.addi %[[ADDI_18]], %{{.*}}-1
-// CHECK: %[[MULI_20:.*]] = arith.muli %{{.*}}, %[[ADDI_19]]
-// CHECK: %[[ADDI_21:.*]] = arith.addi %{{.*}}, %[[MULI_20]]
-// CHECK: %[[CMPI_22:.*]] = arith.cmpi sge, %[[ADDI_19]], %[[C0]]
-// CHECK: scf.if %[[CMPI_17]] {
-// CHECK: memref.store %{{.*}}#0, %{{.*}}[%[[ADDI_21]]]
+// CHECK: %[[SELECT_11:.*]] = arith.select %[[CMPI_10]], %[[C1]], %[[CM1]]
+// CHECK: %[[SUBI_12:.*]] = arith.subi %[[UB]], %[[LB]]
+// CHECK: %[[ADDI_13:.*]] = arith.addi %[[SUBI_12]], %[[STEP]]
+// CHECK: %[[ADDI_14:.*]] = arith.addi %[[ADDI_13]], %[[SELECT_11]]
+// CHECK: %[[DIVSI_15:.*]] = arith.divsi %[[ADDI_14]], %[[STEP]]
+// CHECK: %[[SUBI_17:.*]] = arith.subi %[[DIVSI_15]], %[[C2]]
+// CHECK: %[[MAXSI_18:.*]] = arith.maxsi %[[SUBI_17]], %[[C0]]
+// CHECK: %[[MULI_19:.*]] = arith.muli %[[STEP]], %[[MAXSI_18]]
+// CHECK: %[[ADDI_20:.*]] = arith.addi %[[LB]], %[[MULI_19]]
+// CHECK: %[[ADDI_21:.*]] = arith.addi %[[MAXSI_18]], %[[C1]]
+// CHECK: %[[CMPI_22:.*]] = arith.cmpi sge, %[[DIVSI_15]], %[[C1]]
+// CHECK: %[[MULI_23:.*]] = arith.muli %[[STEP]], %[[ADDI_21]]
+// CHECK: %[[ADDI_24:.*]] = arith.addi %[[LB]], %[[MULI_23]]
+// CHECK: %[[CMPI_25:.*]] = arith.cmpi sge, %[[DIVSI_15]], %[[C2]]
+// CHECK: scf.if %[[CMPI_22]] {
+// CHECK: memref.store %{{.*}}#0, %{{.*}}[%[[ADDI_20]]]
// CHECK: } else {
// CHECK: }
-// CHECK: %[[IF_23:.*]] = scf.if %[[CMPI_22]] -> (f32) {
-// CHECK: %[[ADDF_24:.*]] = arith.addf %{{.*}}#1, %{{.*}}
-// CHECK: scf.yield %[[ADDF_24]]
+// CHECK: %[[IF_26:.*]] = scf.if %[[CMPI_25]]
+// CHECK: %[[ADDF_27:.*]] = arith.addf %{{.*}}#1, %{{.*}}
+// CHECK: scf.yield %[[ADDF_27]]
// CHECK: } else {
// CHECK: scf.yield %{{.*}}
// CHECK: }
-// CHECK: scf.if %[[CMPI_22]] {
-// CHECK: memref.store %[[IF_23]], %{{.*}}[%[[ADDI_16]]]
+// CHECK: scf.if %[[CMPI_25]] {
+// CHECK: memref.store %[[IF_26]], %{{.*}}[%[[ADDI_24]]]
// CHECK: } else {
// CHECK: }
// CHECK: return
@@ -842,6 +843,7 @@ func.func @dynamic_loop(%A: memref<?xf32>, %result: memref<?xf32>, %lb: index, %
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[CM1:.*]] = arith.constant -1 : index
+// CHECK-DAG: %[[CF0:.*]] = arith.constant 0.000000e+00
// CHECK: %[[UBM:.*]] = arith.subi %[[UB:.*]], %{{.*}}
// CHECK: %{{.*}}:2 = scf.for %[[ARG5:.*]] = %[[LB:.*]] to %[[UBM]] step %[[STEP:.*]] iter_args(%[[ARG6:.*]] = %{{.*}}, %[[ARG7:.*]] = %{{.*}})
// CHECK: %[[ADDF_13:.*]] = arith.addf %[[ARG7]], %[[ARG6]]
@@ -856,22 +858,21 @@ func.func @dynamic_loop(%A: memref<?xf32>, %result: memref<?xf32>, %lb: index, %
// CHECK: %[[ADDI_7:.*]] = arith.addi %[[SUBI_6]], %[[STEP]]
// CHECK: %[[ADDI_8:.*]] = arith.addi %[[ADDI_7]], %[[SELECT_5]]
// CHECK: %[[DIVSI_9:.*]] = arith.divsi %[[ADDI_8]], %[[STEP]]
-// CHECK: %[[ADDI_10:.*]] = arith.addi %[[DIVSI_9]], %[[CM1]]
-// CHECK: %[[CMPI_11:.*]] = arith.cmpi sge, %[[ADDI_10]], %[[C0]]
-// CHECK: %[[IF_10:.*]] = scf.if %[[CMPI_11]]
-// CHECK: %[[ADDF_13:.*]] = arith.addf %{{.*}}#1, %{{.*}}#0
-// CHECK: scf.yield %[[ADDF_13]]
+// CHECK: %[[CMPI_10:.*]] = arith.cmpi sge, %[[DIVSI_9]], %[[C1]]
+// CHECK: %[[IF_11:.*]] = scf.if %[[CMPI_10]]
+// CHECK: %[[ADDF_14:.*]] = arith.addf %{{.*}}#1, %{{.*}}#0
+// CHECK: scf.yield %[[ADDF_14]]
// CHECK: } else {
-// CHECK: scf.yield %{{.*}}
+// CHECK: scf.yield %[[CF0]]
// CHECK: }
-// CHECK: %[[IF_11:.*]] = scf.if %[[CMPI_11]]
-// CHECK: %[[MULF_13:.*]] = arith.mulf %[[IF_10]], %{{.*}}
-// CHECK: scf.yield %[[MULF_13]]
+// CHECK: %[[IF_12:.*]] = scf.if %[[CMPI_10]]
+// CHECK: %[[MULF_14:.*]] = arith.mulf %[[IF_11]], %{{.*}}
+// CHECK: scf.yield %[[MULF_14]]
// CHECK: } else {
-// CHECK: scf.yield %{{.*}}
+// CHECK: scf.yield %[[CF0]]
// CHECK: }
-// CHECK: %[[SELECT_12:.*]] = arith.select %[[CMPI_11]], %[[IF_11]], %{{.*}}#0
-// CHECK: memref.store %[[SELECT_12]], %{{.*}}[%{{.*}}]
+// CHECK: %[[SELECT_13:.*]] = arith.select %[[CMPI_10]], %[[IF_12]], %{{.*}}#0
+// CHECK: memref.store %[[SELECT_13]], %{{.*}}[%[[C0]]]
func.func @dynamic_loop_result(%A: memref<?xf32>, %result: memref<?xf32>, %lb: index, %ub: index, %step: index) {
%cf0 = arith.constant 1.0 : f32
%cf1 = arith.constant 33.0 : f32
|
ThomasRaoux
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks
| rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, 1)); | ||
| Value minusOne = | ||
| rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, -1)); | ||
| auto getConst = [&](int v) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: maybe rename to createConst?
- The epilogue ramp-down indexing must start at zero or greater (total_iterations - max_stage) to ensure alignment with the prologue ramp-up stages. - If total_iterations < max_stage, the trailing stages will be masked. This commit mirrors upstream llvm/llvm-project#112418 and adds a functional test for correctness with num_stages=1,2,3,4.
…iters (llvm#112418) When pipelining an `scf.for` with dynamic loop bounds, the epilogue ramp-down must align with the prologue when num_stages > total_iterations. For example: ``` scf.for (0..ub) { load(i) add(i) store(i) } ``` When num_stages=3 the pipeline follows: ``` load(0) - add(0) - scf.for (0..ub-2) - store(ub-2) load(1) - - add(ub-1) - store(ub-1) ``` The trailing `store(ub-2)`, `i=ub-2`, must align with the ramp-up for `i=0` when `ub < num_stages-1`, so the index `i` should be `max(0, ub-2)` and each subsequent index is an increment. The predicate must also handle this scenario, so it becomes `predicate[0] = total_iterations > epilogue_stage`.
- The epilogue ramp-down indexing must start at zero or greater (total_iterations - max_stage) to ensure alignment with the prologue ramp-up stages. - If total_iterations < max_stage, the trailing stages will be masked. This commit mirrors upstream llvm/llvm-project#112418 and adds a functional test for correctness with num_stages=1,2,3,4.
- The epilogue ramp-down indexing must start at zero or greater (total_iterations - max_stage) to ensure alignment with the prologue ramp-up stages. - If total_iterations < max_stage, the trailing stages will be masked. This commit mirrors upstream llvm/llvm-project#112418 and adds a functional test for correctness with num_stages=1,2,3,4.
- The epilogue ramp-down indexing must start at zero or greater (total_iterations - max_stage) to ensure alignment with the prologue ramp-up stages. - If total_iterations < max_stage, the trailing stages will be masked. This commit mirrors upstream llvm/llvm-project#112418 and adds a functional test for correctness with num_stages=1,2,3,4.
- The epilogue ramp-down indexing must start at zero or greater (total_iterations - max_stage) to ensure alignment with the prologue ramp-up stages. - If total_iterations < max_stage, the trailing stages will be masked. This commit mirrors upstream llvm/llvm-project#112418 and adds a functional test for correctness with num_stages=1,2,3,4.
When pipelining an
scf.forwith dynamic loop bounds, the epilogue ramp-down must align with the prologue when num_stages > total_iterations.For example:
When num_stages=3 the pipeline follows:
The trailing
store(ub-2),i=ub-2, must align with the ramp-up fori=0whenub < num_stages-1, so the indexishould bemax(0, ub-2)and each subsequent index is an increment. The predicate must also handle this scenario, so it becomespredicate[0] = total_iterations > epilogue_stage.