@@ -187,20 +187,6 @@ static void createRegRealloc(TritonLLVMIRRewriter &b, int curRegs,
187187 b.create <NVVM::SetMaxRegisterOp>(adjRegs, action);
188188}
189189
190- static void createEntryRegRealloc (TritonLLVMIRRewriter &b, Operation *op,
191- int actRegs) {
192- auto maxnreg = op->getParentOfType <ModuleOp>()->getAttrOfType <IntegerAttr>(
193- AttrMaxRegistersName);
194- createRegRealloc (b, maxnreg.getInt (), actRegs);
195- }
196-
197- static void createExitRegRealloc (TritonLLVMIRRewriter &b, Operation *op,
198- int actRegs) {
199- auto maxnreg = op->getParentOfType <ModuleOp>()->getAttrOfType <IntegerAttr>(
200- AttrMaxRegistersName);
201- createRegRealloc (b, actRegs, maxnreg.getInt ());
202- }
203-
204190// Assign hardware barriers to each warp group and rewrite warp group barriers
205191// into `barrier.sync` instructions. There is a maximum number of barriers.
206192static LogicalResult rewriteWarpGroupBarriers (LLVM::LLVMFuncOp func,
@@ -245,13 +231,20 @@ static LogicalResult rewriteWarpGroupBarriers(LLVM::LLVMFuncOp func,
245231}
246232
247233static void rewritePartitionRegions (WarpSpecializeOp ws, Block *switchLoop,
248- const NVIDIA::TargetInfo &targetInfo) {
234+ const NVIDIA::TargetInfo &targetInfo,
235+ int lowRegs) {
249236 TritonLLVMIRRewriter b (ws.getLoc (), ws.getContext ());
250237
251238 for (Region *partition : ws.getPartitionRegions ()) {
252239 // Load the explicit captures from shared memory and replace the block args
253240 // if there are any.
254241 b.setInsertionPointToStart (&partition->front ());
242+
243+ if (auto actRegs = ws.getActualRegisters ()) {
244+ createRegRealloc (b, lowRegs,
245+ (*actRegs)[partition->getRegionNumber () + 1 ]);
246+ }
247+
255248 if (partition->getNumArguments ()) {
256249 auto captureType = LLVM::LLVMStructType::getLiteral (
257250 b.getContext (), llvm::to_vector (partition->getArgumentTypes ()),
@@ -275,19 +268,15 @@ static void rewritePartitionRegions(WarpSpecializeOp ws, Block *switchLoop,
275268 // another barrier here.
276269 createBarrier (b, kSwitchLoopBarrierIdx , /* numThreads=*/ std::nullopt ,
277270 /* aligned=*/ false );
278- if (auto actRegs = ws.getActualRegisters ()) {
279- createEntryRegRealloc (b, ws,
280- (*actRegs)[partition->getRegionNumber () + 1 ]);
281- }
282271
283272 // Rewrite all warp returns.
284273 partition->walk ([&](WarpReturnOp op) {
285274 TritonLLVMIRRewriter b (op.getLoc (), op);
286275 createBarrier (b, kSwitchLoopBarrierIdx , /* numThreads=*/ std::nullopt ,
287276 /* aligned=*/ false );
288277 if (auto actRegs = ws.getActualRegisters ()) {
289- createExitRegRealloc (b, ws ,
290- (*actRegs)[partition-> getRegionNumber () + 1 ] );
278+ createRegRealloc (b, (*actRegs)[partition-> getRegionNumber () + 1 ] ,
279+ lowRegs );
291280 }
292281 b.replaceOpWithNewOp <LLVM::BrOp>(op, switchLoop);
293282 });
@@ -328,6 +317,39 @@ static LogicalResult lowerWarpSpecialize(LLVM::LLVMFuncOp func,
328317 defaultWarpGroupSize)))
329318 return failure ();
330319
320+ auto totalNumWarpsAttr =
321+ module ->getAttrOfType <IntegerAttr>(" ttg.total-num-warps" );
322+ if (!totalNumWarpsAttr) {
323+ return mlir::emitError (module .getLoc (),
324+ " module missing 'ttg.total-num-warps' attribute" );
325+ }
326+ unsigned totalNumThreads = totalNumWarpsAttr.getInt () * threadsPerWarp;
327+
328+ // Determine how many registers the worker warps can surrender before they
329+ // begin execution.
330+ auto maxnreg = func->getParentOfType <ModuleOp>()->getAttrOfType <IntegerAttr>(
331+ AttrMaxRegistersName);
332+ int lowRegs = -1 ;
333+ int defRegs = -1 ;
334+ if (maxnreg) {
335+ int numWorkerWarps = totalNumWarpsAttr.getInt () - defaultNumWarps;
336+ int startRegs = maxnreg.getInt ();
337+
338+ // First determine how many extra registers the default warp group can get
339+ // if the workers surrender the maximum number of registers.
340+ lowRegs = 24 ;
341+ int extraRegs = (startRegs - lowRegs) * numWorkerWarps / defaultNumWarps;
342+ defRegs = (startRegs + extraRegs) / 8 * 8 ;
343+
344+ // If the default warp group goes over 256 registers, the workers don't need
345+ // to give up this much.
346+ if (defRegs > 256 ) {
347+ defRegs = 256 ;
348+ int giveRegs = (defRegs - startRegs) * defaultNumWarps / numWorkerWarps;
349+ lowRegs = (startRegs - giveRegs) / 8 * 8 ;
350+ }
351+ }
352+
331353 // Attempt to elide captures of trivial computations by hoisting them into the
332354 // header or rematerializing them into each partition.
333355 elideTrivialCaptures (func, wsOps);
@@ -357,22 +379,18 @@ static LogicalResult lowerWarpSpecialize(LLVM::LLVMFuncOp func,
357379 llvm::zip (header->getArguments (), entry->getArguments ()))
358380 oldArg.replaceAllUsesWith (arg);
359381 entry->eraseArguments ([](auto ) { return true ; });
360-
361- // Generate the switch loop.
362- auto totalNumWarpsAttr =
363- module ->getAttrOfType <IntegerAttr>(" ttg.total-num-warps" );
364- if (!totalNumWarpsAttr) {
365- return mlir::emitError (module .getLoc (),
366- " module missing 'ttg.total-num-warps' attribute" );
367- }
368- unsigned totalNumThreads = totalNumWarpsAttr.getInt () * threadsPerWarp;
382+ b.setInsertionPointToStart (entry);
383+ if (maxnreg)
384+ createRegRealloc (b, maxnreg.getInt (), defRegs);
369385
370386 // ^switchLoop:
371387 // barrier.sync 1
372388 // %state_ptr = getelementptr (ptr @shared), <offset>
373389 // %rel_tid = sub %tid, <default_warp_group_size>
374390 // %rel_wid = udiv %rel_tid, 32
375391 b.setInsertionPointToStart (switchLoop);
392+ if (maxnreg)
393+ createRegRealloc (b, maxnreg.getInt (), lowRegs);
376394 createBarrier (b, kSwitchLoopBarrierIdx , /* numThreads=*/ std::nullopt ,
377395 /* aligned=*/ false );
378396 Value statePtr = LLVM::getSharedMemoryBase (b.getLoc (), b, targetInfo, func);
@@ -400,7 +418,7 @@ static LogicalResult lowerWarpSpecialize(LLVM::LLVMFuncOp func,
400418 SmallVector<SmallVector<int32_t >> warpToState (
401419 wsOps.size (), SmallVector<int32_t >(maxNumWarps, -1 ));
402420 for (auto [op, stateMap] : llvm::zip (wsOps, warpToState)) {
403- rewritePartitionRegions (op, switchLoop, targetInfo);
421+ rewritePartitionRegions (op, switchLoop, targetInfo, lowRegs );
404422 for (auto [partition, partitionNumWarps, startId] :
405423 llvm::zip (op.getPartitionRegions (), op.getPartitionNumWarps (),
406424 *op.getWarpGroupStartIds ())) {
@@ -480,18 +498,18 @@ static LogicalResult lowerWarpSpecialize(LLVM::LLVMFuncOp func,
480498 // they have read the captures before the memory is released upon entry.
481499 createBarrier (b, kSwitchLoopBarrierIdx , /* numThreads=*/ std::nullopt ,
482500 /* aligned=*/ false );
501+ if (auto actRegs = ws.getActualRegisters ())
502+ createRegRealloc (b, defRegs, actRegs->front ());
483503 createBarrier (b, kSwitchLoopBarrierIdx , /* numThreads=*/ std::nullopt ,
484504 /* aligned=*/ false );
485- if (auto actRegs = ws.getActualRegisters ())
486- createEntryRegRealloc (b, func, actRegs->front ());
487505 b.create <LLVM::BrOp>(&ws.getDefaultRegion ().front ());
488506
489507 ws.getDefaultRegion ().walk ([&, ws = ws](WarpYieldOp op) mutable {
490508 TritonLLVMIRRewriter b (op.getLoc (), op);
491509 createBarrier (b, kSwitchLoopBarrierIdx , /* numThreads=*/ std::nullopt ,
492510 /* aligned=*/ false );
493511 if (auto actRegs = ws.getActualRegisters ())
494- createExitRegRealloc (b, func, actRegs->front ());
512+ createRegRealloc (b, actRegs->front (), defRegs );
495513 b.replaceOpWithNewOp <LLVM::BrOp>(op, op.getOperands (), after);
496514 });
497515 after->getParent ()->getBlocks ().splice (after->getIterator (),
0 commit comments