Skip to content

Commit 4d711bd

Browse files
authored
[Pipeliner] Fix loop iteration calculation for negative step (#4786)
This fixes loop iteration count calculation if the step is a negative value, where we should adjust the added delta from `step-1` to `step+1` when doing the ceil div.
1 parent 6152840 commit 4d711bd

File tree

2 files changed

+35
-23
lines changed

2 files changed

+35
-23
lines changed

lib/Dialect/TritonGPU/Transforms/Pipeliner/PipelineExpander.cpp

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -657,18 +657,25 @@ LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter,
657657
// Emit different versions of the induction variable. They will be
658658
// removed by dead code if not used.
659659

660-
// bounds_range = ub - lb
661-
// total_iterations = (bounds_range + step - 1) / step
660+
// range_diff = ub - lb
661+
// total_iterations = (range_diff + step + (step < 0 ? 1 : -1)) / step
662662
Type t = lb.getType();
663-
Value minus1 =
664-
rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, -1));
665-
Value boundsRange = rewriter.create<arith::SubIOp>(loc, ub, lb);
666-
Value rangeIncr = rewriter.create<arith::AddIOp>(loc, boundsRange, step);
667-
Value rangeDecr = rewriter.create<arith::AddIOp>(loc, rangeIncr, minus1);
668-
Value totalIterations = rewriter.create<arith::DivUIOp>(loc, rangeDecr, step);
669-
670663
Value zero =
671664
rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, 0));
665+
Value one =
666+
rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, 1));
667+
Value minusOne =
668+
rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, -1));
669+
Value stepLessZero = rewriter.create<arith::CmpIOp>(
670+
loc, arith::CmpIPredicate::slt, step, zero);
671+
Value stepDecr =
672+
rewriter.create<arith::SelectOp>(loc, stepLessZero, one, minusOne);
673+
674+
Value rangeDiff = rewriter.create<arith::SubIOp>(loc, ub, lb);
675+
Value rangeIncrStep = rewriter.create<arith::AddIOp>(loc, rangeDiff, step);
676+
Value rangeDecr =
677+
rewriter.create<arith::AddIOp>(loc, rangeIncrStep, stepDecr);
678+
Value totalIterations = rewriter.create<arith::DivUIOp>(loc, rangeDecr, step);
672679

673680
// Capture predicates for dynamic loops.
674681
SmallVector<Value> predicates(maxStage + 1);
@@ -679,7 +686,7 @@ LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter,
679686
Value minusI =
680687
rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, -i));
681688
Value iterI = rewriter.create<arith::AddIOp>(
682-
loc, rewriter.create<arith::AddIOp>(loc, totalIterations, minus1),
689+
loc, rewriter.create<arith::AddIOp>(loc, totalIterations, minusOne),
683690
minusI);
684691
// newLastIter = lb + step * iterI
685692
Value newlastIter = rewriter.create<arith::AddIOp>(

test/TritonGPU/loop-pipeline.mlir

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,11 @@
5757
// CHECK: scf.yield {{.*}}, %[[INS_IDX_3]], %[[EXT_IDX_3]], %[[NEXT_A]], %[[NEXT_B]]
5858

5959
// AMD-LABEL: tt.func @matmul_loop
60+
// AMD-DAG: %[[CM1:.*]] = arith.constant -1 : index
61+
// AMD-DAG: %[[C1:.*]] = arith.constant 1 : index
6062
// AMD-DAG: %[[C0:.*]] = arith.constant 0 : index
61-
// AMD: %{{.*}}:6 = scf.for %[[ARG5:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG6:.*]] = %{{.*}}, %[[ARG7:.*]] = %{{.*}}, %[[ARG8:.*]] = %{{.*}}, %[[ARG9:.*]] = %{{.*}}, %[[ARG10:.*]] = %{{.*}}, %[[ARG11:.*]] = %{{.*}})
63+
// AMD: %[[UB1:.*]] = arith.subi %[[UB:.*]], %arg2 : index
64+
// AMD: %[[FOR:.*]]:6 = scf.for %[[ARG5:.*]] = %[[LB:.*]] to %[[UB1]] step %[[STEP:.*]] iter_args(%[[ARG6:.*]] = %{{.*}}, %[[ARG7:.*]] = %{{.*}}, %[[ARG8:.*]] = %{{.*}}, %[[ARG9:.*]] = %{{.*}}, %[[ARG10:.*]] = %{{.*}}, %[[ARG11:.*]] = %{{.*}})
6265
// AMD: %[[LOCAL_LOAD_32:.*]] = triton_gpu.local_load %[[ARG10]]
6366
// AMD: %[[LOCAL_LOAD_33:.*]] = triton_gpu.local_load %[[ARG11]]
6467
// AMD: %[[MULF_34:.*]] = arith.mulf %[[LOCAL_LOAD_33]], %{{.*}}
@@ -76,22 +79,24 @@
7679
// AMD: triton_gpu.local_store %[[LOAD_39]], %[[MEMDESC_SUBVIEW_44]]
7780
// AMD: scf.yield %[[ADDPTR_36]], %[[ADDPTR_37]], %[[DOT_35]], %[[SELECT_42]], %[[MEMDESC_SUBVIEW_43]], %[[MEMDESC_SUBVIEW_44]]
7881
// AMD: }
79-
// AMD: %[[SUBI_21:.*]] = arith.subi %{{.*}}, %{{.*}}
80-
// AMD: %[[ADDI_22:.*]] = arith.addi %[[SUBI_21]], %{{.*}}
81-
// AMD: %[[ADDI_23:.*]] = arith.addi %[[ADDI_22]], %{{.*}}-1
82-
// AMD: %[[DIVUI_24:.*]] = arith.divui %[[ADDI_23]], %{{.*}}
83-
// AMD: %[[ADDI_25:.*]] = arith.addi %[[DIVUI_24]], %{{.*}}-1
84-
// AMD: %[[CMPI_26:.*]] = arith.cmpi sge, %[[ADDI_25]], %[[C0]]
85-
// AMD: %[[LOCAL_LOAD_27:.*]] = triton_gpu.local_load %{{.*}}#4
86-
// AMD: %[[LOCAL_LOAD_28:.*]] = triton_gpu.local_load %{{.*}}#5
82+
// AMD: %[[CMPI_21:.*]] = arith.cmpi slt, %[[STEP]], %[[C0]]
83+
// AMD: %[[SELECT_22:.*]] = arith.select %[[CMPI_21]], %[[C1]], %[[CM1]]
84+
// AMD: %[[SUBI_23:.*]] = arith.subi %[[UB]], %[[LB]]
85+
// AMD: %[[ADDI_24:.*]] = arith.addi %[[SUBI_23]], %[[STEP]]
86+
// AMD: %[[ADDI_25:.*]] = arith.addi %[[ADDI_24]], %[[SELECT_22]]
87+
// AMD: %[[DIVUI_26:.*]] = arith.divui %[[ADDI_25]], %[[STEP]]
88+
// AMD: %[[ADDI_27:.*]] = arith.addi %[[DIVUI_26]], %[[CM1]]
89+
// AMD: %[[CMPI_28:.*]] = arith.cmpi sge, %[[ADDI_27]], %[[C0]]
90+
// AMD: %[[LOCAL_LOAD_27:.*]] = triton_gpu.local_load %[[FOR]]#4
91+
// AMD: %[[LOCAL_LOAD_28:.*]] = triton_gpu.local_load %[[FOR]]#5
8792
// AMD: %[[MULF_29:.*]] = arith.mulf %[[LOCAL_LOAD_28]], %{{.*}}
88-
// AMD: %[[IF_30:.*]] = scf.if %[[CMPI_26]]
89-
// AMD: %[[DOT_32:.*]] = tt.dot %[[LOCAL_LOAD_27]], %[[MULF_29]], %{{.*}}#2
93+
// AMD: %[[IF_30:.*]] = scf.if %[[CMPI_28]]
94+
// AMD: %[[DOT_32:.*]] = tt.dot %[[LOCAL_LOAD_27]], %[[MULF_29]], %[[FOR]]#2
9095
// AMD: scf.yield %[[DOT_32]]
9196
// AMD: } else {
92-
// AMD: scf.yield %{{.*}}#2
97+
// AMD: scf.yield %[[FOR]]#2
9398
// AMD: }
94-
// AMD: %[[SELECT_31:.*]] = arith.select %[[CMPI_26]], %[[IF_30]], %{{.*}}#2
99+
// AMD: %[[SELECT_31:.*]] = arith.select %[[CMPI_28]], %[[IF_30]], %[[FOR]]#2
95100
// AMD: triton_gpu.local_dealloc %{{.*}}
96101
// AMD: triton_gpu.local_dealloc %{{.*}}
97102

0 commit comments

Comments
 (0)