@@ -207,6 +207,17 @@ LogicalResult triton::gpu::specializeLoadMMADependencies(scf::ForOp &loop,
207207 " loads for `tt.dot` operands" );
208208 }
209209
210+ SmallVector<Operation *> aScaleChain, bScaleChain;
211+ auto scaledMMAOp = dyn_cast<ttng::TCGen5MMAScaledOp>(mmaOp.getOperation ());
212+ if (scaledMMAOp) {
213+ if (failed (
214+ findSingleChainToLoad (loop, scaledMMAOp.getAScale (), aScaleChain)))
215+ aScaleChain.clear ();
216+ if (failed (
217+ findSingleChainToLoad (loop, scaledMMAOp.getBScale (), bScaleChain)))
218+ bScaleChain.clear ();
219+ }
220+
210221 ttng::TMEMAllocOp oldAccAlloc =
211222 mmaOp.getAccumulator ().getDefiningOp <ttng::TMEMAllocOp>();
212223 if (!oldAccAlloc)
@@ -218,7 +229,9 @@ LogicalResult triton::gpu::specializeLoadMMADependencies(scf::ForOp &loop,
218229
219230 // Determine if the MMA accumulator can be multibuffered.
220231 auto isLoadPipelineable = [&](Operation *op) {
221- return llvm::is_contained ({aChain.back (), bChain.back ()}, op);
232+ return llvm::is_contained (llvm::to_vector (llvm::concat<Operation *>(
233+ aChain, bChain, aScaleChain, bScaleChain)),
234+ op);
222235 };
223236 bool accIsMultiBuffered =
224237 // All operand feeds are pipelineable.
@@ -280,16 +293,27 @@ LogicalResult triton::gpu::specializeLoadMMADependencies(scf::ForOp &loop,
280293 Partition *mmaPartition = schedule.addPartition (numStages);
281294
282295 // Multi-buffer the loads.
283- auto [loadIndex, loadPhase] = addIndexAndPhase (b, loop, numStages);
296+ BlockArgument loadIndex;
297+ BlockArgument loadPhase;
298+ std::tie (loadIndex, loadPhase) = addIndexAndPhase (b, loop, numStages);
299+
300+ auto allocate = [&](const SmallVector<Operation *> &chain)
301+ -> std::tuple<Operation *, RankedTensorType, SharedEncodingTrait, Value> {
302+ if (chain.empty ())
303+ return {nullptr , RankedTensorType (), SharedEncodingTrait (), Value ()};
304+
305+ Operation *load = chain.back ();
306+ auto type = cast<RankedTensorType>(load->getResult (0 ).getType ());
307+ SharedEncodingTrait enc = getSharedEncoding (chain.back ());
308+ Value alloc = createAlloc (loop, type, load->getLoc (), enc, numStages);
309+
310+ return {load, type, enc, alloc};
311+ };
284312
285- Operation *aLoad = aChain.back ();
286- Operation *bLoad = bChain.back ();
287- auto aType = cast<RankedTensorType>(aLoad->getResult (0 ).getType ());
288- auto bType = cast<RankedTensorType>(bLoad->getResult (0 ).getType ());
289- SharedEncodingTrait aEnc = getSharedEncoding (aChain.back ());
290- SharedEncodingTrait bEnc = getSharedEncoding (bChain.back ());
291- Value aAlloc = createAlloc (loop, aType, aLoad->getLoc (), aEnc, numStages);
292- Value bAlloc = createAlloc (loop, bType, bLoad->getLoc (), bEnc, numStages);
313+ auto [aLoad, aType, aEnc, aAlloc] = allocate (aChain);
314+ auto [bLoad, bType, bEnc, bAlloc] = allocate (bChain);
315+ auto [aScaleLoad, aScaleType, aScaleEnc, aScaleAlloc] = allocate (aScaleChain);
316+ auto [bScaleLoad, bScaleType, bScaleEnc, bScaleAlloc] = allocate (bScaleChain);
293317
294318 // Share the same set of barriers for both.
295319 Value emptyBars = createBarrierAlloc (loop, numStages);
@@ -304,9 +328,23 @@ LogicalResult triton::gpu::specializeLoadMMADependencies(scf::ForOp &loop,
304328 int loadSizeInBytes =
305329 product (aType.getShape ()) * aType.getElementTypeBitWidth () / 8 +
306330 product (bType.getShape ()) * bType.getElementTypeBitWidth () / 8 ;
331+ if (aScaleLoad)
332+ loadSizeInBytes += product (aScaleType.getShape ()) *
333+ aScaleType.getElementTypeBitWidth () / 8 ;
334+ if (bScaleLoad)
335+ loadSizeInBytes += product (bScaleType.getShape ()) *
336+ bScaleType.getElementTypeBitWidth () / 8 ;
307337
308338 // Insert before the group of loads.
309- b.setInsertionPoint (aLoad->isBeforeInBlock (bLoad) ? aLoad : bLoad);
339+ SmallVector<Operation *> allLoads{aLoad, bLoad};
340+ if (aScaleLoad)
341+ allLoads.push_back (aScaleLoad);
342+ if (bScaleLoad)
343+ allLoads.push_back (bScaleLoad);
344+ std::sort (allLoads.begin (), allLoads.end (),
345+ [](Operation *a, Operation *b) { return a->isBeforeInBlock (b); });
346+ b.setInsertionPoint (allLoads.front ());
347+
310348 // Wait for the buffer to be empty and the corresponding barrier to be
311349 // exhausted.
312350 Value curEmptyBar = createSingleBufferView (b, emptyBars, loadIndex);
@@ -318,19 +356,21 @@ LogicalResult triton::gpu::specializeLoadMMADependencies(scf::ForOp &loop,
318356 loadSizeInBytes, intCst (true , 1 ));
319357
320358 // Replace the loads with async copies.
321- b.setInsertionPoint (aLoad);
322- Value aView = createSingleBufferView (b, aAlloc, loadIndex);
323- lowerTMACopy (b, *loadPartition, aLoad, curLoadBar, aView);
324- replaceUsesAndPropagateType (b, *aLoad->user_begin (), aView);
325- aLoad->user_begin ()->erase ();
326- aLoad->erase ();
327-
328- b.setInsertionPoint (bLoad);
329- Value bView = createSingleBufferView (b, bAlloc, loadIndex);
330- lowerTMACopy (b, *loadPartition, bLoad, curLoadBar, bView);
331- replaceUsesAndPropagateType (b, *bLoad->user_begin (), bView);
332- bLoad->user_begin ()->erase ();
333- bLoad->erase ();
359+ auto lowerLoadAndPropagate = [&](Operation *load, Value alloc,
360+ Value barrier) {
361+ b.setInsertionPoint (load);
362+ Value view = createSingleBufferView (b, alloc, loadIndex);
363+ lowerTMACopy (b, *loadPartition, load, barrier, view);
364+ replaceUsesAndPropagateType (b, *load->user_begin (), view);
365+ load->user_begin ()->erase ();
366+ load->erase ();
367+ };
368+ lowerLoadAndPropagate (aLoad, aAlloc, curLoadBar);
369+ lowerLoadAndPropagate (bLoad, bAlloc, curLoadBar);
370+ if (aScaleLoad)
371+ lowerLoadAndPropagate (aScaleLoad, aScaleAlloc, curLoadBar);
372+ if (bScaleLoad)
373+ lowerLoadAndPropagate (bScaleLoad, bScaleAlloc, curLoadBar);
334374
335375 // Place the remaining users in the MMA partition. Re-acquire the use chain
336376 // because some ops were invalidated by `replaceUsesAndPropagateType`.
@@ -339,9 +379,18 @@ LogicalResult triton::gpu::specializeLoadMMADependencies(scf::ForOp &loop,
339379 aChain.push_back (mmaOp);
340380 (void )findSingleChainToLoad (loop, dot.getA (), aChain);
341381 (void )findSingleChainToLoad (loop, dot.getB (), bChain);
382+ if (aScaleLoad) {
383+ aScaleChain.clear ();
384+ (void )findSingleChainToLoad (loop, scaledMMAOp.getAScale (), aScaleChain);
385+ }
386+ if (bScaleLoad) {
387+ bScaleChain.clear ();
388+ (void )findSingleChainToLoad (loop, scaledMMAOp.getBScale (), bScaleChain);
389+ }
342390
343391 // Place users in the MMA partition.
344- auto allUsers = llvm::to_vector (llvm::concat<Operation *>(aChain, bChain));
392+ auto allUsers = llvm::to_vector (
393+ llvm::concat<Operation *>(aChain, bChain, aScaleChain, bScaleChain));
345394 for (Operation *user : allUsers)
346395 mmaPartition->insert (user);
347396
0 commit comments