Skip to content

Commit b489f34

Browse files
MogballFindHao
authored andcommitted
[Blackwell] Optimize MMA warp specialization to allow multiple consumers of MMAv5 result (triton-lang#6487)
This is a redo of triton-lang#6457 but without affecting existing kernels. This adds a list of (barrier, pred) pairs to the MMAv5 ops, but does not alter codegen for the current MMAv5 ops -- they still keep a single barrier argument. This enables warp specialization to handle multi-consumers of MMA partition results without a separate "waiter" partition. This reduces the latency between MMA and load partition signalling, resulting in a small but consistent performance increase of up to 2.2% in dense fp8 matmul. Importantly, this reduces the number of required warps by 1 and simplifies the codegen for warp specialization, which will be important for FMHA.
1 parent 0a70a66 commit b489f34

28 files changed

+574
-454
lines changed

include/triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include <vector>
99

1010
namespace mlir {
11+
class DominanceInfo;
1112
class ImplicitLocOpBuilder;
1213
namespace triton {
1314

@@ -20,6 +21,38 @@ static const char *kLoopClusterAttrName = "loop.cluster";
2021
static const char *kScheduledMaxStageAttrName = "tt.scheduled_max_stage";
2122
static const char *kLatencyAttrName = "tt.latency";
2223

24+
//===----------------------------------------------------------------------===//
25+
// Hoisting Utilities
26+
//===----------------------------------------------------------------------===//
27+
28+
// By default, an operation can be hoisted if it is pure scalar operation.
29+
bool isPureScalarOp(Operation *op);
30+
31+
// Given a set of values and a reference operation, return true if all of the
32+
// values dominate the reference operation OR a set of "trivial" operations can
33+
// be moved before the reference operation such that the value set dominates the
34+
// reference operation.
35+
//
36+
// Returns false if it is not possible to make the values dominate the reference
37+
// operation. The function determines "trivial"-ness with the given callback.
38+
// By default, it determines that memory-effect-free and scalar operations are
39+
// trivial.
40+
bool getDominatingValueSetOpsToHoist(
41+
DominanceInfo &domInfo, Operation *refOp, ArrayRef<Value> valueSet,
42+
llvm::SetVector<Operation *> &toHoist,
43+
function_ref<bool(Operation *)> canHoist = isPureScalarOp);
44+
45+
// Hoist the given set of operations above the reference operation.
46+
void hoistOpsBefore(Operation *refOp,
47+
const llvm::SetVector<Operation *> &toHoist);
48+
// Hoist the given set of operations before the iterator.
49+
void hoistOpsBefore(Block *block, Block::iterator it,
50+
const llvm::SetVector<Operation *> &toHoist);
51+
52+
//===----------------------------------------------------------------------===//
53+
// Loop Pipelining Utilities
54+
//===----------------------------------------------------------------------===//
55+
2356
bool loopHasDistGreaterThanOne(scf::ForOp forOp);
2457
bool isOuterLoop(scf::ForOp forOp);
2558

include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOpInterfaces.td

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,10 @@ def MMAv5OpInterface : OpInterface<"MMAv5OpInterface"> {
1919
"void",
2020
"setUseAccumulator",
2121
(ins "::mlir::Value":$flag)>,
22-
InterfaceMethod<"Associate a new barrier to this MMAv5 op.",
22+
InterfaceMethod<"Associate a new completion barrier to this MMAv5 op.",
2323
"void",
24-
"setBarrier",
25-
(ins "::mlir::Value":$barrier)>,
24+
"addCompletionBarrier",
25+
(ins "::mlir::Value":$barrier, "::mlir::Value":$pred)>,
2626
InterfaceMethod<"Return the accumulator.",
2727
"::mlir::Value",
2828
"getAccumulator">,

include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td

Lines changed: 89 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -389,55 +389,102 @@ def TTNG_TMAStoreWaitOp : TTNG_Op<"async_tma_store_wait"> {
389389
let assemblyFormat = "attr-dict";
390390
}
391391

392-
def TTNG_TCGen5MMAOp : TTNG_Op<"tc_gen5_mma", [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>, DeclareOpInterfaceMethods<DotOpInterface>, DeclareOpInterfaceMethods<MMAv5OpInterface>]> {
393-
let summary = "block level op mapping to tensorcore gen5 mma";
392+
def TTNG_TCGen5MMAOp : TTNG_Op<"tc_gen5_mma", [
393+
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
394+
DeclareOpInterfaceMethods<DotOpInterface>,
395+
DeclareOpInterfaceMethods<MMAv5OpInterface>,
396+
SameVariadicOperandSize
397+
]> {
398+
let summary = "block level op mapping to tensorcore gen5 mma";
394399

395-
let description = [{
396-
$d += matrix_multiply($a, $b).
397-
If no barrier is given the op is assumed to be synchronous otherwise the op will trigger a commit/arrive on the given barrier.
398-
If there is a barrier the result will be safe to read after a barrier wait.
399-
If $two_ctas is set the op will execute a matmul across two contiguous CTAs, it will read the data distributed across the two CTAs.
400-
and syncronize both CTAs if the op is synchronous.
401-
}];
400+
let description = [{
401+
$d += matrix_multiply($a, $b).
402+
If no barrier is given the op is assumed to be synchronous otherwise the op will trigger a commit/arrive on the given barrier.
403+
If there is a barrier the result will be safe to read after a barrier wait.
404+
If $two_ctas is set the op will execute a matmul across two contiguous CTAs, it will read the data distributed across the two CTAs.
405+
and syncronize both CTAs if the op is synchronous.
406+
}];
402407

403-
let arguments = (ins TTG_MemDescType:$a,
404-
TTG_MemDescType:$b,
405-
TTG_MemDescType:$d,
406-
I1:$useD,
407-
I1:$pred,
408-
Optional<TTG_MemDescType>:$barrier,
409-
OptionalAttr<UnitAttr>:$two_ctas);
408+
let arguments = (ins
409+
TTG_MemDescType:$a,
410+
TTG_MemDescType:$b,
411+
TTG_MemDescType:$d,
412+
I1:$useD,
413+
I1:$pred,
414+
Variadic<TTG_MemDescType>:$barriers,
415+
Variadic<I1>:$barrier_preds,
416+
OptionalAttr<UnitAttr>:$two_ctas
417+
);
410418

411-
// TODO: improve printing format.
412-
let assemblyFormat = "$a`,` $b`,` $d`,` $useD`,` $pred (`,` $barrier^)? attr-dict `:` functional-type(operands, results)";
419+
let builders = [
420+
OpBuilder<(ins
421+
"Value":$a, "Value":$b, "Value":$d, "Value":$useD, "Value":$pred,
422+
CArg<"bool", "false">:$two_ctas, CArg<"ValueRange", "{}">:$barriers,
423+
CArg<"ValueRange", "{}">:$barrier_preds)>
424+
];
425+
426+
let assemblyFormat = [{
427+
$a`,` $b`,` $d`,` $useD`,` $pred
428+
`` custom<BarriersAndPreds>($barriers, $barrier_preds)
429+
attr-dict `:` qualified(type($a)) `,` qualified(type($b)) `,`
430+
qualified(type($d)) (`,` qualified(type($barriers))^)?
431+
}];
413432
}
414433

415-
def TTNG_TCGen5MMAScaledOp : TTNG_Op<"tc_gen5_mma_scaled", [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>, DeclareOpInterfaceMethods<DotOpInterface, ["verifyDims", "verifyOutputDims"]>, DeclareOpInterfaceMethods<MMAv5OpInterface>]> {
416-
let summary = "block level op mapping to tensorcore gen5 mma";
434+
def TTNG_TCGen5MMAScaledOp : TTNG_Op<"tc_gen5_mma_scaled", [
435+
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
436+
DeclareOpInterfaceMethods<DotOpInterface, ["verifyDims", "verifyOutputDims"]>,
437+
DeclareOpInterfaceMethods<MMAv5OpInterface>,
438+
SameVariadicOperandSize
439+
]> {
440+
let summary = "block level op mapping to tensorcore gen5 mma";
417441

418-
let description = [{
419-
$d += matrix_multiply(scale($lhs, $lhs_scale), scale(rlhs, $rhs_scale))
420-
If no barrier is given the op is assumed to be synchronous otherwise the op will trigger a commit/arrive on the given barrier.
421-
If there is a barrier the result will be safe to read after a barrier wait.
422-
}];
442+
let description = [{
443+
$d += matrix_multiply(scale($lhs, $lhs_scale), scale(rlhs, $rhs_scale))
444+
If no barrier is given the op is assumed to be synchronous otherwise the op will trigger a commit/arrive on the given barrier.
445+
If there is a barrier the result will be safe to read after a barrier wait.
446+
}];
423447

424-
let arguments = (ins TTG_MemDescType:$a,
425-
TTG_MemDescType:$b,
426-
TTG_MemDescType:$d,
427-
TTG_MemDescType:$a_scale,
428-
TTG_MemDescType:$b_scale,
429-
TT_ScaleDotElemTypeAttr:$a_type,
430-
TT_ScaleDotElemTypeAttr:$b_type,
431-
I1:$useD,
432-
I1:$pred,
433-
Optional<TTG_MemDescType>:$barrier);
434-
let extraClassDeclaration = [{
435-
int64_t getBlockM();
436-
int64_t getBlockN();
437-
int64_t getBlockK();
438-
}];
439-
// TODO: improve printing format.
440-
let assemblyFormat = "$a `,` $b `,` $d `,` $a_scale `,` $b_scale `,` $useD`,` $pred `lhs` `=` $a_type `rhs` `=` $b_type (`,` $barrier^)? attr-dict `:` functional-type(operands, results)";
448+
let arguments = (ins
449+
TTG_MemDescType:$a,
450+
TTG_MemDescType:$b,
451+
TTG_MemDescType:$d,
452+
TTG_MemDescType:$a_scale,
453+
TTG_MemDescType:$b_scale,
454+
TT_ScaleDotElemTypeAttr:$a_type,
455+
TT_ScaleDotElemTypeAttr:$b_type,
456+
I1:$useD,
457+
I1:$pred,
458+
Variadic<TTG_MemDescType>:$barriers,
459+
Variadic<I1>:$barrier_preds
460+
);
461+
let extraClassDeclaration = [{
462+
int64_t getBlockM();
463+
int64_t getBlockN();
464+
int64_t getBlockK();
465+
}];
466+
467+
let builders = [
468+
// Namespaces need to be prefixed so ODS prefers our
469+
// custom builder signature over the default-generated one.
470+
OpBuilder<(ins
471+
"::mlir::Value":$a, "::mlir::Value":$b, "::mlir::Value":$d,
472+
"::mlir::Value":$a_scale, "::mlir::Value":$b_scale,
473+
"::mlir::triton::ScaleDotElemType":$a_type,
474+
"::mlir::triton::ScaleDotElemType":$b_type,
475+
"::mlir::Value":$useD, "::mlir::Value":$pred,
476+
CArg<"::mlir::ValueRange", "{}">:$barriers,
477+
CArg<"::mlir::ValueRange", "{}">:$barrier_preds)>
478+
];
479+
480+
let assemblyFormat = [{
481+
$a `,` $b `,` $d `,` $a_scale `,` $b_scale `,` $useD`,` $pred
482+
`lhs` `=` $a_type `rhs` `=` $b_type
483+
`` custom<BarriersAndPreds>($barriers, $barrier_preds)
484+
attr-dict `:` qualified(type($a)) `,` qualified(type($b)) `,`
485+
qualified(type($d)) `,` qualified(type($a_scale)) `,`
486+
qualified(type($b_scale)) (`,` qualified(type($barriers))^)?
487+
}];
441488
}
442489

443490
def TTNG_TMEMLoadOp : TTNG_Op<"tmem_load"> {

lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -548,7 +548,7 @@ class BlockedToMMAv5 : public mlir::OpRewritePattern<DotOp> {
548548
loc, accMemDescType, cvtAcc);
549549
auto vTrue = rewriter.create<arith::ConstantIntOp>(dotOp.getLoc(), 1, 1);
550550
auto mma = rewriter.create<triton::nvidia_gpu::TCGen5MMAOp>(
551-
loc, a, b, acc, vTrue, vTrue, Value(), UnitAttr());
551+
loc, a, b, acc, /*useD=*/vTrue, /*pred=*/vTrue);
552552
mma.setTwoCtas(useTwoCTAs);
553553

554554
auto ld =
@@ -735,7 +735,7 @@ class ScaledBlockedToMMAv5
735735
auto vTrue = rewriter.create<arith::ConstantIntOp>(dotOp.getLoc(), 1, 1);
736736
rewriter.create<triton::nvidia_gpu::TCGen5MMAScaledOp>(
737737
loc, a, b, acc, scaleA, scaleB, dotOp.getAElemType(),
738-
dotOp.getBElemType(), vTrue, vTrue, Value());
738+
dotOp.getBElemType(), /*useD=*/vTrue, /*pred=*/vTrue);
739739

740740
auto ld =
741741
rewriter.create<triton::nvidia_gpu::TMEMLoadOp>(loc, newAccType, acc);

lib/Dialect/TritonGPU/Transforms/FuseNestedLoops.cpp

Lines changed: 7 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
#include "mlir/Analysis/TopologicalSortUtils.h"
21
#include "mlir/Dialect/UB/IR/UBOps.h"
32
#include "mlir/IR/Dominance.h"
43
#include "mlir/IR/ImplicitLocOpBuilder.h"
@@ -238,7 +237,8 @@ static Logue createLogueFrom(llvm::iterator_range<Block::iterator> ops,
238237
// Only hoist operations that are side-effect free and "cheap" (i.e. only scalar
239238
// operands). Importantly, we need to be able to hoist code generated by fusing
240239
// children loops into their parents so the algorithm can be applied
241-
// recursively.
240+
// recursively. This includes integer division, which are not speculatable, but
241+
// we know they will never divide by zero.
242242
static bool canHoistLoopBoundComputation(Operation *op) {
243243
auto isScalar = [](Type type) { return type.isIntOrIndexOrFloat(); };
244244
return isMemoryEffectFree(op) &&
@@ -251,50 +251,8 @@ static bool canHoistLoopBoundComputation(Operation *op) {
251251
static bool isOuterLoopInvariant(mlir::DominanceInfo &domInfo, scf::ForOp outer,
252252
ArrayRef<Value> values,
253253
llvm::SetVector<Operation *> &toHoist) {
254-
// The set of operations within `outer` that are being checked if they can be
255-
// hoisted. This set prevents checking operations twice but also if the
256-
// computation can be hoisted, this becomes the set of operations to hoist.
257-
llvm::SetVector<Operation *> visited;
258-
259-
// Climb the use-def chain breadth-first so that operations can be hoisted in
260-
// the reverse visitation order.
261-
std::queue<Value> queue;
262-
for (Value value : values)
263-
queue.push(value);
264-
265-
while (!queue.empty()) {
266-
Value value = queue.front();
267-
queue.pop();
268-
269-
// If the value properly dominates the outer loop, then it must be invariant
270-
// to it.
271-
if (domInfo.properlyDominates(value, outer))
272-
continue;
273-
// If the value is a block argument, it cannot be hoisted.
274-
if (auto arg = dyn_cast<BlockArgument>(value))
275-
return false;
276-
277-
Operation *op = value.getDefiningOp();
278-
// Check if the op was already visited.
279-
if (visited.contains(op))
280-
continue;
281-
// If the defining op cannot be hoisted, then the value cannot be made loop
282-
// invariant.
283-
if (!canHoistLoopBoundComputation(op))
284-
return false;
285-
visited.insert(op);
286-
// Recurse on the operands of the op.
287-
for (Value operand : op->getOperands())
288-
queue.push(operand);
289-
}
290-
291-
// The operations in `visited` must be hoisted. Note that operations are not
292-
// added to `toHoist` unless all of `values` can be hoisted. This is to avoid
293-
// hoisting operations for loops that don't end up getting fused if one of
294-
// their bounds operands cannot be hoisted.
295-
toHoist.insert(visited.begin(), visited.end());
296-
297-
return true;
254+
return getDominatingValueSetOpsToHoist(domInfo, outer, values, toHoist,
255+
canHoistLoopBoundComputation);
298256
}
299257

300258
// Pessimistically assume the internal storage bitwidth for index types.
@@ -545,9 +503,7 @@ static void fuseOneLevel(LoopNestNode *parent, mlir::DominanceInfo &domInfo) {
545503
// The transformation will definitely succeed on `childrenToFuse`. `toHoist`
546504
// only contains the operations that must be hoisted for `childrenToFuse` to
547505
// be fusible.
548-
toHoist = topologicalSort(toHoist);
549-
for (Operation *op : toHoist)
550-
op->moveBefore(outer);
506+
hoistOpsBefore(outer, toHoist);
551507

552508
// Determine the integer type to use for the length computations. Use an
553509
// integer bitwidth twice the size of the largest integer, up to 64 bits, to
@@ -993,9 +949,7 @@ static void sinkOps(Region &limit, Block *sinkBlock, Block::iterator sinkBefore,
993949
if (sunkOps.empty())
994950
return;
995951

996-
sunkOps = topologicalSort(sunkOps);
997-
for (Operation *op : sunkOps)
998-
op->moveBefore(sinkBlock, sinkBefore);
952+
hoistOpsBefore(sinkBlock, sinkBefore, sunkOps);
999953
}
1000954

1001955
// Sink ops from the prologue into the epilogue when possible.
@@ -1028,9 +982,7 @@ static LogicalResult speculateInnerLoopLength(scf::ForOp outerLoop,
1028982
return failure();
1029983

1030984
// Hoist the inner loop bounds computations if necessary.
1031-
toHoist = topologicalSort(toHoist);
1032-
for (Operation *op : toHoist)
1033-
op->moveBefore(outerLoop);
985+
hoistOpsBefore(outerLoop, toHoist);
1034986

1035987
// Mark the inner loop.
1036988
ImplicitLocOpBuilder b(loc, outerLoop);

lib/Dialect/TritonGPU/Transforms/Pipeliner/LowerLoops.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -913,7 +913,8 @@ scf::ForOp createBarrierAndWaitOps(scf::ForOp forOp, CoarseSchedule &schedule,
913913
barrierSlice =
914914
triton::createSingleBufferView(builder, barrierAlloc, barrierIdx);
915915
}
916-
mma.setBarrier(barrierSlice);
916+
mma.addCompletionBarrier(barrierSlice,
917+
builder.create<arith::ConstantIntOp>(loc, 1, 1));
917918

918919
// List of buffers that may be used until wait completes
919920
SmallVector<Value> waitBuffers;

0 commit comments

Comments
 (0)