Skip to content

Commit fa089b0

Browse files
authored
[SCF] Fixed epilogue predicates in loop pipelining (#108964)
The computed loop iteration is zero based, so only check it is less than zero. This fixes the case when lower bound is not zero.
1 parent b30b9eb commit fa089b0

File tree

2 files changed

+8
-4
lines changed

2 files changed

+8
-4
lines changed

mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -655,6 +655,9 @@ LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter,
655655
Value rangeDecr = rewriter.create<arith::AddIOp>(loc, rangeIncr, minus1);
656656
Value totalIterations = rewriter.create<arith::DivUIOp>(loc, rangeDecr, step);
657657

658+
Value zero =
659+
rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, 0));
660+
658661
SmallVector<Value> predicates(maxStage + 1);
659662
for (int64_t i = 0; i < maxStage; i++) {
660663
// iterI = total_iters - 1 - i
@@ -671,9 +674,9 @@ LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter,
671674
setValueMapping(forOp.getInductionVar(), newlastIter, maxStage - i);
672675

673676
if (dynamicLoop) {
674-
// pred = iterI >= lb
677+
// pred = iterI >= 0
675678
predicates[i + 1] = rewriter.create<arith::CmpIOp>(
676-
loc, arith::CmpIPredicate::sge, iterI, lb);
679+
loc, arith::CmpIPredicate::sge, iterI, zero);
677680
}
678681
}
679682

mlir/test/Dialect/SCF/loop-pipelining.mlir

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -766,6 +766,7 @@ func.func @stage_0_value_escape(%A: memref<?xf32>, %result: memref<?xf32>, %ub:
766766

767767
// Check for predicated epilogue for dynamic loop.
768768
// CHECK-LABEL: dynamic_loop(
769+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
769770
// CHECK: %{{.*}}:2 = scf.for %[[ARG5:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG6:.*]] = %{{.*}}, %[[ARG7:.*]] = %{{.*}})
770771
// CHECK: memref.store %[[ARG6]], %{{.*}}[%[[ARG5]]]
771772
// CHECK: %[[ADDF_24:.*]] = arith.addf %[[ARG7]], %{{.*}}
@@ -781,12 +782,12 @@ func.func @stage_0_value_escape(%A: memref<?xf32>, %result: memref<?xf32>, %ub:
781782
// CHECK: %[[ADDI_14:.*]] = arith.addi %[[DIVUI_13]], %{{.*}}-1
782783
// CHECK: %[[MULI_15:.*]] = arith.muli %{{.*}}, %[[ADDI_14]]
783784
// CHECK: %[[ADDI_16:.*]] = arith.addi %{{.*}}, %[[MULI_15]]
784-
// CHECK: %[[CMPI_17:.*]] = arith.cmpi sge, %[[ADDI_14]], %{{.*}}
785+
// CHECK: %[[CMPI_17:.*]] = arith.cmpi sge, %[[ADDI_14]], %[[C0]]
785786
// CHECK: %[[ADDI_18:.*]] = arith.addi %[[DIVUI_13]], %{{.*}}-1
786787
// CHECK: %[[ADDI_19:.*]] = arith.addi %[[ADDI_18]], %{{.*}}-1
787788
// CHECK: %[[MULI_20:.*]] = arith.muli %{{.*}}, %[[ADDI_19]]
788789
// CHECK: %[[ADDI_21:.*]] = arith.addi %{{.*}}, %[[MULI_20]]
789-
// CHECK: %[[CMPI_22:.*]] = arith.cmpi sge, %[[ADDI_19]], %{{.*}}
790+
// CHECK: %[[CMPI_22:.*]] = arith.cmpi sge, %[[ADDI_19]], %[[C0]]
790791
// CHECK: scf.if %[[CMPI_17]] {
791792
// CHECK: memref.store %{{.*}}#0, %{{.*}}[%[[ADDI_21]]]
792793
// CHECK: } else {

0 commit comments

Comments
 (0)