1414#include " triton/Dialect/TritonGPU/Transforms/Utility.h"
1515#include " triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
1616#include " triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h"
17- #include " triton/Tools/LinearLayout.h"
1817#include " llvm/ADT/MapVector.h"
1918#include " llvm/ADT/STLExtras.h"
2019#include " llvm/ADT/SetVector.h"
@@ -31,30 +30,6 @@ namespace tt = mlir::triton;
3130namespace ttg = mlir::triton::gpu;
3231namespace ttng = mlir::triton::nvidia_gpu;
3332
34- // Returns whether the dot dot such that:
35- // 1. The LHS comes from registers and
36- // 1.1 The LHS is defined inside the loop
37- // 1.2. The LHS does not come from another dot
38- // For these dots, we assume that we cannot rewrite their
39- // operands until the previous dot has finished
40- static bool isRSDotFromSIMD (Operation *dot, scf::ForOp forOp) {
41- auto dotOp = dyn_cast<ttng::WarpGroupDotOp>(dot);
42- if (!dotOp)
43- return false ;
44- auto a = dotOp.getA ();
45- if (!isa<RankedTensorType>(a.getType ())) {
46- return false ;
47- }
48- if (forOp.isDefinedOutsideOfLoop (a)) {
49- return false ;
50- }
51- if (auto cvt = dyn_cast<ttg::ConvertLayoutOp>(a.getDefiningOp ())) {
52- return !isa<ttg::NvidiaMmaEncodingAttr>(
53- cvt.getSrc ().getType ().getEncoding ());
54- }
55- return true ;
56- }
57-
5833// / Find the minimum number of async_commit_group ops between the wait
5934// / and the associated async_commit_group. This can be safely used as the wait
6035// / number.
@@ -231,148 +206,6 @@ static void threadValuesThroughWait(ttng::WarpGroupDotWaitOp wait,
231206 wait->erase ();
232207}
233208
234- // Split the LHS of a RSWGMMADot operation into multiple multiple
235- // tensors of size MxnewK via SplitOps
236- SmallVector<Value> splitLhs (OpBuilder &builder,
237- TypedValue<RankedTensorType> lhs, int64_t newK) {
238- auto loc = lhs.getLoc ();
239- auto type = lhs.getType ();
240- auto rank = type.getRank ();
241- auto shape = to_vector (type.getShape ());
242- auto nSplits = shape.back () / newK;
243- assert (nSplits > 1 );
244- // Reshape K == 2x..x2xnewK
245- shape.pop_back ();
246- for (int i = 1 ; i < nSplits; i *= 2 ) {
247- shape.push_back (2 );
248- }
249- shape.push_back (newK);
250- lhs = builder.create <tt::ReshapeOp>(loc, shape, lhs);
251- // We want to split first the slowest running dim, then the second slowest,
252- // etc.
253- auto transOrder = to_vector (llvm::seq<int >(rank - 1 ));
254- transOrder.push_back (shape.size () - 1 );
255- llvm::append_range (transOrder, llvm::reverse (llvm::seq (
256- rank - 1 , (int64_t )shape.size () - 1 )));
257- lhs = builder.create <tt::TransOp>(loc, lhs, transOrder);
258- // We split recursively
259- SmallVector<Value> curr;
260- SmallVector<Value> ret = {lhs};
261- for (int i = 1 ; i < nSplits; i *= 2 ) {
262- curr = ret;
263- ret.clear ();
264- for (auto v : curr) {
265- auto split = builder.create <tt::SplitOp>(loc, v);
266- ret.push_back (split.getResult (0 ));
267- ret.push_back (split.getResult (1 ));
268- }
269- }
270-
271- auto mmav3Type =
272- type.clone (cast<RankedTensorType>(ret.front ().getType ()).getShape ());
273- // Convert the LHS to mmav3 layout
274- for (auto &v : ret) {
275- v = builder.create <ttg::ConvertLayoutOp>(loc, mmav3Type, v);
276- // The layouts are noops by construction
277- assert (minimalCvtLayout (v.getType (), mmav3Type) ==
278- tt::LinearLayout::empty ());
279- }
280- assert (ret.size () == nSplits);
281- return ret;
282- }
283-
284- // Split the RHS of a RSWGMMADot operation into multiple multiple
285- // tensors of size newKxN via MemDescSubview
286- SmallVector<Value> splitRhs (OpBuilder &builder,
287- TypedValue<ttg::MemDescType> rhs, int64_t newK) {
288- auto loc = rhs.getLoc ();
289- auto type = rhs.getType ();
290- auto rank = type.getRank ();
291- auto kDim = rank - 2 ;
292- auto nSplits = type.getShape ()[kDim ] / newK;
293- auto shape = llvm::to_vector (type.getShape ());
294- shape[kDim ] = newK;
295- SmallVector<Value> offsetsVal;
296- for (int i = 0 ; i < rank; i++) {
297- offsetsVal.push_back (builder.create <arith::ConstantIntOp>(loc, 0 , 32 ));
298- }
299- auto newType = ttg::MemDescType::get (
300- shape, type.getElementType (), type.getEncoding (), type.getMemorySpace (),
301- /* isMutable=*/ false , type.getAllocShape ());
302- SmallVector<Value> ret;
303- for (int i = 0 ; i < nSplits; i++) {
304- offsetsVal[kDim ] = builder.create <arith::ConstantIntOp>(loc, i * newK, 32 );
305- Value newSmem = builder.create <triton::gpu::MemDescSubviewOp>(
306- loc, newType, rhs, offsetsVal);
307- ret.push_back (newSmem);
308- }
309- return ret;
310- }
311-
312- std::vector<ttng::WarpGroupDotOp> splitRSDot (ttng::WarpGroupDotOp dotOp) {
313- // Splits a wgmma(tensor, shmem) MxK, KxN -> MxN into
314- // along K into multiple wgmma(tensor, shmem) Mx16, 16xN -> MxN
315- // where 16 is the instruction size
316- if (!isa<RankedTensorType>(dotOp.getA ().getType ())) {
317- return {dotOp};
318- }
319-
320- auto a = cast<TypedValue<RankedTensorType>>(dotOp.getA ());
321- auto b = cast<TypedValue<ttg::MemDescType>>(dotOp.getB ());
322- auto origK = a.getType ().getShape ().back ();
323- auto newK = cast<ttg::NvidiaMmaEncodingAttr>(dotOp.getType ().getEncoding ())
324- .getInstrShape ()[2 ];
325- auto numSplits = origK / newK;
326- // Nothing to split
327- if (numSplits <= 1 ) {
328- return {dotOp};
329- }
330-
331- assert (origK % newK == 0 && " origK must be divisible by newK" );
332- auto builder = OpBuilder (dotOp);
333- auto loc = dotOp.getLoc ();
334- auto lhss = splitLhs (builder, a, newK);
335- auto rhss = splitRhs (builder, b, newK);
336- assert (lhss.size () == numSplits && " lhs must have the same number of splits" );
337- assert (rhss.size () == numSplits && " rhs must have the same number of splits" );
338-
339- Value useC = dotOp.getUseC ();
340- Value C = dotOp.getC ();
341- auto numImpreciseAccLeft = dotOp.getMaxNumImpreciseAcc ();
342- std::vector<ttng::WarpGroupDotOp> dots;
343- for (int i = 0 ; i < numSplits; i++) {
344- // 2**30 is to prevent the subtile from adding
345- // extra imprecise accumulator, See WGMMA.cpp
346- uint32_t numImpreciseAcc = (numImpreciseAccLeft > newK)
347- ? 1073741824 // 2**30
348- : numImpreciseAccLeft;
349- // Deduct the actual consumed imprecise acc
350- numImpreciseAccLeft -= std::min (numImpreciseAccLeft, newK);
351- auto dot = builder.create <ttng::WarpGroupDotOp>(
352- loc, dotOp.getType (), lhss[i], rhss[i], C, useC,
353- dotOp.getInputPrecision (), numImpreciseAcc, dotOp.getIsAsync ());
354- dots.push_back (dot);
355- C = dot.getResult ();
356- useC = builder.create <mlir::arith::ConstantIntOp>(loc, 1 , 1 );
357- }
358- dotOp.replaceAllUsesWith (dots.back ().getResult ());
359- dotOp.erase ();
360- return dots;
361- }
362-
363- // Apply splitRSDot to all dots in the input list.
364- llvm::MapVector<Operation *, int >
365- splitRSDots (const llvm::MapVector<Operation *, int > &dots) {
366- llvm::MapVector<Operation *, int > ret;
367- for (auto [dot, iterArgIdx] : dots) {
368- auto newDots = splitRSDot (cast<ttng::WarpGroupDotOp>(dot));
369- for (auto newDot : newDots) {
370- ret.insert ({newDot, iterArgIdx});
371- }
372- }
373- return ret;
374- }
375-
376209// Determines whether a given MMAv3 dot op, represented as ttng.warp_group_dot,
377210// needs a wait immediately after it.
378211//
@@ -427,11 +260,21 @@ static std::optional<int> dotCanBeProperlyAsync(ttng::WarpGroupDotOp dotOp,
427260 scf::ForOp forOp) {
428261 LDBG (" Considering whether to make MMAv3 dot properly async: " << dotOp);
429262
263+ // Rule 1: All shmem operands are multi-buffered.
430264 auto checkOperand = [&](Value operand) {
431- // We can always make RSGEMM async s long as the RHS can be multi-buffered
432- if (isa<RankedTensorType>(operand.getType ())) {
433- return true ;
265+ if (!isa<ttg::SharedEncodingTrait>(
266+ cast<ttg::TensorOrMemDesc>(operand.getType ()).getEncoding ())) {
267+ // Rule 1a: Register operands must not be modified within the loop.
268+ // First, check for chained WGMMA as an exception.
269+ if (auto cvt = dyn_cast<ttg::ConvertLayoutOp>(operand.getDefiningOp ())) {
270+ return isa<ttg::NvidiaMmaEncodingAttr>(
271+ cvt.getSrc ().getType ().getEncoding ());
272+ }
273+ // And then, do a stricter-than-necessary check for now, that the operand
274+ // is defined outside the loop.
275+ return forOp.isDefinedOutsideOfLoop (operand);
434276 }
277+
435278 // If it's a shmem operand, it must either be defined outside the loop, or
436279 // come from an MemDescSubview op. Only ConvertLayout and view ops are
437280 // allowed in between.
@@ -453,7 +296,6 @@ static std::optional<int> dotCanBeProperlyAsync(ttng::WarpGroupDotOp dotOp,
453296 transitiveOperand.getDefiningOp <ttg::MemDescSubviewOp>();
454297 };
455298
456- // Rule 1: All shmem operands are multi-buffered.
457299 // We don't have to call checkOperand on getC() because it's always in
458300 // registers, never in shmem.
459301 assert (isa<ttg::NvidiaMmaEncodingAttr>(dotOp.getC ().getType ().getEncoding ()));
@@ -473,13 +315,6 @@ static std::optional<int> dotCanBeProperlyAsync(ttng::WarpGroupDotOp dotOp,
473315 while (!queue.empty ()) {
474316 auto [user, argIdx] = queue.pop_back_val ();
475317 if (user->getParentOp () == forOp) {
476- // We support noops in between the dot and the yield
477- if (isNoop (user)) {
478- for (auto &use : user->getResult (0 ).getUses ()) {
479- queue.push_back ({use.getOwner (), use.getOperandNumber ()});
480- }
481- continue ;
482- }
483318 if (isa<scf::YieldOp>(user)) {
484319 if (iterArg) {
485320 // The dot is used by the loop's yield, but we can't have any other
@@ -508,28 +343,15 @@ static std::optional<int> dotCanBeProperlyAsync(ttng::WarpGroupDotOp dotOp,
508343 return std::nullopt ;
509344 }
510345 }
511- // Rule 2.1: We don't make the dot async if the accumulator is not fp32.
512- if (!dotOp.getC ().getType ().getElementType ().isF32 ()) {
513- LDBG (" Can't make dot async because the accumulator is not fp32" );
514- return std::nullopt ;
515- }
516-
517- // Rule 3a: Check that every use of the dot’s result (iterArg) eventually
518- // reaches a WarpGroupDotOp (with use index 2), possibly after passing through
519- // a chain of noops
520- std::function<bool (OpOperand &)> isTransitivelyWarpGroupDot =
521- [&](OpOperand &use) -> bool {
522- Operation *user = use.getOwner ();
523- if (isa<ttng::WarpGroupDotOp>(user))
524- return use.getOperandNumber () == 2 ;
525- if (isNoop (user))
526- return llvm::all_of (user->getResult (0 ).getUses (),
527- isTransitivelyWarpGroupDot);
528- return false ;
529- };
530346
531- if (llvm::all_of (iterArg.getUses (), isTransitivelyWarpGroupDot))
347+ // Rule 3a: Are the only users of the dot's result from iteration i-1 other
348+ // MMAv3 dots? If so, we're done, this dot can be properly async.
349+ if (llvm::all_of (iterArg.getUses (), [&](OpOperand &use) {
350+ return isa<ttng::WarpGroupDotOp>(use.getOwner ()) &&
351+ use.getOperandNumber () == 2 ;
352+ })) {
532353 return iterArgIdx;
354+ }
533355
534356 // Rule 3b: Are all users of the dot's result from iteration i-1 after the
535357 // first `warp_group_dot_wait {pendings=0}` op? If so, the dot can be
@@ -592,21 +414,7 @@ static void insertAsyncWarpGroupDotWaitInLoop(
592414
593415 // Insert waits before the users of the properly async dots other than loop
594416 // yield.
595- for (auto asyncDot : llvm::make_first_range (properlyAsyncDots)) {
596- // If the dot takes the LHS on registers i, we add a wait for the number
597- // of properly async dots in the loop minus one.
598- // This makes sure that the dot will wait until itself from the previous
599- // iteration has completed, as to avoid rewriting the registers.
600- if (isRSDotFromSIMD (asyncDot, forOp)) {
601- OpBuilder builder (asyncDot);
602- builder.setInsertionPointAfter (asyncDot);
603- auto newWait = builder.create <ttng::WarpGroupDotWaitOp>(
604- asyncDot->getLoc (), ArrayRef<Value>{}, properlyAsyncDots.size () - 1 );
605- SmallVector<Value> waitOperands = {asyncDot->getResult (0 )};
606- threadValuesThroughWait (newWait, waitOperands);
607- continue ;
608- }
609-
417+ for (auto [asyncDot, iterArgIdx] : properlyAsyncDots) {
610418 SmallVector<OpOperand *> uses;
611419 for (auto &use : asyncDot->getUses ()) {
612420 if (auto yieldOp = dyn_cast<scf::YieldOp>(use.getOwner ())) {
@@ -640,11 +448,6 @@ static void insertAsyncWarpGroupDotWaitInLoop(
640448 // by a dot.)
641449 IRRewriter builder (forOp.getContext ());
642450 auto lastAsyncDot = properlyAsyncDots.back ().first ;
643- // If the last dot is an RS dot, we don't need to insert a wait
644- // as we have already inserted a wait(properlyAsyncDots.size() - 1)
645- if (isRSDotFromSIMD (lastAsyncDot, forOp)) {
646- return ;
647- }
648451 builder.setInsertionPointAfter (lastAsyncDot);
649452 auto wait = builder.create <ttng::WarpGroupDotWaitOp>(
650453 lastAsyncDot->getLoc (),
@@ -701,11 +504,6 @@ void triton::asyncLaunchDots(scf::ForOp forOp) {
701504 return ;
702505 }
703506
704- // Split RS dots into dots with K = 16 (the instruction size of MMAv3)
705- // If we split them in nSplit dots, we will be able to keep nSplit-1 dots
706- // in flight at a time.
707- properlyAsyncDots = splitRSDots (properlyAsyncDots);
708-
709507 // Next, insert a wait inside the loop. We pipeline to depth 2, so the third
710508 // iteration's set of asynchronous dots (and their corresponding async copies
711509 // from global to shmem) can't start until the first iteration's set has
0 commit comments