@@ -100,8 +100,8 @@ TMemMessageTraits getTMemMessageFromAtom(const TMemAccessAtom &atom,
100100// Only allows half of the thread registers to be used for tensor memory access
101101// to avoid register pressure. This ensures the largest tmem message width is
102102// used for the workload without inducing spills.
103- int getTMemMessageNarrowingFactor (int workloadThreadRegs) {
104- const int allowedRegUsage = maxRegisters / 2 ;
103+ int getTMemMessageNarrowingFactor (int workloadThreadRegs, int maxnreg ) {
104+ const int allowedRegUsage = maxnreg / 2 ;
105105 int narrowingFactor = 1 ;
106106 while (workloadThreadRegs > allowedRegUsage) {
107107 workloadThreadRegs /= 2 ;
@@ -338,13 +338,13 @@ void createWaitOpSt(Location loc, ConversionPatternRewriter &rewriter) {
338338 ptxBuilder.launch (rewriter, loc, void_ty (rewriter.getContext ()));
339339}
340340
341- TMemMessageTraits selectTMemMessage (const TMemRuntimeInfo &info) {
341+ TMemMessageTraits selectTMemMessage (const TMemRuntimeInfo &info, int maxnreg ) {
342342 auto atom = info.useStridedMessage ? TMemAccess16x32bx2 : TMemAccess32x32b;
343343
344344 int totalRegsNeeded =
345345 getEffectiveRegs (info.unpackedb16 , info.useStridedMessage ,
346346 info.numCols / info.numWarpGroups );
347- int narrowingFactor = getTMemMessageNarrowingFactor (totalRegsNeeded);
347+ int narrowingFactor = getTMemMessageNarrowingFactor (totalRegsNeeded, maxnreg );
348348 auto narrowedMessage = getTMemMessageFromAtom (atom, narrowingFactor);
349349 narrowedMessage = constrainMessageFromWorkload (narrowedMessage, info,
350350 narrowedMessage.numRegs );
@@ -355,6 +355,35 @@ TMemMessageTraits selectTMemMessage(const TMemRuntimeInfo &info) {
355355 return std::min (narrowedMessage, maxWidthMessage);
356356}
357357
358+ // Get the maximum number of registers per thread based on the context. This is
359+ // by default 256, but it can be overridden by `ttg.maxnreg` set on the module.
360+ // Alternatively, warp groups within warp specialized regions can have a
361+ // different number of registers allocated.
362+ static int getContextualMaxNReg (Operation *op) {
363+ if (auto mod = dyn_cast<ModuleOp>(op)) {
364+ // Check for a maxnreg attribute.
365+ if (auto attr = op->getAttrOfType <IntegerAttr>(AttrMaxRegistersName))
366+ return std::max<int >(maxRegisters, attr.getInt ());
367+
368+ } else if (auto partitions =
369+ dyn_cast<WarpSpecializePartitionsOp>(op->getParentOp ())) {
370+ // Check if the partition has reduced registers.
371+ unsigned idx = op->getParentRegion ()->getRegionNumber ();
372+ if (auto actRegisters = partitions.getParentOp ().getActualRegisters ())
373+ return std::max<int >(maxRegisters, (*actRegisters)[1 + idx]);
374+ return getContextualMaxNReg (partitions.getParentOp ());
375+
376+ } else if (auto wsOp = dyn_cast<WarpSpecializeOp>(op->getParentOp ())) {
377+ // Check the register usage of the default warpgroup.
378+ if (auto actRegisters = wsOp.getActualRegisters ())
379+ return std::max<int >(maxRegisters, actRegisters->front ());
380+ }
381+
382+ if (Operation *parent = op->getParentOp ())
383+ return getContextualMaxNReg (parent);
384+ return maxRegisters;
385+ }
386+
358387static void lowerStoreToTensorMemory (Location loc, Operation *op, Value src,
359388 Value dest, Value llSrc, Value pred,
360389 Value tmemBase,
@@ -365,7 +394,8 @@ static void lowerStoreToTensorMemory(Location loc, Operation *op, Value src,
365394 auto dstType = cast<MemDescType>(dest.getType ());
366395 auto info = getTMemRuntimeInfo (op, cast<RankedTensorType>(src.getType ()),
367396 cast<MemDescType>(dest.getType ()));
368- const TMemMessageTraits message = selectTMemMessage (info);
397+ const TMemMessageTraits message =
398+ selectTMemMessage (info, getContextualMaxNReg (op));
369399 int regIdx = 0 ;
370400 calculateAddressAndEmitTmemMessage (
371401 loc, tmemBase, info, message, rewriter,
@@ -503,7 +533,8 @@ struct TensorMemoryLoadOpConversion
503533
504534 auto info = getTMemRuntimeInfo (op, cast<RankedTensorType>(op.getType ()),
505535 cast<MemDescType>(op.getSrc ().getType ()));
506- const TMemMessageTraits message = selectTMemMessage (info);
536+ const TMemMessageTraits message =
537+ selectTMemMessage (info, getContextualMaxNReg (op));
507538 SmallVector<Value> resultVals;
508539 calculateAddressAndEmitTmemMessage (
509540 loc, tmemBase, info, message, rewriter,
0 commit comments