Skip to content

Commit 8da5aa1

Browse files
authored
[mlir][SCF] Fix dynamic loop pipeline peeling for num_stages > total_iters (#112418)
When pipelining an `scf.for` with dynamic loop bounds, the epilogue ramp-down must align with the prologue when num_stages > total_iterations. For example: ``` scf.for (0..ub) { load(i) add(i) store(i) } ``` When num_stages=3 the pipeline follows: ``` load(0) - add(0) - scf.for (0..ub-2) - store(ub-2) load(1) - - add(ub-1) - store(ub-1) ``` The trailing `store(ub-2)`, `i=ub-2`, must align with the ramp-up for `i=0` when `ub < num_stages-1`, so the index `i` should be `max(0, ub-2)` and each subsequent index is an increment. The predicate must also handle this scenario, so it becomes `predicate[0] = total_iterations > epilogue_stage`.
1 parent de7f7ea commit 8da5aa1

File tree

2 files changed

+65
-55
lines changed

2 files changed

+65
-55
lines changed

mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp

Lines changed: 31 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -642,48 +642,57 @@ LogicalResult
642642
LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter,
643643
llvm::SmallVector<Value> &returnValues) {
644644
Location loc = forOp.getLoc();
645+
Type t = lb.getType();
646+
645647
// Emit different versions of the induction variable. They will be
646648
// removed by dead code if not used.
647649

648-
// bounds_range = ub - lb
649-
// total_iterations = (bounds_range + step - 1) / step
650-
Type t = lb.getType();
651-
Value zero =
652-
rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, 0));
653-
Value one =
654-
rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, 1));
655-
Value minusOne =
656-
rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, -1));
650+
auto createConst = [&](int v) {
651+
return rewriter.create<arith::ConstantOp>(loc,
652+
rewriter.getIntegerAttr(t, v));
653+
};
654+
655+
// total_iterations = cdiv(range_diff, step);
656+
// - range_diff = ub - lb
657+
// - total_iterations = (range_diff + step + (step < 0 ? 1 : -1)) / step
658+
Value zero = createConst(0);
659+
Value one = createConst(1);
657660
Value stepLessZero = rewriter.create<arith::CmpIOp>(
658661
loc, arith::CmpIPredicate::slt, step, zero);
659662
Value stepDecr =
660-
rewriter.create<arith::SelectOp>(loc, stepLessZero, one, minusOne);
663+
rewriter.create<arith::SelectOp>(loc, stepLessZero, one, createConst(-1));
661664

662665
Value rangeDiff = rewriter.create<arith::SubIOp>(loc, ub, lb);
663666
Value rangeIncrStep = rewriter.create<arith::AddIOp>(loc, rangeDiff, step);
664667
Value rangeDecr =
665668
rewriter.create<arith::AddIOp>(loc, rangeIncrStep, stepDecr);
666669
Value totalIterations = rewriter.create<arith::DivSIOp>(loc, rangeDecr, step);
667670

671+
// If total_iters < max_stage, start the epilogue at zero to match the
672+
// ramp-up in the prologue.
673+
// start_iter = max(0, total_iters - max_stage)
674+
Value iterI = rewriter.create<arith::SubIOp>(loc, totalIterations,
675+
createConst(maxStage));
676+
iterI = rewriter.create<arith::MaxSIOp>(loc, zero, iterI);
677+
678+
// Capture predicates for dynamic loops.
668679
SmallVector<Value> predicates(maxStage + 1);
669-
for (int64_t i = 0; i < maxStage; i++) {
670-
// iterI = total_iters - 1 - i
671-
// May go negative...
672-
Value minusI =
673-
rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, -i));
674-
Value iterI = rewriter.create<arith::AddIOp>(
675-
loc, rewriter.create<arith::AddIOp>(loc, totalIterations, minusOne),
676-
minusI);
680+
681+
for (int64_t i = 1; i <= maxStage; i++) {
677682
// newLastIter = lb + step * iterI
678683
Value newlastIter = rewriter.create<arith::AddIOp>(
679684
loc, lb, rewriter.create<arith::MulIOp>(loc, step, iterI));
680685

681-
setValueMapping(forOp.getInductionVar(), newlastIter, maxStage - i);
686+
setValueMapping(forOp.getInductionVar(), newlastIter, i);
687+
688+
// increment to next iterI
689+
iterI = rewriter.create<arith::AddIOp>(loc, iterI, one);
682690

683691
if (dynamicLoop) {
684-
// pred = iterI >= 0
685-
predicates[i + 1] = rewriter.create<arith::CmpIOp>(
686-
loc, arith::CmpIPredicate::sge, iterI, zero);
692+
// Disable stages when `i` is greater than total_iters.
693+
// pred = total_iters >= i
694+
predicates[i] = rewriter.create<arith::CmpIOp>(
695+
loc, arith::CmpIPredicate::sge, totalIterations, createConst(i));
687696
}
688697
}
689698

mlir/test/Dialect/SCF/loop-pipelining.mlir

Lines changed: 34 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -767,6 +767,7 @@ func.func @stage_0_value_escape(%A: memref<?xf32>, %result: memref<?xf32>, %ub:
767767
// Check for predicated epilogue for dynamic loop.
768768
// CHECK-LABEL: dynamic_loop(
769769
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
770+
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
770771
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
771772
// CHECK-DAG: %[[CM1:.*]] = arith.constant -1 : index
772773
// CHECK: %[[UBM:.*]] = arith.subi %[[UB:.*]], %{{.*}}
@@ -779,32 +780,32 @@ func.func @stage_0_value_escape(%A: memref<?xf32>, %result: memref<?xf32>, %ub:
779780
// CHECK: scf.yield %[[ADDF_24]], %[[LOAD_27]]
780781
// CHECK: }
781782
// CHECK: %[[CMPI_10:.*]] = arith.cmpi slt, %[[STEP]], %[[C0]]
782-
// CHECK: %[[SEL_10:.*]] = arith.select %[[CMPI_10]], %[[C1]], %[[CM1]]
783-
// CHECK: %[[SUBI_10:.*]] = arith.subi %[[UB]], %[[LB]]
784-
// CHECK: %[[ADDI_11:.*]] = arith.addi %[[SUBI_10]], %[[STEP]]
785-
// CHECK: %[[ADDI_12:.*]] = arith.addi %[[ADDI_11]], %[[SEL_10]]
786-
// CHECK: %[[DIVSI_13:.*]] = arith.divsi %[[ADDI_12]], %[[STEP]]
787-
// CHECK: %[[ADDI_14:.*]] = arith.addi %[[DIVSI_13]], %[[CM1]]
788-
// CHECK: %[[MULI_15:.*]] = arith.muli %{{.*}}, %[[ADDI_14]]
789-
// CHECK: %[[ADDI_16:.*]] = arith.addi %{{.*}}, %[[MULI_15]]
790-
// CHECK: %[[CMPI_17:.*]] = arith.cmpi sge, %[[ADDI_14]], %[[C0]]
791-
// CHECK: %[[ADDI_18:.*]] = arith.addi %[[DIVSI_13]], %{{.*}}-1
792-
// CHECK: %[[ADDI_19:.*]] = arith.addi %[[ADDI_18]], %{{.*}}-1
793-
// CHECK: %[[MULI_20:.*]] = arith.muli %{{.*}}, %[[ADDI_19]]
794-
// CHECK: %[[ADDI_21:.*]] = arith.addi %{{.*}}, %[[MULI_20]]
795-
// CHECK: %[[CMPI_22:.*]] = arith.cmpi sge, %[[ADDI_19]], %[[C0]]
796-
// CHECK: scf.if %[[CMPI_17]] {
797-
// CHECK: memref.store %{{.*}}#0, %{{.*}}[%[[ADDI_21]]]
783+
// CHECK: %[[SELECT_11:.*]] = arith.select %[[CMPI_10]], %[[C1]], %[[CM1]]
784+
// CHECK: %[[SUBI_12:.*]] = arith.subi %[[UB]], %[[LB]]
785+
// CHECK: %[[ADDI_13:.*]] = arith.addi %[[SUBI_12]], %[[STEP]]
786+
// CHECK: %[[ADDI_14:.*]] = arith.addi %[[ADDI_13]], %[[SELECT_11]]
787+
// CHECK: %[[DIVSI_15:.*]] = arith.divsi %[[ADDI_14]], %[[STEP]]
788+
// CHECK: %[[SUBI_17:.*]] = arith.subi %[[DIVSI_15]], %[[C2]]
789+
// CHECK: %[[MAXSI_18:.*]] = arith.maxsi %[[SUBI_17]], %[[C0]]
790+
// CHECK: %[[MULI_19:.*]] = arith.muli %[[STEP]], %[[MAXSI_18]]
791+
// CHECK: %[[ADDI_20:.*]] = arith.addi %[[LB]], %[[MULI_19]]
792+
// CHECK: %[[ADDI_21:.*]] = arith.addi %[[MAXSI_18]], %[[C1]]
793+
// CHECK: %[[CMPI_22:.*]] = arith.cmpi sge, %[[DIVSI_15]], %[[C1]]
794+
// CHECK: %[[MULI_23:.*]] = arith.muli %[[STEP]], %[[ADDI_21]]
795+
// CHECK: %[[ADDI_24:.*]] = arith.addi %[[LB]], %[[MULI_23]]
796+
// CHECK: %[[CMPI_25:.*]] = arith.cmpi sge, %[[DIVSI_15]], %[[C2]]
797+
// CHECK: scf.if %[[CMPI_22]] {
798+
// CHECK: memref.store %{{.*}}#0, %{{.*}}[%[[ADDI_20]]]
798799
// CHECK: } else {
799800
// CHECK: }
800-
// CHECK: %[[IF_23:.*]] = scf.if %[[CMPI_22]] -> (f32) {
801-
// CHECK: %[[ADDF_24:.*]] = arith.addf %{{.*}}#1, %{{.*}}
802-
// CHECK: scf.yield %[[ADDF_24]]
801+
// CHECK: %[[IF_26:.*]] = scf.if %[[CMPI_25]]
802+
// CHECK: %[[ADDF_27:.*]] = arith.addf %{{.*}}#1, %{{.*}}
803+
// CHECK: scf.yield %[[ADDF_27]]
803804
// CHECK: } else {
804805
// CHECK: scf.yield %{{.*}}
805806
// CHECK: }
806-
// CHECK: scf.if %[[CMPI_22]] {
807-
// CHECK: memref.store %[[IF_23]], %{{.*}}[%[[ADDI_16]]]
807+
// CHECK: scf.if %[[CMPI_25]] {
808+
// CHECK: memref.store %[[IF_26]], %{{.*}}[%[[ADDI_24]]]
808809
// CHECK: } else {
809810
// CHECK: }
810811
// CHECK: return
@@ -842,6 +843,7 @@ func.func @dynamic_loop(%A: memref<?xf32>, %result: memref<?xf32>, %lb: index, %
842843
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
843844
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
844845
// CHECK-DAG: %[[CM1:.*]] = arith.constant -1 : index
846+
// CHECK-DAG: %[[CF0:.*]] = arith.constant 0.000000e+00
845847
// CHECK: %[[UBM:.*]] = arith.subi %[[UB:.*]], %{{.*}}
846848
// CHECK: %{{.*}}:2 = scf.for %[[ARG5:.*]] = %[[LB:.*]] to %[[UBM]] step %[[STEP:.*]] iter_args(%[[ARG6:.*]] = %{{.*}}, %[[ARG7:.*]] = %{{.*}})
847849
// CHECK: %[[ADDF_13:.*]] = arith.addf %[[ARG7]], %[[ARG6]]
@@ -856,22 +858,21 @@ func.func @dynamic_loop(%A: memref<?xf32>, %result: memref<?xf32>, %lb: index, %
856858
// CHECK: %[[ADDI_7:.*]] = arith.addi %[[SUBI_6]], %[[STEP]]
857859
// CHECK: %[[ADDI_8:.*]] = arith.addi %[[ADDI_7]], %[[SELECT_5]]
858860
// CHECK: %[[DIVSI_9:.*]] = arith.divsi %[[ADDI_8]], %[[STEP]]
859-
// CHECK: %[[ADDI_10:.*]] = arith.addi %[[DIVSI_9]], %[[CM1]]
860-
// CHECK: %[[CMPI_11:.*]] = arith.cmpi sge, %[[ADDI_10]], %[[C0]]
861-
// CHECK: %[[IF_10:.*]] = scf.if %[[CMPI_11]]
862-
// CHECK: %[[ADDF_13:.*]] = arith.addf %{{.*}}#1, %{{.*}}#0
863-
// CHECK: scf.yield %[[ADDF_13]]
861+
// CHECK: %[[CMPI_10:.*]] = arith.cmpi sge, %[[DIVSI_9]], %[[C1]]
862+
// CHECK: %[[IF_11:.*]] = scf.if %[[CMPI_10]]
863+
// CHECK: %[[ADDF_14:.*]] = arith.addf %{{.*}}#1, %{{.*}}#0
864+
// CHECK: scf.yield %[[ADDF_14]]
864865
// CHECK: } else {
865-
// CHECK: scf.yield %{{.*}}
866+
// CHECK: scf.yield %[[CF0]]
866867
// CHECK: }
867-
// CHECK: %[[IF_11:.*]] = scf.if %[[CMPI_11]]
868-
// CHECK: %[[MULF_13:.*]] = arith.mulf %[[IF_10]], %{{.*}}
869-
// CHECK: scf.yield %[[MULF_13]]
868+
// CHECK: %[[IF_12:.*]] = scf.if %[[CMPI_10]]
869+
// CHECK: %[[MULF_14:.*]] = arith.mulf %[[IF_11]], %{{.*}}
870+
// CHECK: scf.yield %[[MULF_14]]
870871
// CHECK: } else {
871-
// CHECK: scf.yield %{{.*}}
872+
// CHECK: scf.yield %[[CF0]]
872873
// CHECK: }
873-
// CHECK: %[[SELECT_12:.*]] = arith.select %[[CMPI_11]], %[[IF_11]], %{{.*}}#0
874-
// CHECK: memref.store %[[SELECT_12]], %{{.*}}[%{{.*}}]
874+
// CHECK: %[[SELECT_13:.*]] = arith.select %[[CMPI_10]], %[[IF_12]], %{{.*}}#0
875+
// CHECK: memref.store %[[SELECT_13]], %{{.*}}[%[[C0]]]
875876
func.func @dynamic_loop_result(%A: memref<?xf32>, %result: memref<?xf32>, %lb: index, %ub: index, %step: index) {
876877
%cf0 = arith.constant 1.0 : f32
877878
%cf1 = arith.constant 33.0 : f32

0 commit comments

Comments
 (0)