Skip to content

Commit 44ecbec

Browse files
[PIPELINE] Re-enabling mmav5 pipelining after fixing some performance regressions (#6761)
Previous attempt at reworking `ttng.wait_barrier` placement (triton-lang/triton#6613) and enabling mmav5 (triton-lang/triton#6732) was overly conservative regarding overlapping mmav5 ops with itself in cases where the conditional `tmem_load` was placed in the same stage as the mma op. We can still overlap such cases, we just need to have wait added to the conditional block with `tmem_load`. This PR addresses this by adding second "latency" kind to the mma op - `tt.self_latency` that expresses what stage should the `wait_barrier` be pushed to. It also paves the way for enabling the user for independently controlling what stage the users of the mma should be placed (controlling tmem buffer count) and mma self-overlapping.
1 parent 8557148 commit 44ecbec

File tree

16 files changed

+657
-378
lines changed

16 files changed

+657
-378
lines changed

include/triton/Dialect/Triton/IR/TritonDialect.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,8 @@ def Triton_Dialect : Dialect {
4545

4646
let discardableAttrs = (ins
4747
"::mlir::IntegerAttr":$num_stages,
48-
"::mlir::IntegerAttr":$latency
48+
"::mlir::IntegerAttr":$latency,
49+
"::mlir::IntegerAttr":$self_latency
4950
);
5051

5152
let hasConstantMaterializer = 1;

include/triton/Dialect/TritonGPU/Transforms/MMAv5PipelineUtility.h

Lines changed: 31 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,37 @@ namespace triton::nvidia_gpu {
2020
// Given an MMAv5 operation in a loop, determine if its accumulator can be
2121
// multibuffered.
2222
bool isAccMultibufferingPossible(MMAv5OpInterface mma, scf::ForOp forOp);
23-
// Only pipeline the loops where the MMA happens before the tmem_load, or is in
24-
// the same stage as the tmem_load. Lowering does not support the case where the
25-
// MMA is in a different stage as the tmem_load and happens after it.
26-
bool mmav5DominatesTmemLoads(
27-
scf::ForOp forOp, function_ref<bool(MMAv5OpInterface)> isMmaPipelineable);
23+
24+
// Returns true if the MMA operation requires acc multi-buffering when
25+
// pipelined.
26+
bool requiresAccMultiBuffering(MMAv5OpInterface mma, scf::ForOp forOp);
27+
28+
// Returns true if there are loads from tmem after the MMA operation.
29+
bool hasLoadsAfterMMA(MMAv5OpInterface mma, scf::ForOp forOp);
30+
31+
// Helper class to determine if the operands of an MMA operation are
32+
// pipelineable.
33+
class MMAv5PipelineableOperandsHelper {
34+
public:
35+
MMAv5PipelineableOperandsHelper(
36+
MMAv5OpInterface mmaOp, scf::ForOp forOp,
37+
std::function<bool(Operation *)> isLoadToBePipelined)
38+
: mmaOp(mmaOp), forOp(forOp), isLoadToBePipelined(isLoadToBePipelined) {
39+
run();
40+
}
41+
bool isPipelineable = false;
42+
// If true, the existing operand loads are all been found and their
43+
// pipelineability has been determined.
44+
bool isOperandsStateDetermined = false;
45+
SmallVector<Operation *> unpipelineableOperandLoads;
46+
47+
private:
48+
MMAv5OpInterface mmaOp;
49+
scf::ForOp forOp;
50+
std::function<bool(Operation *)> isLoadToBePipelined;
51+
bool comesFromLoadOrOutsideLoop(Value v, Operation *&foundLoad);
52+
void run();
53+
};
2854

2955
//===----------------------------------------------------------------------===//
3056
// MMA Pipeline Rewriters
@@ -35,12 +61,6 @@ bool mmav5DominatesTmemLoads(
3561
TMEMAllocOp createTMemAlloc(OpBuilder &builder, TMEMAllocOp oldTMemAllocOp,
3662
bool multiBufferred, int numStages);
3763

38-
// Return true if operands of the MMA operation are/are going to be pipelined
39-
// and multibuffered, enabling the MMA operation to be pipelined.
40-
bool mmaHasPipelineableOperands(
41-
MMAv5OpInterface mma, scf::ForOp forOp,
42-
std::function<bool(Operation *)> isLoadPipelineable);
43-
4464
// Return true if the accumulator of an mma in subsequent iterations is either
4565
// independent from the previous iteration (overwritten) or completely reused,
4666
// without read-modify-write.

include/triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,11 @@ int getCopyVecBytes(RankedTensorType registerTy,
9191
// attribute.
9292
void serializeLatencies(ModuleOp module, DenseMap<Operation *, int> &opLatency);
9393

94+
// Serialize the self latencies of the operations in the loops into the
95+
// self_latency attribute.
96+
void serializeSelfLatencies(ModuleOp module,
97+
DenseMap<Operation *, int> &opSelfLatency);
98+
9499
// Deserialize the latencies of the operations in the loops from the attribute.
95100
DenseMap<Operation *, int> deserializeLatencies(Operation *op);
96101

@@ -107,6 +112,9 @@ Value createAlloc(scf::ForOp forOp, RankedTensorType ty, Location loc,
107112
// Determine if the operation is a TMA load.
108113
bool isTMALoad(Operation *op);
109114

115+
// Determine if the operation can be lowered to an async load.
116+
bool canBeAsyncLoad(Operation *op);
117+
110118
// Look for consecutive wait ops and combine them into a single wait op.
111119
void combineRedundantWaitOps(
112120
llvm::SmallSetVector<gpu::AsyncWaitOp, 8> &waitOps);

include/triton/Dialect/TritonGPU/Transforms/Schedule.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,20 @@ class CoarseSchedule {
125125

126126
auto find(Operation *op) const { return opToStageAndCluster.find(op); }
127127

128+
// Split the cluster containing op into two clusters, one containing all
129+
// operations before the op and one containing op and all operations after the
130+
// op. Return the cluster containing op and all operations after the op.
131+
Cluster splitClusterBefore(Operation *op, scf::ForOp forOp);
132+
133+
// Check if op a will show up before op b in the final unrolled code.
134+
bool isOpBefore(Operation *a, Operation *b);
135+
136+
// Check if op a is in earlier cluster than op b.
137+
bool isOpInEarlierCluster(Operation *a, Operation *b);
138+
139+
// Check if op a is in the same cluster as op b.
140+
bool isOpInSameCluster(Operation *a, Operation *b);
141+
128142
SmallVector<std::tuple<Operation *, int, Cluster>>
129143
getOpsInOrder(scf::ForOp forOp);
130144
std::vector<std::pair<Operation *, unsigned>>

include/triton/Tools/Sys/GetEnv.hpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,7 @@ inline const std::set<std::string> CACHE_INVALIDATING_ENV_VARS = {
4343
"NVPTX_ENABLE_DUMP",
4444
"STORE_TMEM_TO_GLOBAL_BYPASS_SMEM",
4545
"ALLOW_LHS_TMEM_LAYOUT_CONVERSION",
46-
"TRITON_F32_DEFAULT",
47-
"ENABLE_MMA_V5_ATT_PIPELINE"
46+
"TRITON_F32_DEFAULT"
4847
// clang-format on
4948
};
5049

lib/Dialect/TritonGPU/Transforms/Pipeliner/AssignLatencies.cpp

Lines changed: 35 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -260,37 +260,56 @@ class AssignMMALatencies {
260260
: forOp(forOp), opLatency(opLatency) {};
261261

262262
void run() {
263-
if (!triton::tools::getBoolEnv("ENABLE_MMA_V5_ATT_PIPELINE")) {
264-
int mmav5Count = 0;
265-
for (auto &op : forOp.getBody()->without_terminator()) {
266-
if (isa<ttng::MMAv5OpInterface>(&op)) {
267-
mmav5Count++;
268-
}
269-
}
270-
if (mmav5Count > 1)
271-
return;
272-
}
263+
DenseMap<Operation *, int> mmaSelfLatency;
273264
// Check if the load op (mma operand) is pipelineable.
274-
auto isLoadPipelineable = [&](Operation *op) {
265+
auto isLoadToBePipelined = [&](Operation *op) {
275266
return opLatency.count(op) && opLatency[op] > 0;
276267
};
277268
for (auto &op : forOp.getBody()->without_terminator()) {
278269
// If the acc can not be multibuffered, do not pipeline the uses of
279270
// the MMA to later stages.
280271
if (auto mma = dyn_cast<ttng::MMAv5OpInterface>(&op)) {
281-
if (ttng::mmaHasPipelineableOperands(mma, forOp, isLoadPipelineable) &&
282-
!ttng::hasAccReadModifyWrite(mma, forOp) &&
283-
ttng::isAccMultibufferingPossible(mma, forOp) &&
284-
!getDisallowAccMultiBuffer(forOp)) {
285-
opLatency[&op] = 1;
272+
// Try to push out the wait by one stage even if the operands are not
273+
// pipelineable, but we know where the loads are scheduled, so we can
274+
// place the wait right before the loads.
275+
276+
if (hasSyncDots(forOp)) {
277+
// Skip pipelining MMA in the loops where sync dots are used. This is
278+
// dirty heuristic for performance drops in kernels where we would
279+
// rather want to have last iteration peeled instead of having a full
280+
// iteration of masked operations only to execute single wait.
281+
continue;
282+
}
283+
auto pipeHelper = ttng::MMAv5PipelineableOperandsHelper(
284+
mma, forOp, isLoadToBePipelined);
285+
if (pipeHelper.isPipelineable ||
286+
(pipeHelper.isOperandsStateDetermined &&
287+
!ttng::hasLoadsAfterMMA(mma, forOp))) {
288+
// MMA can be overlapped with itself
289+
mmaSelfLatency[mma] = 1;
290+
if (!ttng::requiresAccMultiBuffering(mma, forOp) ||
291+
(ttng::isAccMultibufferingPossible(mma, forOp) &&
292+
!getDisallowAccMultiBuffer(forOp))) {
293+
// MMA's users can be pushed to the next stage
294+
opLatency[&op] = 1;
295+
}
286296
}
287297
}
288298
}
299+
serializeSelfLatencies(forOp->getParentOfType<ModuleOp>(), mmaSelfLatency);
289300
}
290301

291302
private:
292303
scf::ForOp forOp;
293304
DenseMap<Operation *, int> &opLatency;
305+
306+
bool hasSyncDots(scf::ForOp forOp) {
307+
for (auto &op : forOp.getBody()->without_terminator()) {
308+
if (isa<mlir::triton::DotOp>(op))
309+
return true;
310+
}
311+
return false;
312+
}
294313
};
295314

296315
} // namespace

0 commit comments

Comments
 (0)