Skip to content

Commit 21fd9eb

Browse files
authored
[PIPELINER] Pipeline RS WGMMA (#6804)
This PR allows to pipeline WGMMAs that take the lhs on registers. The strategy is to wait on the WGMMA from the previous loop to have finished before executing the next one to avoid overwritting the registers too early. Note that this does depend on ptxas handling the register allocation correctly. In an 8k x 8k x 8k dense matmul we get a speed up of: 2.441 -> 2.039 We might need to split the pointwise computations and interleave them with the wgmmas similar to how CUTLASS does it, but we don't do that in this PR. This pass supersedes WGMMAPrefetch as it drops most of the preconditions of that pass.
1 parent 2ec711b commit 21fd9eb

File tree

11 files changed

+389
-1129
lines changed

11 files changed

+389
-1129
lines changed

include/triton/Dialect/TritonGPU/Transforms/Passes.td

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -219,23 +219,6 @@ def TritonGPUPrefetch : Pass<"tritongpu-prefetch", "mlir::ModuleOp"> {
219219
"mlir::arith::ArithDialect"];
220220
}
221221

222-
def TritonGPUWGMMAPrefetch : Pass<"tritongpu-wgmma-prefetch", "mlir::ModuleOp"> {
223-
let summary = "prefetch for wgmma mixed precision";
224-
225-
let description = [{
226-
This pass attempts to prefetch from shared memory for mixed-precision
227-
wgmma when operand A is in the shared memory and needs to be loaded
228-
to the local registers.
229-
}];
230-
231-
let dependentDialects = [ "mlir::triton::gpu::TritonGPUDialect",
232-
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect",
233-
"mlir::scf::SCFDialect",
234-
"mlir::arith::ArithDialect"];
235-
}
236-
237-
238-
239222
def TritonGPUAccelerateMatmul : Pass<"tritongpu-accelerate-matmul", "mlir::ModuleOp"> {
240223
let summary = "accelerate matmul";
241224

include/triton/Dialect/TritonGPU/Transforms/Utility.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,10 @@ getNumElementsPerThread(Operation *op, SmallVector<unsigned> order,
5454
// Returns whether the op is a "view op", i.e. doesn't move any data
5555
bool isView(Operation *op);
5656

57+
// Returns whether the op is a "noop op", i.e. has one input and one output
58+
// and lowers to llvm as the identity function (returns the input)
59+
bool isNoop(Operation *op);
60+
5761
/* Dump Triton IR in graphviz dot format.
5862
*
5963
* You can override `onValue` and `onOperation` in a subclass to mark

lib/Dialect/TritonGPU/Transforms/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ add_triton_library(TritonGPUTransforms
2525
Pipeliner/PipeliningUtility.cpp
2626
Pipeliner/Schedule.cpp
2727
Prefetch.cpp
28-
WGMMAPrefetch.cpp
2928
RemoveLayoutConversions.cpp
3029
ReorderInstructions.cpp
3130
CoalesceAsyncCopy.cpp

lib/Dialect/TritonGPU/Transforms/Pipeliner/WGMMAPipeline.cpp

Lines changed: 223 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
1515
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
1616
#include "triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h"
17+
#include "triton/Tools/LinearLayout.h"
1718
#include "llvm/ADT/MapVector.h"
1819
#include "llvm/ADT/STLExtras.h"
1920
#include "llvm/ADT/SetVector.h"
@@ -30,6 +31,30 @@ namespace tt = mlir::triton;
3031
namespace ttg = mlir::triton::gpu;
3132
namespace ttng = mlir::triton::nvidia_gpu;
3233

34+
// Returns whether the dot dot such that:
35+
// 1. The LHS comes from registers and
36+
// 1.1 The LHS is defined inside the loop
37+
// 1.2. The LHS does not come from another dot
38+
// For these dots, we assume that we cannot rewrite their
39+
// operands until the previous dot has finished
40+
static bool isRSDotFromSIMD(Operation *dot, scf::ForOp forOp) {
41+
auto dotOp = dyn_cast<ttng::WarpGroupDotOp>(dot);
42+
if (!dotOp)
43+
return false;
44+
auto a = dotOp.getA();
45+
if (!isa<RankedTensorType>(a.getType())) {
46+
return false;
47+
}
48+
if (forOp.isDefinedOutsideOfLoop(a)) {
49+
return false;
50+
}
51+
if (auto cvt = dyn_cast<ttg::ConvertLayoutOp>(a.getDefiningOp())) {
52+
return !isa<ttg::NvidiaMmaEncodingAttr>(
53+
cvt.getSrc().getType().getEncoding());
54+
}
55+
return true;
56+
}
57+
3358
/// Find the minimum number of async_commit_group ops between the wait
3459
/// and the associated async_commit_group. This can be safely used as the wait
3560
/// number.
@@ -206,6 +231,148 @@ static void threadValuesThroughWait(ttng::WarpGroupDotWaitOp wait,
206231
wait->erase();
207232
}
208233

234+
// Split the LHS of a RSWGMMADot operation into multiple multiple
235+
// tensors of size MxnewK via SplitOps
236+
SmallVector<Value> splitLhs(OpBuilder &builder,
237+
TypedValue<RankedTensorType> lhs, int64_t newK) {
238+
auto loc = lhs.getLoc();
239+
auto type = lhs.getType();
240+
auto rank = type.getRank();
241+
auto shape = to_vector(type.getShape());
242+
auto nSplits = shape.back() / newK;
243+
assert(nSplits > 1);
244+
// Reshape K == 2x..x2xnewK
245+
shape.pop_back();
246+
for (int i = 1; i < nSplits; i *= 2) {
247+
shape.push_back(2);
248+
}
249+
shape.push_back(newK);
250+
lhs = builder.create<tt::ReshapeOp>(loc, shape, lhs);
251+
// We want to split first the slowest running dim, then the second slowest,
252+
// etc.
253+
auto transOrder = to_vector(llvm::seq<int>(rank - 1));
254+
transOrder.push_back(shape.size() - 1);
255+
llvm::append_range(transOrder, llvm::reverse(llvm::seq(
256+
rank - 1, (int64_t)shape.size() - 1)));
257+
lhs = builder.create<tt::TransOp>(loc, lhs, transOrder);
258+
// We split recursively
259+
SmallVector<Value> curr;
260+
SmallVector<Value> ret = {lhs};
261+
for (int i = 1; i < nSplits; i *= 2) {
262+
curr = ret;
263+
ret.clear();
264+
for (auto v : curr) {
265+
auto split = builder.create<tt::SplitOp>(loc, v);
266+
ret.push_back(split.getResult(0));
267+
ret.push_back(split.getResult(1));
268+
}
269+
}
270+
271+
auto mmav3Type =
272+
type.clone(cast<RankedTensorType>(ret.front().getType()).getShape());
273+
// Convert the LHS to mmav3 layout
274+
for (auto &v : ret) {
275+
v = builder.create<ttg::ConvertLayoutOp>(loc, mmav3Type, v);
276+
// The layouts are noops by construction
277+
assert(minimalCvtLayout(v.getType(), mmav3Type) ==
278+
tt::LinearLayout::empty());
279+
}
280+
assert(ret.size() == nSplits);
281+
return ret;
282+
}
283+
284+
// Split the RHS of a RSWGMMADot operation into multiple multiple
285+
// tensors of size newKxN via MemDescSubview
286+
SmallVector<Value> splitRhs(OpBuilder &builder,
287+
TypedValue<ttg::MemDescType> rhs, int64_t newK) {
288+
auto loc = rhs.getLoc();
289+
auto type = rhs.getType();
290+
auto rank = type.getRank();
291+
auto kDim = rank - 2;
292+
auto nSplits = type.getShape()[kDim] / newK;
293+
auto shape = llvm::to_vector(type.getShape());
294+
shape[kDim] = newK;
295+
SmallVector<Value> offsetsVal;
296+
for (int i = 0; i < rank; i++) {
297+
offsetsVal.push_back(builder.create<arith::ConstantIntOp>(loc, 0, 32));
298+
}
299+
auto newType = ttg::MemDescType::get(
300+
shape, type.getElementType(), type.getEncoding(), type.getMemorySpace(),
301+
/*isMutable=*/false, type.getAllocShape());
302+
SmallVector<Value> ret;
303+
for (int i = 0; i < nSplits; i++) {
304+
offsetsVal[kDim] = builder.create<arith::ConstantIntOp>(loc, i * newK, 32);
305+
Value newSmem = builder.create<triton::gpu::MemDescSubviewOp>(
306+
loc, newType, rhs, offsetsVal);
307+
ret.push_back(newSmem);
308+
}
309+
return ret;
310+
}
311+
312+
std::vector<ttng::WarpGroupDotOp> splitRSDot(ttng::WarpGroupDotOp dotOp) {
313+
// Splits a wgmma(tensor, shmem) MxK, KxN -> MxN into
314+
// along K into multiple wgmma(tensor, shmem) Mx16, 16xN -> MxN
315+
// where 16 is the instruction size
316+
if (!isa<RankedTensorType>(dotOp.getA().getType())) {
317+
return {dotOp};
318+
}
319+
320+
auto a = cast<TypedValue<RankedTensorType>>(dotOp.getA());
321+
auto b = cast<TypedValue<ttg::MemDescType>>(dotOp.getB());
322+
auto origK = a.getType().getShape().back();
323+
auto newK = cast<ttg::NvidiaMmaEncodingAttr>(dotOp.getType().getEncoding())
324+
.getInstrShape()[2];
325+
auto numSplits = origK / newK;
326+
// Nothing to split
327+
if (numSplits <= 1) {
328+
return {dotOp};
329+
}
330+
331+
assert(origK % newK == 0 && "origK must be divisible by newK");
332+
auto builder = OpBuilder(dotOp);
333+
auto loc = dotOp.getLoc();
334+
auto lhss = splitLhs(builder, a, newK);
335+
auto rhss = splitRhs(builder, b, newK);
336+
assert(lhss.size() == numSplits && "lhs must have the same number of splits");
337+
assert(rhss.size() == numSplits && "rhs must have the same number of splits");
338+
339+
Value useC = dotOp.getUseC();
340+
Value C = dotOp.getC();
341+
auto numImpreciseAccLeft = dotOp.getMaxNumImpreciseAcc();
342+
std::vector<ttng::WarpGroupDotOp> dots;
343+
for (int i = 0; i < numSplits; i++) {
344+
// 2**30 is to prevent the subtile from adding
345+
// extra imprecise accumulator, See WGMMA.cpp
346+
uint32_t numImpreciseAcc = (numImpreciseAccLeft > newK)
347+
? 1073741824 // 2**30
348+
: numImpreciseAccLeft;
349+
// Deduct the actual consumed imprecise acc
350+
numImpreciseAccLeft -= std::min(numImpreciseAccLeft, newK);
351+
auto dot = builder.create<ttng::WarpGroupDotOp>(
352+
loc, dotOp.getType(), lhss[i], rhss[i], C, useC,
353+
dotOp.getInputPrecision(), numImpreciseAcc, dotOp.getIsAsync());
354+
dots.push_back(dot);
355+
C = dot.getResult();
356+
useC = builder.create<mlir::arith::ConstantIntOp>(loc, 1, 1);
357+
}
358+
dotOp.replaceAllUsesWith(dots.back().getResult());
359+
dotOp.erase();
360+
return dots;
361+
}
362+
363+
// Apply splitRSDot to all dots in the input list.
364+
llvm::MapVector<Operation *, int>
365+
splitRSDots(const llvm::MapVector<Operation *, int> &dots) {
366+
llvm::MapVector<Operation *, int> ret;
367+
for (auto [dot, iterArgIdx] : dots) {
368+
auto newDots = splitRSDot(cast<ttng::WarpGroupDotOp>(dot));
369+
for (auto newDot : newDots) {
370+
ret.insert({newDot, iterArgIdx});
371+
}
372+
}
373+
return ret;
374+
}
375+
209376
// Determines whether a given MMAv3 dot op, represented as ttng.warp_group_dot,
210377
// needs a wait immediately after it.
211378
//
@@ -260,21 +427,11 @@ static std::optional<int> dotCanBeProperlyAsync(ttng::WarpGroupDotOp dotOp,
260427
scf::ForOp forOp) {
261428
LDBG("Considering whether to make MMAv3 dot properly async: " << dotOp);
262429

263-
// Rule 1: All shmem operands are multi-buffered.
264430
auto checkOperand = [&](Value operand) {
265-
if (!isa<ttg::SharedEncodingTrait>(
266-
cast<ttg::TensorOrMemDesc>(operand.getType()).getEncoding())) {
267-
// Rule 1a: Register operands must not be modified within the loop.
268-
// First, check for chained WGMMA as an exception.
269-
if (auto cvt = dyn_cast<ttg::ConvertLayoutOp>(operand.getDefiningOp())) {
270-
return isa<ttg::NvidiaMmaEncodingAttr>(
271-
cvt.getSrc().getType().getEncoding());
272-
}
273-
// And then, do a stricter-than-necessary check for now, that the operand
274-
// is defined outside the loop.
275-
return forOp.isDefinedOutsideOfLoop(operand);
431+
// We can always make RSGEMM async s long as the RHS can be multi-buffered
432+
if (isa<RankedTensorType>(operand.getType())) {
433+
return true;
276434
}
277-
278435
// If it's a shmem operand, it must either be defined outside the loop, or
279436
// come from an MemDescSubview op. Only ConvertLayout and view ops are
280437
// allowed in between.
@@ -296,6 +453,7 @@ static std::optional<int> dotCanBeProperlyAsync(ttng::WarpGroupDotOp dotOp,
296453
transitiveOperand.getDefiningOp<ttg::MemDescSubviewOp>();
297454
};
298455

456+
// Rule 1: All shmem operands are multi-buffered.
299457
// We don't have to call checkOperand on getC() because it's always in
300458
// registers, never in shmem.
301459
assert(isa<ttg::NvidiaMmaEncodingAttr>(dotOp.getC().getType().getEncoding()));
@@ -315,6 +473,13 @@ static std::optional<int> dotCanBeProperlyAsync(ttng::WarpGroupDotOp dotOp,
315473
while (!queue.empty()) {
316474
auto [user, argIdx] = queue.pop_back_val();
317475
if (user->getParentOp() == forOp) {
476+
// We support noops in between the dot and the yield
477+
if (isNoop(user)) {
478+
for (auto &use : user->getResult(0).getUses()) {
479+
queue.push_back({use.getOwner(), use.getOperandNumber()});
480+
}
481+
continue;
482+
}
318483
if (isa<scf::YieldOp>(user)) {
319484
if (iterArg) {
320485
// The dot is used by the loop's yield, but we can't have any other
@@ -343,15 +508,28 @@ static std::optional<int> dotCanBeProperlyAsync(ttng::WarpGroupDotOp dotOp,
343508
return std::nullopt;
344509
}
345510
}
511+
// Rule 2.1: We don't make the dot async if the accumulator is not fp32.
512+
if (!dotOp.getC().getType().getElementType().isF32()) {
513+
LDBG("Can't make dot async because the accumulator is not fp32");
514+
return std::nullopt;
515+
}
346516

347-
// Rule 3a: Are the only users of the dot's result from iteration i-1 other
348-
// MMAv3 dots? If so, we're done, this dot can be properly async.
349-
if (llvm::all_of(iterArg.getUses(), [&](OpOperand &use) {
350-
return isa<ttng::WarpGroupDotOp>(use.getOwner()) &&
351-
use.getOperandNumber() == 2;
352-
})) {
517+
// Rule 3a: Check that every use of the dot’s result (iterArg) eventually
518+
// reaches a WarpGroupDotOp (with use index 2), possibly after passing through
519+
// a chain of noops
520+
std::function<bool(OpOperand &)> isTransitivelyWarpGroupDot =
521+
[&](OpOperand &use) -> bool {
522+
Operation *user = use.getOwner();
523+
if (isa<ttng::WarpGroupDotOp>(user))
524+
return use.getOperandNumber() == 2;
525+
if (isNoop(user))
526+
return llvm::all_of(user->getResult(0).getUses(),
527+
isTransitivelyWarpGroupDot);
528+
return false;
529+
};
530+
531+
if (llvm::all_of(iterArg.getUses(), isTransitivelyWarpGroupDot))
353532
return iterArgIdx;
354-
}
355533

356534
// Rule 3b: Are all users of the dot's result from iteration i-1 after the
357535
// first `warp_group_dot_wait {pendings=0}` op? If so, the dot can be
@@ -414,7 +592,21 @@ static void insertAsyncWarpGroupDotWaitInLoop(
414592

415593
// Insert waits before the users of the properly async dots other than loop
416594
// yield.
417-
for (auto [asyncDot, iterArgIdx] : properlyAsyncDots) {
595+
for (auto asyncDot : llvm::make_first_range(properlyAsyncDots)) {
596+
// If the dot takes the LHS on registers i, we add a wait for the number
597+
// of properly async dots in the loop minus one.
598+
// This makes sure that the dot will wait until itself from the previous
599+
// iteration has completed, as to avoid rewriting the registers.
600+
if (isRSDotFromSIMD(asyncDot, forOp)) {
601+
OpBuilder builder(asyncDot);
602+
builder.setInsertionPointAfter(asyncDot);
603+
auto newWait = builder.create<ttng::WarpGroupDotWaitOp>(
604+
asyncDot->getLoc(), ArrayRef<Value>{}, properlyAsyncDots.size() - 1);
605+
SmallVector<Value> waitOperands = {asyncDot->getResult(0)};
606+
threadValuesThroughWait(newWait, waitOperands);
607+
continue;
608+
}
609+
418610
SmallVector<OpOperand *> uses;
419611
for (auto &use : asyncDot->getUses()) {
420612
if (auto yieldOp = dyn_cast<scf::YieldOp>(use.getOwner())) {
@@ -448,6 +640,11 @@ static void insertAsyncWarpGroupDotWaitInLoop(
448640
// by a dot.)
449641
IRRewriter builder(forOp.getContext());
450642
auto lastAsyncDot = properlyAsyncDots.back().first;
643+
// If the last dot is an RS dot, we don't need to insert a wait
644+
// as we have already inserted a wait(properlyAsyncDots.size() - 1)
645+
if (isRSDotFromSIMD(lastAsyncDot, forOp)) {
646+
return;
647+
}
451648
builder.setInsertionPointAfter(lastAsyncDot);
452649
auto wait = builder.create<ttng::WarpGroupDotWaitOp>(
453650
lastAsyncDot->getLoc(),
@@ -504,6 +701,11 @@ void triton::asyncLaunchDots(scf::ForOp forOp) {
504701
return;
505702
}
506703

704+
// Split RS dots into dots with K = 16 (the instruction size of MMAv3)
705+
// If we split them in nSplit dots, we will be able to keep nSplit-1 dots
706+
// in flight at a time.
707+
properlyAsyncDots = splitRSDots(properlyAsyncDots);
708+
507709
// Next, insert a wait inside the loop. We pipeline to depth 2, so the third
508710
// iteration's set of asynchronous dots (and their corresponding async copies
509711
// from global to shmem) can't start until the first iteration's set has

lib/Dialect/TritonGPU/Transforms/Utility.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,17 @@ bool isView(Operation *op) {
149149
return isa<ExpandDimsOp, ReshapeOp, TransOp, JoinOp, SplitOp>(op);
150150
}
151151

152+
bool isNoop(Operation *op) {
153+
if (isa<ReshapeOp, TransOp>(op))
154+
return true;
155+
if (auto cvt = dyn_cast<ttg::ConvertLayoutOp>(op)) {
156+
// The conversion op is a noop if the conversion layout is trivial
157+
return minimalCvtLayout(cvt.getSrc().getType(),
158+
cvt.getResult().getType()) == LinearLayout::empty();
159+
}
160+
return false;
161+
}
162+
152163
//===----------------------------------------------------------------------===//
153164
// GraphDumper
154165
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)