1414#include " triton/Dialect/TritonGPU/Transforms/Utility.h"
1515#include " triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
1616#include " triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h"
17+ #include " triton/Tools/LinearLayout.h"
1718#include " llvm/ADT/MapVector.h"
1819#include " llvm/ADT/STLExtras.h"
1920#include " llvm/ADT/SetVector.h"
@@ -30,6 +31,30 @@ namespace tt = mlir::triton;
3031namespace ttg = mlir::triton::gpu;
3132namespace ttng = mlir::triton::nvidia_gpu;
3233
34+ // Returns whether the dot dot such that:
35+ // 1. The LHS comes from registers and
36+ // 1.1 The LHS is defined inside the loop
37+ // 1.2. The LHS does not come from another dot
38+ // For these dots, we assume that we cannot rewrite their
39+ // operands until the previous dot has finished
40+ static bool isRSDotFromSIMD (Operation *dot, scf::ForOp forOp) {
41+ auto dotOp = dyn_cast<ttng::WarpGroupDotOp>(dot);
42+ if (!dotOp)
43+ return false ;
44+ auto a = dotOp.getA ();
45+ if (!isa<RankedTensorType>(a.getType ())) {
46+ return false ;
47+ }
48+ if (forOp.isDefinedOutsideOfLoop (a)) {
49+ return false ;
50+ }
51+ if (auto cvt = dyn_cast<ttg::ConvertLayoutOp>(a.getDefiningOp ())) {
52+ return !isa<ttg::NvidiaMmaEncodingAttr>(
53+ cvt.getSrc ().getType ().getEncoding ());
54+ }
55+ return true ;
56+ }
57+
3358// / Find the minimum number of async_commit_group ops between the wait
3459// / and the associated async_commit_group. This can be safely used as the wait
3560// / number.
@@ -206,6 +231,148 @@ static void threadValuesThroughWait(ttng::WarpGroupDotWaitOp wait,
206231 wait->erase ();
207232}
208233
234+ // Split the LHS of a RSWGMMADot operation into multiple multiple
235+ // tensors of size MxnewK via SplitOps
236+ SmallVector<Value> splitLhs (OpBuilder &builder,
237+ TypedValue<RankedTensorType> lhs, int64_t newK) {
238+ auto loc = lhs.getLoc ();
239+ auto type = lhs.getType ();
240+ auto rank = type.getRank ();
241+ auto shape = to_vector (type.getShape ());
242+ auto nSplits = shape.back () / newK;
243+ assert (nSplits > 1 );
244+ // Reshape K == 2x..x2xnewK
245+ shape.pop_back ();
246+ for (int i = 1 ; i < nSplits; i *= 2 ) {
247+ shape.push_back (2 );
248+ }
249+ shape.push_back (newK);
250+ lhs = builder.create <tt::ReshapeOp>(loc, shape, lhs);
251+ // We want to split first the slowest running dim, then the second slowest,
252+ // etc.
253+ auto transOrder = to_vector (llvm::seq<int >(rank - 1 ));
254+ transOrder.push_back (shape.size () - 1 );
255+ llvm::append_range (transOrder, llvm::reverse (llvm::seq (
256+ rank - 1 , (int64_t )shape.size () - 1 )));
257+ lhs = builder.create <tt::TransOp>(loc, lhs, transOrder);
258+ // We split recursively
259+ SmallVector<Value> curr;
260+ SmallVector<Value> ret = {lhs};
261+ for (int i = 1 ; i < nSplits; i *= 2 ) {
262+ curr = ret;
263+ ret.clear ();
264+ for (auto v : curr) {
265+ auto split = builder.create <tt::SplitOp>(loc, v);
266+ ret.push_back (split.getResult (0 ));
267+ ret.push_back (split.getResult (1 ));
268+ }
269+ }
270+
271+ auto mmav3Type =
272+ type.clone (cast<RankedTensorType>(ret.front ().getType ()).getShape ());
273+ // Convert the LHS to mmav3 layout
274+ for (auto &v : ret) {
275+ v = builder.create <ttg::ConvertLayoutOp>(loc, mmav3Type, v);
276+ // The layouts are noops by construction
277+ assert (minimalCvtLayout (v.getType (), mmav3Type) ==
278+ tt::LinearLayout::empty ());
279+ }
280+ assert (ret.size () == nSplits);
281+ return ret;
282+ }
283+
284+ // Split the RHS of a RSWGMMADot operation into multiple multiple
285+ // tensors of size newKxN via MemDescSubview
286+ SmallVector<Value> splitRhs (OpBuilder &builder,
287+ TypedValue<ttg::MemDescType> rhs, int64_t newK) {
288+ auto loc = rhs.getLoc ();
289+ auto type = rhs.getType ();
290+ auto rank = type.getRank ();
291+ auto kDim = rank - 2 ;
292+ auto nSplits = type.getShape ()[kDim ] / newK;
293+ auto shape = llvm::to_vector (type.getShape ());
294+ shape[kDim ] = newK;
295+ SmallVector<Value> offsetsVal;
296+ for (int i = 0 ; i < rank; i++) {
297+ offsetsVal.push_back (builder.create <arith::ConstantIntOp>(loc, 0 , 32 ));
298+ }
299+ auto newType = ttg::MemDescType::get (
300+ shape, type.getElementType (), type.getEncoding (), type.getMemorySpace (),
301+ /* isMutable=*/ false , type.getAllocShape ());
302+ SmallVector<Value> ret;
303+ for (int i = 0 ; i < nSplits; i++) {
304+ offsetsVal[kDim ] = builder.create <arith::ConstantIntOp>(loc, i * newK, 32 );
305+ Value newSmem = builder.create <triton::gpu::MemDescSubviewOp>(
306+ loc, newType, rhs, offsetsVal);
307+ ret.push_back (newSmem);
308+ }
309+ return ret;
310+ }
311+
312+ std::vector<ttng::WarpGroupDotOp> splitRSDot (ttng::WarpGroupDotOp dotOp) {
313+ // Splits a wgmma(tensor, shmem) MxK, KxN -> MxN into
314+ // along K into multiple wgmma(tensor, shmem) Mx16, 16xN -> MxN
315+ // where 16 is the instruction size
316+ if (!isa<RankedTensorType>(dotOp.getA ().getType ())) {
317+ return {dotOp};
318+ }
319+
320+ auto a = cast<TypedValue<RankedTensorType>>(dotOp.getA ());
321+ auto b = cast<TypedValue<ttg::MemDescType>>(dotOp.getB ());
322+ auto origK = a.getType ().getShape ().back ();
323+ auto newK = cast<ttg::NvidiaMmaEncodingAttr>(dotOp.getType ().getEncoding ())
324+ .getInstrShape ()[2 ];
325+ auto numSplits = origK / newK;
326+ // Nothing to split
327+ if (numSplits <= 1 ) {
328+ return {dotOp};
329+ }
330+
331+ assert (origK % newK == 0 && " origK must be divisible by newK" );
332+ auto builder = OpBuilder (dotOp);
333+ auto loc = dotOp.getLoc ();
334+ auto lhss = splitLhs (builder, a, newK);
335+ auto rhss = splitRhs (builder, b, newK);
336+ assert (lhss.size () == numSplits && " lhs must have the same number of splits" );
337+ assert (rhss.size () == numSplits && " rhs must have the same number of splits" );
338+
339+ Value useC = dotOp.getUseC ();
340+ Value C = dotOp.getC ();
341+ auto numImpreciseAccLeft = dotOp.getMaxNumImpreciseAcc ();
342+ std::vector<ttng::WarpGroupDotOp> dots;
343+ for (int i = 0 ; i < numSplits; i++) {
344+ // 2**30 is to prevent the subtile from adding
345+ // extra imprecise accumulator, See WGMMA.cpp
346+ uint32_t numImpreciseAcc = (numImpreciseAccLeft > newK)
347+ ? 1073741824 // 2**30
348+ : numImpreciseAccLeft;
349+ // Deduct the actual consumed imprecise acc
350+ numImpreciseAccLeft -= std::min (numImpreciseAccLeft, newK);
351+ auto dot = builder.create <ttng::WarpGroupDotOp>(
352+ loc, dotOp.getType (), lhss[i], rhss[i], C, useC,
353+ dotOp.getInputPrecision (), numImpreciseAcc, dotOp.getIsAsync ());
354+ dots.push_back (dot);
355+ C = dot.getResult ();
356+ useC = builder.create <mlir::arith::ConstantIntOp>(loc, 1 , 1 );
357+ }
358+ dotOp.replaceAllUsesWith (dots.back ().getResult ());
359+ dotOp.erase ();
360+ return dots;
361+ }
362+
363+ // Apply splitRSDot to all dots in the input list.
364+ llvm::MapVector<Operation *, int >
365+ splitRSDots (const llvm::MapVector<Operation *, int > &dots) {
366+ llvm::MapVector<Operation *, int > ret;
367+ for (auto [dot, iterArgIdx] : dots) {
368+ auto newDots = splitRSDot (cast<ttng::WarpGroupDotOp>(dot));
369+ for (auto newDot : newDots) {
370+ ret.insert ({newDot, iterArgIdx});
371+ }
372+ }
373+ return ret;
374+ }
375+
209376// Determines whether a given MMAv3 dot op, represented as ttng.warp_group_dot,
210377// needs a wait immediately after it.
211378//
@@ -260,21 +427,11 @@ static std::optional<int> dotCanBeProperlyAsync(ttng::WarpGroupDotOp dotOp,
260427 scf::ForOp forOp) {
261428 LDBG (" Considering whether to make MMAv3 dot properly async: " << dotOp);
262429
263- // Rule 1: All shmem operands are multi-buffered.
264430 auto checkOperand = [&](Value operand) {
265- if (!isa<ttg::SharedEncodingTrait>(
266- cast<ttg::TensorOrMemDesc>(operand.getType ()).getEncoding ())) {
267- // Rule 1a: Register operands must not be modified within the loop.
268- // First, check for chained WGMMA as an exception.
269- if (auto cvt = dyn_cast<ttg::ConvertLayoutOp>(operand.getDefiningOp ())) {
270- return isa<ttg::NvidiaMmaEncodingAttr>(
271- cvt.getSrc ().getType ().getEncoding ());
272- }
273- // And then, do a stricter-than-necessary check for now, that the operand
274- // is defined outside the loop.
275- return forOp.isDefinedOutsideOfLoop (operand);
431+ // We can always make RSGEMM async s long as the RHS can be multi-buffered
432+ if (isa<RankedTensorType>(operand.getType ())) {
433+ return true ;
276434 }
277-
278435 // If it's a shmem operand, it must either be defined outside the loop, or
279436 // come from an MemDescSubview op. Only ConvertLayout and view ops are
280437 // allowed in between.
@@ -296,6 +453,7 @@ static std::optional<int> dotCanBeProperlyAsync(ttng::WarpGroupDotOp dotOp,
296453 transitiveOperand.getDefiningOp <ttg::MemDescSubviewOp>();
297454 };
298455
456+ // Rule 1: All shmem operands are multi-buffered.
299457 // We don't have to call checkOperand on getC() because it's always in
300458 // registers, never in shmem.
301459 assert (isa<ttg::NvidiaMmaEncodingAttr>(dotOp.getC ().getType ().getEncoding ()));
@@ -315,6 +473,13 @@ static std::optional<int> dotCanBeProperlyAsync(ttng::WarpGroupDotOp dotOp,
315473 while (!queue.empty ()) {
316474 auto [user, argIdx] = queue.pop_back_val ();
317475 if (user->getParentOp () == forOp) {
476+ // We support noops in between the dot and the yield
477+ if (isNoop (user)) {
478+ for (auto &use : user->getResult (0 ).getUses ()) {
479+ queue.push_back ({use.getOwner (), use.getOperandNumber ()});
480+ }
481+ continue ;
482+ }
318483 if (isa<scf::YieldOp>(user)) {
319484 if (iterArg) {
320485 // The dot is used by the loop's yield, but we can't have any other
@@ -343,15 +508,28 @@ static std::optional<int> dotCanBeProperlyAsync(ttng::WarpGroupDotOp dotOp,
343508 return std::nullopt ;
344509 }
345510 }
511+ // Rule 2.1: We don't make the dot async if the accumulator is not fp32.
512+ if (!dotOp.getC ().getType ().getElementType ().isF32 ()) {
513+ LDBG (" Can't make dot async because the accumulator is not fp32" );
514+ return std::nullopt ;
515+ }
346516
347- // Rule 3a: Are the only users of the dot's result from iteration i-1 other
348- // MMAv3 dots? If so, we're done, this dot can be properly async.
349- if (llvm::all_of (iterArg.getUses (), [&](OpOperand &use) {
350- return isa<ttng::WarpGroupDotOp>(use.getOwner ()) &&
351- use.getOperandNumber () == 2 ;
352- })) {
517+ // Rule 3a: Check that every use of the dot’s result (iterArg) eventually
518+ // reaches a WarpGroupDotOp (with use index 2), possibly after passing through
519+ // a chain of noops
520+ std::function<bool (OpOperand &)> isTransitivelyWarpGroupDot =
521+ [&](OpOperand &use) -> bool {
522+ Operation *user = use.getOwner ();
523+ if (isa<ttng::WarpGroupDotOp>(user))
524+ return use.getOperandNumber () == 2 ;
525+ if (isNoop (user))
526+ return llvm::all_of (user->getResult (0 ).getUses (),
527+ isTransitivelyWarpGroupDot);
528+ return false ;
529+ };
530+
531+ if (llvm::all_of (iterArg.getUses (), isTransitivelyWarpGroupDot))
353532 return iterArgIdx;
354- }
355533
356534 // Rule 3b: Are all users of the dot's result from iteration i-1 after the
357535 // first `warp_group_dot_wait {pendings=0}` op? If so, the dot can be
@@ -414,7 +592,21 @@ static void insertAsyncWarpGroupDotWaitInLoop(
414592
415593 // Insert waits before the users of the properly async dots other than loop
416594 // yield.
417- for (auto [asyncDot, iterArgIdx] : properlyAsyncDots) {
595+ for (auto asyncDot : llvm::make_first_range (properlyAsyncDots)) {
596+ // If the dot takes the LHS on registers i, we add a wait for the number
597+ // of properly async dots in the loop minus one.
598+ // This makes sure that the dot will wait until itself from the previous
599+ // iteration has completed, as to avoid rewriting the registers.
600+ if (isRSDotFromSIMD (asyncDot, forOp)) {
601+ OpBuilder builder (asyncDot);
602+ builder.setInsertionPointAfter (asyncDot);
603+ auto newWait = builder.create <ttng::WarpGroupDotWaitOp>(
604+ asyncDot->getLoc (), ArrayRef<Value>{}, properlyAsyncDots.size () - 1 );
605+ SmallVector<Value> waitOperands = {asyncDot->getResult (0 )};
606+ threadValuesThroughWait (newWait, waitOperands);
607+ continue ;
608+ }
609+
418610 SmallVector<OpOperand *> uses;
419611 for (auto &use : asyncDot->getUses ()) {
420612 if (auto yieldOp = dyn_cast<scf::YieldOp>(use.getOwner ())) {
@@ -448,6 +640,11 @@ static void insertAsyncWarpGroupDotWaitInLoop(
448640 // by a dot.)
449641 IRRewriter builder (forOp.getContext ());
450642 auto lastAsyncDot = properlyAsyncDots.back ().first ;
643+ // If the last dot is an RS dot, we don't need to insert a wait
644+ // as we have already inserted a wait(properlyAsyncDots.size() - 1)
645+ if (isRSDotFromSIMD (lastAsyncDot, forOp)) {
646+ return ;
647+ }
451648 builder.setInsertionPointAfter (lastAsyncDot);
452649 auto wait = builder.create <ttng::WarpGroupDotWaitOp>(
453650 lastAsyncDot->getLoc (),
@@ -504,6 +701,11 @@ void triton::asyncLaunchDots(scf::ForOp forOp) {
504701 return ;
505702 }
506703
704+ // Split RS dots into dots with K = 16 (the instruction size of MMAv3)
705+ // If we split them in nSplit dots, we will be able to keep nSplit-1 dots
706+ // in flight at a time.
707+ properlyAsyncDots = splitRSDots (properlyAsyncDots);
708+
507709 // Next, insert a wait inside the loop. We pipeline to depth 2, so the third
508710 // iteration's set of asynchronous dots (and their corresponding async copies
509711 // from global to shmem) can't start until the first iteration's set has
0 commit comments