Skip to content

Commit 0bb6019

Browse files
authored
Revert "[PIPELINER] Pipeline RS WGMMA (#6804)" (#6810)
This reverts commit 21fd9eb as it breaks an internal test
1 parent 21fd9eb commit 0bb6019

File tree

11 files changed

+1129
-389
lines changed

11 files changed

+1129
-389
lines changed

include/triton/Dialect/TritonGPU/Transforms/Passes.td

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,23 @@ def TritonGPUPrefetch : Pass<"tritongpu-prefetch", "mlir::ModuleOp"> {
219219
"mlir::arith::ArithDialect"];
220220
}
221221

222+
def TritonGPUWGMMAPrefetch : Pass<"tritongpu-wgmma-prefetch", "mlir::ModuleOp"> {
223+
let summary = "prefetch for wgmma mixed precision";
224+
225+
let description = [{
226+
This pass attempts to prefetch from shared memory for mixed-precision
227+
wgmma when operand A is in the shared memory and needs to be loaded
228+
to the local registers.
229+
}];
230+
231+
let dependentDialects = [ "mlir::triton::gpu::TritonGPUDialect",
232+
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect",
233+
"mlir::scf::SCFDialect",
234+
"mlir::arith::ArithDialect"];
235+
}
236+
237+
238+
222239
def TritonGPUAccelerateMatmul : Pass<"tritongpu-accelerate-matmul", "mlir::ModuleOp"> {
223240
let summary = "accelerate matmul";
224241

include/triton/Dialect/TritonGPU/Transforms/Utility.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,6 @@ getNumElementsPerThread(Operation *op, SmallVector<unsigned> order,
5454
// Returns whether the op is a "view op", i.e. doesn't move any data
5555
bool isView(Operation *op);
5656

57-
// Returns whether the op is a "noop op", i.e. has one input and one output
58-
// and lowers to llvm as the identity function (returns the input)
59-
bool isNoop(Operation *op);
60-
6157
/* Dump Triton IR in graphviz dot format.
6258
*
6359
* You can override `onValue` and `onOperation` in a subclass to mark

lib/Dialect/TritonGPU/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ add_triton_library(TritonGPUTransforms
2525
Pipeliner/PipeliningUtility.cpp
2626
Pipeliner/Schedule.cpp
2727
Prefetch.cpp
28+
WGMMAPrefetch.cpp
2829
RemoveLayoutConversions.cpp
2930
ReorderInstructions.cpp
3031
CoalesceAsyncCopy.cpp

lib/Dialect/TritonGPU/Transforms/Pipeliner/WGMMAPipeline.cpp

Lines changed: 21 additions & 223 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
1515
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
1616
#include "triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h"
17-
#include "triton/Tools/LinearLayout.h"
1817
#include "llvm/ADT/MapVector.h"
1918
#include "llvm/ADT/STLExtras.h"
2019
#include "llvm/ADT/SetVector.h"
@@ -31,30 +30,6 @@ namespace tt = mlir::triton;
3130
namespace ttg = mlir::triton::gpu;
3231
namespace ttng = mlir::triton::nvidia_gpu;
3332

34-
// Returns whether the dot dot such that:
35-
// 1. The LHS comes from registers and
36-
// 1.1 The LHS is defined inside the loop
37-
// 1.2. The LHS does not come from another dot
38-
// For these dots, we assume that we cannot rewrite their
39-
// operands until the previous dot has finished
40-
static bool isRSDotFromSIMD(Operation *dot, scf::ForOp forOp) {
41-
auto dotOp = dyn_cast<ttng::WarpGroupDotOp>(dot);
42-
if (!dotOp)
43-
return false;
44-
auto a = dotOp.getA();
45-
if (!isa<RankedTensorType>(a.getType())) {
46-
return false;
47-
}
48-
if (forOp.isDefinedOutsideOfLoop(a)) {
49-
return false;
50-
}
51-
if (auto cvt = dyn_cast<ttg::ConvertLayoutOp>(a.getDefiningOp())) {
52-
return !isa<ttg::NvidiaMmaEncodingAttr>(
53-
cvt.getSrc().getType().getEncoding());
54-
}
55-
return true;
56-
}
57-
5833
/// Find the minimum number of async_commit_group ops between the wait
5934
/// and the associated async_commit_group. This can be safely used as the wait
6035
/// number.
@@ -231,148 +206,6 @@ static void threadValuesThroughWait(ttng::WarpGroupDotWaitOp wait,
231206
wait->erase();
232207
}
233208

234-
// Split the LHS of a RSWGMMADot operation into multiple multiple
235-
// tensors of size MxnewK via SplitOps
236-
SmallVector<Value> splitLhs(OpBuilder &builder,
237-
TypedValue<RankedTensorType> lhs, int64_t newK) {
238-
auto loc = lhs.getLoc();
239-
auto type = lhs.getType();
240-
auto rank = type.getRank();
241-
auto shape = to_vector(type.getShape());
242-
auto nSplits = shape.back() / newK;
243-
assert(nSplits > 1);
244-
// Reshape K == 2x..x2xnewK
245-
shape.pop_back();
246-
for (int i = 1; i < nSplits; i *= 2) {
247-
shape.push_back(2);
248-
}
249-
shape.push_back(newK);
250-
lhs = builder.create<tt::ReshapeOp>(loc, shape, lhs);
251-
// We want to split first the slowest running dim, then the second slowest,
252-
// etc.
253-
auto transOrder = to_vector(llvm::seq<int>(rank - 1));
254-
transOrder.push_back(shape.size() - 1);
255-
llvm::append_range(transOrder, llvm::reverse(llvm::seq(
256-
rank - 1, (int64_t)shape.size() - 1)));
257-
lhs = builder.create<tt::TransOp>(loc, lhs, transOrder);
258-
// We split recursively
259-
SmallVector<Value> curr;
260-
SmallVector<Value> ret = {lhs};
261-
for (int i = 1; i < nSplits; i *= 2) {
262-
curr = ret;
263-
ret.clear();
264-
for (auto v : curr) {
265-
auto split = builder.create<tt::SplitOp>(loc, v);
266-
ret.push_back(split.getResult(0));
267-
ret.push_back(split.getResult(1));
268-
}
269-
}
270-
271-
auto mmav3Type =
272-
type.clone(cast<RankedTensorType>(ret.front().getType()).getShape());
273-
// Convert the LHS to mmav3 layout
274-
for (auto &v : ret) {
275-
v = builder.create<ttg::ConvertLayoutOp>(loc, mmav3Type, v);
276-
// The layouts are noops by construction
277-
assert(minimalCvtLayout(v.getType(), mmav3Type) ==
278-
tt::LinearLayout::empty());
279-
}
280-
assert(ret.size() == nSplits);
281-
return ret;
282-
}
283-
284-
// Split the RHS of a RSWGMMADot operation into multiple multiple
285-
// tensors of size newKxN via MemDescSubview
286-
SmallVector<Value> splitRhs(OpBuilder &builder,
287-
TypedValue<ttg::MemDescType> rhs, int64_t newK) {
288-
auto loc = rhs.getLoc();
289-
auto type = rhs.getType();
290-
auto rank = type.getRank();
291-
auto kDim = rank - 2;
292-
auto nSplits = type.getShape()[kDim] / newK;
293-
auto shape = llvm::to_vector(type.getShape());
294-
shape[kDim] = newK;
295-
SmallVector<Value> offsetsVal;
296-
for (int i = 0; i < rank; i++) {
297-
offsetsVal.push_back(builder.create<arith::ConstantIntOp>(loc, 0, 32));
298-
}
299-
auto newType = ttg::MemDescType::get(
300-
shape, type.getElementType(), type.getEncoding(), type.getMemorySpace(),
301-
/*isMutable=*/false, type.getAllocShape());
302-
SmallVector<Value> ret;
303-
for (int i = 0; i < nSplits; i++) {
304-
offsetsVal[kDim] = builder.create<arith::ConstantIntOp>(loc, i * newK, 32);
305-
Value newSmem = builder.create<triton::gpu::MemDescSubviewOp>(
306-
loc, newType, rhs, offsetsVal);
307-
ret.push_back(newSmem);
308-
}
309-
return ret;
310-
}
311-
312-
std::vector<ttng::WarpGroupDotOp> splitRSDot(ttng::WarpGroupDotOp dotOp) {
313-
// Splits a wgmma(tensor, shmem) MxK, KxN -> MxN into
314-
// along K into multiple wgmma(tensor, shmem) Mx16, 16xN -> MxN
315-
// where 16 is the instruction size
316-
if (!isa<RankedTensorType>(dotOp.getA().getType())) {
317-
return {dotOp};
318-
}
319-
320-
auto a = cast<TypedValue<RankedTensorType>>(dotOp.getA());
321-
auto b = cast<TypedValue<ttg::MemDescType>>(dotOp.getB());
322-
auto origK = a.getType().getShape().back();
323-
auto newK = cast<ttg::NvidiaMmaEncodingAttr>(dotOp.getType().getEncoding())
324-
.getInstrShape()[2];
325-
auto numSplits = origK / newK;
326-
// Nothing to split
327-
if (numSplits <= 1) {
328-
return {dotOp};
329-
}
330-
331-
assert(origK % newK == 0 && "origK must be divisible by newK");
332-
auto builder = OpBuilder(dotOp);
333-
auto loc = dotOp.getLoc();
334-
auto lhss = splitLhs(builder, a, newK);
335-
auto rhss = splitRhs(builder, b, newK);
336-
assert(lhss.size() == numSplits && "lhs must have the same number of splits");
337-
assert(rhss.size() == numSplits && "rhs must have the same number of splits");
338-
339-
Value useC = dotOp.getUseC();
340-
Value C = dotOp.getC();
341-
auto numImpreciseAccLeft = dotOp.getMaxNumImpreciseAcc();
342-
std::vector<ttng::WarpGroupDotOp> dots;
343-
for (int i = 0; i < numSplits; i++) {
344-
// 2**30 is to prevent the subtile from adding
345-
// extra imprecise accumulator, See WGMMA.cpp
346-
uint32_t numImpreciseAcc = (numImpreciseAccLeft > newK)
347-
? 1073741824 // 2**30
348-
: numImpreciseAccLeft;
349-
// Deduct the actual consumed imprecise acc
350-
numImpreciseAccLeft -= std::min(numImpreciseAccLeft, newK);
351-
auto dot = builder.create<ttng::WarpGroupDotOp>(
352-
loc, dotOp.getType(), lhss[i], rhss[i], C, useC,
353-
dotOp.getInputPrecision(), numImpreciseAcc, dotOp.getIsAsync());
354-
dots.push_back(dot);
355-
C = dot.getResult();
356-
useC = builder.create<mlir::arith::ConstantIntOp>(loc, 1, 1);
357-
}
358-
dotOp.replaceAllUsesWith(dots.back().getResult());
359-
dotOp.erase();
360-
return dots;
361-
}
362-
363-
// Apply splitRSDot to all dots in the input list.
364-
llvm::MapVector<Operation *, int>
365-
splitRSDots(const llvm::MapVector<Operation *, int> &dots) {
366-
llvm::MapVector<Operation *, int> ret;
367-
for (auto [dot, iterArgIdx] : dots) {
368-
auto newDots = splitRSDot(cast<ttng::WarpGroupDotOp>(dot));
369-
for (auto newDot : newDots) {
370-
ret.insert({newDot, iterArgIdx});
371-
}
372-
}
373-
return ret;
374-
}
375-
376209
// Determines whether a given MMAv3 dot op, represented as ttng.warp_group_dot,
377210
// needs a wait immediately after it.
378211
//
@@ -427,11 +260,21 @@ static std::optional<int> dotCanBeProperlyAsync(ttng::WarpGroupDotOp dotOp,
427260
scf::ForOp forOp) {
428261
LDBG("Considering whether to make MMAv3 dot properly async: " << dotOp);
429262

263+
// Rule 1: All shmem operands are multi-buffered.
430264
auto checkOperand = [&](Value operand) {
431-
// We can always make RSGEMM async s long as the RHS can be multi-buffered
432-
if (isa<RankedTensorType>(operand.getType())) {
433-
return true;
265+
if (!isa<ttg::SharedEncodingTrait>(
266+
cast<ttg::TensorOrMemDesc>(operand.getType()).getEncoding())) {
267+
// Rule 1a: Register operands must not be modified within the loop.
268+
// First, check for chained WGMMA as an exception.
269+
if (auto cvt = dyn_cast<ttg::ConvertLayoutOp>(operand.getDefiningOp())) {
270+
return isa<ttg::NvidiaMmaEncodingAttr>(
271+
cvt.getSrc().getType().getEncoding());
272+
}
273+
// And then, do a stricter-than-necessary check for now, that the operand
274+
// is defined outside the loop.
275+
return forOp.isDefinedOutsideOfLoop(operand);
434276
}
277+
435278
// If it's a shmem operand, it must either be defined outside the loop, or
436279
// come from an MemDescSubview op. Only ConvertLayout and view ops are
437280
// allowed in between.
@@ -453,7 +296,6 @@ static std::optional<int> dotCanBeProperlyAsync(ttng::WarpGroupDotOp dotOp,
453296
transitiveOperand.getDefiningOp<ttg::MemDescSubviewOp>();
454297
};
455298

456-
// Rule 1: All shmem operands are multi-buffered.
457299
// We don't have to call checkOperand on getC() because it's always in
458300
// registers, never in shmem.
459301
assert(isa<ttg::NvidiaMmaEncodingAttr>(dotOp.getC().getType().getEncoding()));
@@ -473,13 +315,6 @@ static std::optional<int> dotCanBeProperlyAsync(ttng::WarpGroupDotOp dotOp,
473315
while (!queue.empty()) {
474316
auto [user, argIdx] = queue.pop_back_val();
475317
if (user->getParentOp() == forOp) {
476-
// We support noops in between the dot and the yield
477-
if (isNoop(user)) {
478-
for (auto &use : user->getResult(0).getUses()) {
479-
queue.push_back({use.getOwner(), use.getOperandNumber()});
480-
}
481-
continue;
482-
}
483318
if (isa<scf::YieldOp>(user)) {
484319
if (iterArg) {
485320
// The dot is used by the loop's yield, but we can't have any other
@@ -508,28 +343,15 @@ static std::optional<int> dotCanBeProperlyAsync(ttng::WarpGroupDotOp dotOp,
508343
return std::nullopt;
509344
}
510345
}
511-
// Rule 2.1: We don't make the dot async if the accumulator is not fp32.
512-
if (!dotOp.getC().getType().getElementType().isF32()) {
513-
LDBG("Can't make dot async because the accumulator is not fp32");
514-
return std::nullopt;
515-
}
516-
517-
// Rule 3a: Check that every use of the dot’s result (iterArg) eventually
518-
// reaches a WarpGroupDotOp (with use index 2), possibly after passing through
519-
// a chain of noops
520-
std::function<bool(OpOperand &)> isTransitivelyWarpGroupDot =
521-
[&](OpOperand &use) -> bool {
522-
Operation *user = use.getOwner();
523-
if (isa<ttng::WarpGroupDotOp>(user))
524-
return use.getOperandNumber() == 2;
525-
if (isNoop(user))
526-
return llvm::all_of(user->getResult(0).getUses(),
527-
isTransitivelyWarpGroupDot);
528-
return false;
529-
};
530346

531-
if (llvm::all_of(iterArg.getUses(), isTransitivelyWarpGroupDot))
347+
// Rule 3a: Are the only users of the dot's result from iteration i-1 other
348+
// MMAv3 dots? If so, we're done, this dot can be properly async.
349+
if (llvm::all_of(iterArg.getUses(), [&](OpOperand &use) {
350+
return isa<ttng::WarpGroupDotOp>(use.getOwner()) &&
351+
use.getOperandNumber() == 2;
352+
})) {
532353
return iterArgIdx;
354+
}
533355

534356
// Rule 3b: Are all users of the dot's result from iteration i-1 after the
535357
// first `warp_group_dot_wait {pendings=0}` op? If so, the dot can be
@@ -592,21 +414,7 @@ static void insertAsyncWarpGroupDotWaitInLoop(
592414

593415
// Insert waits before the users of the properly async dots other than loop
594416
// yield.
595-
for (auto asyncDot : llvm::make_first_range(properlyAsyncDots)) {
596-
// If the dot takes the LHS on registers i, we add a wait for the number
597-
// of properly async dots in the loop minus one.
598-
// This makes sure that the dot will wait until itself from the previous
599-
// iteration has completed, as to avoid rewriting the registers.
600-
if (isRSDotFromSIMD(asyncDot, forOp)) {
601-
OpBuilder builder(asyncDot);
602-
builder.setInsertionPointAfter(asyncDot);
603-
auto newWait = builder.create<ttng::WarpGroupDotWaitOp>(
604-
asyncDot->getLoc(), ArrayRef<Value>{}, properlyAsyncDots.size() - 1);
605-
SmallVector<Value> waitOperands = {asyncDot->getResult(0)};
606-
threadValuesThroughWait(newWait, waitOperands);
607-
continue;
608-
}
609-
417+
for (auto [asyncDot, iterArgIdx] : properlyAsyncDots) {
610418
SmallVector<OpOperand *> uses;
611419
for (auto &use : asyncDot->getUses()) {
612420
if (auto yieldOp = dyn_cast<scf::YieldOp>(use.getOwner())) {
@@ -640,11 +448,6 @@ static void insertAsyncWarpGroupDotWaitInLoop(
640448
// by a dot.)
641449
IRRewriter builder(forOp.getContext());
642450
auto lastAsyncDot = properlyAsyncDots.back().first;
643-
// If the last dot is an RS dot, we don't need to insert a wait
644-
// as we have already inserted a wait(properlyAsyncDots.size() - 1)
645-
if (isRSDotFromSIMD(lastAsyncDot, forOp)) {
646-
return;
647-
}
648451
builder.setInsertionPointAfter(lastAsyncDot);
649452
auto wait = builder.create<ttng::WarpGroupDotWaitOp>(
650453
lastAsyncDot->getLoc(),
@@ -701,11 +504,6 @@ void triton::asyncLaunchDots(scf::ForOp forOp) {
701504
return;
702505
}
703506

704-
// Split RS dots into dots with K = 16 (the instruction size of MMAv3)
705-
// If we split them in nSplit dots, we will be able to keep nSplit-1 dots
706-
// in flight at a time.
707-
properlyAsyncDots = splitRSDots(properlyAsyncDots);
708-
709507
// Next, insert a wait inside the loop. We pipeline to depth 2, so the third
710508
// iteration's set of asynchronous dots (and their corresponding async copies
711509
// from global to shmem) can't start until the first iteration's set has

lib/Dialect/TritonGPU/Transforms/Utility.cpp

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -149,17 +149,6 @@ bool isView(Operation *op) {
149149
return isa<ExpandDimsOp, ReshapeOp, TransOp, JoinOp, SplitOp>(op);
150150
}
151151

152-
bool isNoop(Operation *op) {
153-
if (isa<ReshapeOp, TransOp>(op))
154-
return true;
155-
if (auto cvt = dyn_cast<ttg::ConvertLayoutOp>(op)) {
156-
// The conversion op is a noop if the conversion layout is trivial
157-
return minimalCvtLayout(cvt.getSrc().getType(),
158-
cvt.getResult().getType()) == LinearLayout::empty();
159-
}
160-
return false;
161-
}
162-
163152
//===----------------------------------------------------------------------===//
164153
// GraphDumper
165154
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)