Skip to content

Commit 1572e34

Browse files
[PIPELINE] Implementing expander option to leave stage predicates as an unresolved op (#6836)
We would like to implement epilogue peeling by custom amount of iterations to help with cases where last loop iteration is almost entirely predicated out except for final mmav5 wait. To avoid adding hard to debug complexity to the expander, we will peel the epilogue after the expansion and resolve the stage predicates manually (masking out the instructions from non-last iterations after the loop, and removing the mask from one-before-last iteration in the loop). Adding an option to delay resolving the predicate to after the expansion makes such transformation much easier, as we won't need to analyze the arithmetic ops used to build the logical predicates for the ops.
1 parent 4595f3a commit 1572e34

File tree

5 files changed

+92
-12
lines changed

5 files changed

+92
-12
lines changed

include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,18 @@ def TTG_LocalStoreOp : TTG_Op<"local_store"> {
313313
}];
314314
}
315315

316+
def TTG_PredicateStageOp: TTG_Op<"predicate_stage",
317+
[Pure, AllTypesMatch<["iv", "ub", "step"]>]> {
318+
let summary = "pipeliner stage predicate";
319+
let arguments = (ins AnySignlessIntegerOrIndex:$iv,
320+
AnySignlessIntegerOrIndex:$ub,
321+
AnySignlessIntegerOrIndex:$step,
322+
I32Attr:$maxStage,
323+
I32Attr:$stage);
324+
let results = (outs I1:$result);
325+
let assemblyFormat = "$iv `,` $ub `,` $step `maxStage` $maxStage `stage` $stage attr-dict `:` type($iv) `->` type($result)";
326+
}
327+
316328
def TTG_Fp4ToFpOp : TTG_Op<"fp4_to_fp", [Pure]> {
317329
let summary = "Upcast fp4 (e2m1) to fp";
318330

include/triton/Dialect/TritonGPU/Transforms/PipelineExpander.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,12 @@ struct PipeliningOption {
5757
/// pipeliner will have to predicate operations in the prologue/epilogue.
5858
bool supportDynamicLoops = false;
5959

60+
/// If set, use this function to emit the predicate stage ops instead of the
61+
/// default one.
62+
using EmitPredicateStageFnType = std::function<Value(
63+
RewriterBase &, Value, Value, Value, uint64_t, uint64_t)>;
64+
EmitPredicateStageFnType emitPredicateStageFn = nullptr;
65+
6066
// Callback to predicate operations when the prologue or epilogue are not
6167
// peeled. This takes the original operation, an i1 predicate value and the
6268
// pattern rewriter. It is expected to replace the given operation with
@@ -95,6 +101,10 @@ FailureOr<scf::ForOp> pipelineForLoop(RewriterBase &rewriter, scf::ForOp forOp,
95101
const PipeliningOption &options,
96102
bool *modifiedIR = nullptr);
97103

104+
Value emitPredicateForStage(RewriterBase &rewriter, Value inductionVar,
105+
Value upperBound, Value step, uint64_t maxStage,
106+
uint64_t stage);
107+
98108
} // namespace triton
99109
} // namespace mlir
100110

lib/Dialect/TritonGPU/Transforms/Pipeliner/PipelineExpander.cpp

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,8 @@ struct LoopPipelinerInternal {
6767
triton::PipeliningOption::AnnotationlFnType annotateFn = nullptr;
6868
bool peelEpilogue;
6969
triton::PipeliningOption::PredicateOpFnType predicateFn = nullptr;
70+
triton::PipeliningOption::EmitPredicateStageFnType emitPredicateStageFn =
71+
nullptr;
7072

7173
// When peeling the kernel we generate several version of each value for
7274
// different stage of the prologue. This map tracks the mapping between
@@ -160,6 +162,10 @@ bool LoopPipelinerInternal::initializeLoopInfo(
160162
LDBG("--no epilogue or predicate set -> BAIL");
161163
return false;
162164
}
165+
emitPredicateStageFn = options.emitPredicateStageFn;
166+
if (emitPredicateStageFn == nullptr) {
167+
emitPredicateStageFn = mlir::triton::emitPredicateForStage;
168+
}
163169
std::vector<std::pair<Operation *, unsigned>> schedule;
164170
options.getScheduleFn(forOp, schedule);
165171
if (schedule.empty()) {
@@ -490,20 +496,10 @@ LogicalResult LoopPipelinerInternal::createKernel(
490496
if (!peelEpilogue) {
491497
// Create a predicate for each stage except the last stage.
492498
Location loc = newForOp.getLoc();
493-
Type t = ub.getType();
494499
for (unsigned i = 0; i < maxStage; i++) {
495500
// c = ub - (maxStage - i) * step
496-
Value c = rewriter.create<arith::SubIOp>(
497-
loc, ub,
498-
rewriter.create<arith::MulIOp>(
499-
loc, step,
500-
rewriter.create<arith::ConstantOp>(
501-
loc, rewriter.getIntegerAttr(t, int64_t(maxStage - i)))));
502-
503-
Value pred = rewriter.create<arith::CmpIOp>(
504-
newForOp.getLoc(), arith::CmpIPredicate::slt,
505-
newForOp.getInductionVar(), c);
506-
predicates[i] = pred;
501+
predicates[i] = emitPredicateStageFn(rewriter, newForOp.getInductionVar(),
502+
ub, step, maxStage, i);
507503
}
508504
}
509505
for (Operation *op : opOrder) {
@@ -852,3 +848,19 @@ mlir::triton::pipelineForLoop(RewriterBase &rewriter, ForOp forOp,
852848

853849
return newForOp;
854850
}
851+
852+
Value mlir::triton::emitPredicateForStage(RewriterBase &rewriter,
853+
Value inductionVar, Value upperBound,
854+
Value step, uint64_t maxStage,
855+
uint64_t stage) {
856+
auto loc = inductionVar.getLoc();
857+
auto type = inductionVar.getType();
858+
Value c = rewriter.create<arith::SubIOp>(
859+
loc, upperBound,
860+
rewriter.create<arith::MulIOp>(
861+
loc, step,
862+
rewriter.create<arith::ConstantOp>(
863+
loc, rewriter.getIntegerAttr(type, maxStage - stage))));
864+
return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
865+
inductionVar, c);
866+
}

lib/Dialect/TritonGPU/Transforms/Pipeliner/SoftwarePipeliner.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,17 @@ static void expandLoops(ModuleOp moduleOp) {
6262
std::vector<std::pair<Operation *, unsigned>> &schedule) {
6363
schedule = finalSchedule;
6464
};
65+
// Testing feature: allow for unresolved predicate stage ops
66+
// in the loop body.
67+
if (forOp->hasAttr("__test_keep_predicate_stage")) {
68+
options.emitPredicateStageFn =
69+
[](RewriterBase &rewriter, Value inductionVar, Value upperBound,
70+
Value step, uint64_t maxStage, uint64_t stage) {
71+
return rewriter.create<triton::gpu::PredicateStageOp>(
72+
inductionVar.getLoc(), inductionVar, upperBound, step, maxStage,
73+
stage);
74+
};
75+
}
6576
IRRewriter rewriter(forOp);
6677
FailureOr<scf::ForOp> newForOp =
6778
triton::pipelineForLoop(rewriter, forOp, options);

test/TritonGPU/loop-pipeline.mlir

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1696,3 +1696,38 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
16961696
tt.return
16971697
}
16981698
}
1699+
1700+
// -----
1701+
1702+
#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
1703+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
1704+
// CHECK-LABEL: @predicate_stage1
1705+
// CHECK: scf.for %[[IV:.*]] = %[[LB:.*]] to %[[UB:.*]] step %[[STEP:.*]] iter_args
1706+
// CHECK: ttg.predicate_stage %[[IV]], %[[UB]], %[[STEP]] maxStage 2 stage 0 : i32 -> i1
1707+
tt.func public @predicate_stage1(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 16 : i32}) attributes {noinline = false} {
1708+
%c1024_i32 = arith.constant 1024 : i32
1709+
%c0_i32 = arith.constant 0 : i32
1710+
%c1016800_i32 = arith.constant 1016800 : i32
1711+
%0 = tt.get_program_id x : i32
1712+
%1 = arith.muli %0, %c1016800_i32 : i32
1713+
%2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked>
1714+
%3 = tt.splat %arg3 : i32 -> tensor<1024xi32, #blocked>
1715+
%4 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>, #blocked>
1716+
%5 = tt.splat %arg1 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>, #blocked>
1717+
%6 = tt.splat %arg2 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>, #blocked>
1718+
scf.for %arg4 = %c0_i32 to %c1016800_i32 step %c1024_i32 : i32 {
1719+
%7 = arith.addi %1, %arg4 : i32
1720+
%8 = tt.splat %7 : i32 -> tensor<1024xi32, #blocked>
1721+
%9 = arith.addi %8, %2 : tensor<1024xi32, #blocked>
1722+
%10 = arith.cmpi slt, %9, %3 : tensor<1024xi32, #blocked>
1723+
%11 = tt.addptr %4, %9 : tensor<1024x!tt.ptr<f32>, #blocked>, tensor<1024xi32, #blocked>
1724+
%12 = tt.load %11, %10 : tensor<1024x!tt.ptr<f32>, #blocked>
1725+
%13 = tt.addptr %5, %9 : tensor<1024x!tt.ptr<f32>, #blocked>, tensor<1024xi32, #blocked>
1726+
%14 = tt.load %13, %10 : tensor<1024x!tt.ptr<f32>, #blocked>
1727+
%15 = arith.addf %12, %14 : tensor<1024xf32, #blocked>
1728+
%16 = tt.addptr %6, %9 : tensor<1024x!tt.ptr<f32>, #blocked>, tensor<1024xi32, #blocked>
1729+
tt.store %16, %15, %10 : tensor<1024x!tt.ptr<f32>, #blocked>
1730+
} {tt.num_stages = 3 : i32, __test_keep_predicate_stage}
1731+
tt.return
1732+
}
1733+
}

0 commit comments

Comments
 (0)