Skip to content

Commit 2d6fb76

Browse files
authored
[Backend] Better warp specialization register reallocation (#6877)
This is the last set of changes split from #6760 This moves where the registers are reallocated into better spots. It also causes the worker partitions to immediately give up a bunch of registers when they aren't active to the default warp group, so that it has more registers to execute the "synchronous" parts of the code. This is useful when there are many worker warps and the default warp group does not get many registers at the start of the kernel (maxnreg is low).
1 parent 0f5eccc commit 2d6fb76

File tree

5 files changed

+150
-43
lines changed

5 files changed

+150
-43
lines changed

lib/Dialect/TritonGPU/Transforms/WarpSpecialization/OptimizePartitionWarps.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ static LogicalResult optimizePartitionNumWarps(ModuleAxisInfoAnalysis &axisInfo,
257257
llvm::zip(wsOp.getPartitionRegions(), partitionNumWarps,
258258
wsOp.getPartitionNumWarps(), maxTensorRegs, estRegUsage)) {
259259
// "Guess" the register usage for each partition.
260-
estRegs = tensorRegs ? 72 : 24;
260+
estRegs = tensorRegs ? 88 : 24;
261261

262262
// Layouts need to be reassigned if the number of warps changed and there
263263
// are tensor computations.

lib/Dialect/TritonGPU/Transforms/WarpSpecialization/PartitionScheduling.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,15 @@ static std::optional<WarpSchedule> getInitialSchedule(scf::ForOp loop) {
208208
Operation *op = operandViews.pop_back_val();
209209
if (!op->hasOneUse() || !op->hasTrait<OpTrait::MemDescViewTrait>())
210210
continue;
211+
212+
// Duplicate the op if necessary to ensure the MMA op is the only user.
213+
if (!llvm::all_of(op->getUsers(),
214+
[&](Operation *user) { return user == mmaOp; })) {
215+
Operation *viewOp = OpBuilder(op).clone(*op);
216+
mmaOp->replaceUsesOfWith(op->getResult(0), viewOp->getResult(0));
217+
op = viewOp;
218+
}
219+
211220
schedule.trySchedule(mmaPartition, op);
212221
if (Operation *defOp = op->getOperand(0).getDefiningOp())
213222
operandViews.push_back(defOp);

test/Conversion/warp_specialize_to_llvm.mlir

Lines changed: 87 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -723,41 +723,121 @@ llvm.mlir.global external @global_smem() {addr_space = 3 : i32, alignment = 16 :
723723

724724
// CHECK-LABEL: @dynamic_register_reallocation
725725
llvm.func @dynamic_register_reallocation() attributes {allocation.offset = 0 : i32} {
726+
// CHECK: cond_br %{{.*}}, [[ENTRY:\^.*]], [[SWITCH_LOOP:\^.*]]
727+
728+
// CHECK: [[SWITCH_LOOP]]:
729+
// CHECK-NEXT: nvvm.setmaxregister decrease 24
730+
// CHECK-NEXT: barrier.sync 1 ;
726731
// CHECK: llvm.switch
727732
// CHECK-NEXT: 0: [[PARTITION0:\^.*]],
728733
// CHECK-NEXT: 1: [[PARTITION1:\^.*]],
729734
// CHECK-NEXT: 2: [[PARTITION2:\^.*]],
730735
// CHECK-NEXT: 3: [[EXIT:\^.*]]
731736

732737
// CHECK: [[PARTITION0]]:
733-
// CHECK-NEXT: barrier.sync 1 ;
734738
// CHECK-NEXT: nvvm.setmaxregister increase 80
739+
// CHECK-NEXT: barrier.sync 1 ;
735740
// CHECK-NEXT: "partition0"()
736741
// CHECK-NEXT: barrier.sync 1 ;
737-
// CHECK-NEXT: nvvm.setmaxregister increase 80
742+
// CHECK-NEXT: nvvm.setmaxregister decrease 24
738743

739744
// CHECK: [[PARTITION1]]:
745+
// CHECK-NEXT: nvvm.setmaxregister increase 48
740746
// CHECK-NEXT: barrier.sync 1 ;
741-
// CHECK-NEXT: nvvm.setmaxregister decrease 48
742747
// CHECK-NEXT: "partition1"()
743748
// CHECK-NEXT: barrier.sync 1 ;
749+
// CHECK-NEXT: nvvm.setmaxregister decrease 24
750+
751+
// CHECK: [[PARTITION2]]:
752+
// CHECK-NEXT: nvvm.setmaxregister increase 128
753+
// CHECK-NEXT: barrier.sync 1 ;
754+
// CHECK-NEXT: "partition2"()
755+
// CHECK-NEXT: barrier.sync 1 ;
756+
// CHECK-NEXT: nvvm.setmaxregister decrease 24
757+
758+
// CHECK: [[ENTRY]]:
759+
// CHECK-NEXT: nvvm.setmaxregister increase 248
760+
761+
// CHECK: barrier.sync 1 ;
762+
// CHECK-NEXT: setmaxregister decrease 152
763+
// CHECK-NEXT: barrier.sync 1 ;
764+
// CHECK: "default"
765+
// CHECK: barrier.sync 1 ;
766+
// CHECK-NEXT: setmaxregister increase 248
767+
768+
ttg.warp_specialize() attributes {allocation.offset = 0 : i32, warpGroupStartIds = array<i32: 4, 8, 12>, actualRegisters = array<i32: 152, 80, 48, 128>}
769+
default {
770+
"default"() : () -> ()
771+
ttg.warp_yield
772+
}
773+
partition0() num_warps(4) {
774+
"partition0"() : () -> ()
775+
ttg.warp_return
776+
}
777+
partition1() num_warps(4) {
778+
"partition1"() : () -> ()
779+
ttg.warp_return
780+
}
781+
partition2() num_warps(4) {
782+
"partition2"() : () -> ()
783+
ttg.warp_return
784+
} : () -> ()
785+
llvm.return
786+
}
787+
788+
}
789+
790+
// -----
791+
792+
module attributes {ttg.maxnreg = 128 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.total-num-warps" = 16 : i32} {
793+
794+
llvm.mlir.global external @global_smem() {addr_space = 3 : i32, alignment = 16 : i64} : !llvm.array<0 x i8>
795+
796+
// CHECK-LABEL: @dynamic_register_reallocation
797+
llvm.func @dynamic_register_reallocation_overalloc() attributes {allocation.offset = 0 : i32} {
798+
// CHECK: cond_br %{{.*}}, [[ENTRY:\^.*]], [[SWITCH_LOOP:\^.*]]
799+
800+
// CHECK: [[SWITCH_LOOP]]:
801+
// CHECK-NEXT: nvvm.setmaxregister decrease 80
802+
// CHECK-NEXT: barrier.sync 1 ;
803+
// CHECK: llvm.switch
804+
// CHECK-NEXT: 0: [[PARTITION0:\^.*]],
805+
// CHECK-NEXT: 1: [[PARTITION1:\^.*]],
806+
// CHECK-NEXT: 2: [[PARTITION2:\^.*]],
807+
// CHECK-NEXT: 3: [[EXIT:\^.*]]
808+
809+
// CHECK: [[PARTITION0]]:
810+
// CHECK-NEXT: nvvm.setmaxregister decrease 24
811+
// CHECK-NEXT: barrier.sync 1 ;
812+
// CHECK-NEXT: "partition0"()
813+
// CHECK-NEXT: barrier.sync 1 ;
744814
// CHECK-NEXT: nvvm.setmaxregister increase 80
745815

816+
// CHECK: [[PARTITION1]]:
817+
// CHECK-NEXT: nvvm.setmaxregister increase 192
818+
// CHECK-NEXT: barrier.sync 1 ;
819+
// CHECK-NEXT: "partition1"()
820+
// CHECK-NEXT: barrier.sync 1 ;
821+
// CHECK-NEXT: nvvm.setmaxregister decrease 80
822+
746823
// CHECK: [[PARTITION2]]:
824+
// CHECK-NEXT: nvvm.setmaxregister increase 192
747825
// CHECK-NEXT: barrier.sync 1 ;
748-
// CHECK-NEXT: nvvm.setmaxregister increase 128
749826
// CHECK-NEXT: "partition2"()
750827
// CHECK-NEXT: barrier.sync 1 ;
751828
// CHECK-NEXT: nvvm.setmaxregister decrease 80
752829

830+
// CHECK: [[ENTRY]]:
831+
// CHECK-NEXT: nvvm.setmaxregister increase 256
832+
753833
// CHECK: barrier.sync 1 ;
834+
// CHECK-NEXT: setmaxregister decrease 104
754835
// CHECK-NEXT: barrier.sync 1 ;
755-
// CHECK: setmaxregister increase 152
756836
// CHECK: "default"
757837
// CHECK: barrier.sync 1 ;
758-
// CHECK-NEXT: setmaxregister decrease 80
838+
// CHECK-NEXT: setmaxregister increase 256
759839

760-
ttg.warp_specialize() attributes {allocation.offset = 0 : i32, warpGroupStartIds = array<i32: 4, 8, 12>, actualRegisters = array<i32: 152, 80, 48, 128>}
840+
ttg.warp_specialize() attributes {allocation.offset = 0 : i32, warpGroupStartIds = array<i32: 4, 8, 12>, actualRegisters = array<i32: 104, 24, 192, 192>}
761841
default {
762842
"default"() : () -> ()
763843
ttg.warp_yield

test/TritonGPU/optimize-partition-warps.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ tt.func @fits_after_shrink(%arg0: i32) {
119119

120120
// CHECK-LABEL: @register_use_heuristic
121121
tt.func @register_use_heuristic() {
122-
// CHECK: requestedRegisters = array<i32: 24, 72>
122+
// CHECK: requestedRegisters = array<i32: 24, 88>
123123
ttg.warp_specialize()
124124
default {
125125
ttg.warp_yield

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertWarpSpecializeToLLVM.cpp

Lines changed: 52 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -187,20 +187,6 @@ static void createRegRealloc(TritonLLVMIRRewriter &b, int curRegs,
187187
b.create<NVVM::SetMaxRegisterOp>(adjRegs, action);
188188
}
189189

190-
static void createEntryRegRealloc(TritonLLVMIRRewriter &b, Operation *op,
191-
int actRegs) {
192-
auto maxnreg = op->getParentOfType<ModuleOp>()->getAttrOfType<IntegerAttr>(
193-
AttrMaxRegistersName);
194-
createRegRealloc(b, maxnreg.getInt(), actRegs);
195-
}
196-
197-
static void createExitRegRealloc(TritonLLVMIRRewriter &b, Operation *op,
198-
int actRegs) {
199-
auto maxnreg = op->getParentOfType<ModuleOp>()->getAttrOfType<IntegerAttr>(
200-
AttrMaxRegistersName);
201-
createRegRealloc(b, actRegs, maxnreg.getInt());
202-
}
203-
204190
// Assign hardware barriers to each warp group and rewrite warp group barriers
205191
// into `barrier.sync` instructions. There is a maximum number of barriers.
206192
static LogicalResult rewriteWarpGroupBarriers(LLVM::LLVMFuncOp func,
@@ -245,13 +231,20 @@ static LogicalResult rewriteWarpGroupBarriers(LLVM::LLVMFuncOp func,
245231
}
246232

247233
static void rewritePartitionRegions(WarpSpecializeOp ws, Block *switchLoop,
248-
const NVIDIA::TargetInfo &targetInfo) {
234+
const NVIDIA::TargetInfo &targetInfo,
235+
int lowRegs) {
249236
TritonLLVMIRRewriter b(ws.getLoc(), ws.getContext());
250237

251238
for (Region *partition : ws.getPartitionRegions()) {
252239
// Load the explicit captures from shared memory and replace the block args
253240
// if there are any.
254241
b.setInsertionPointToStart(&partition->front());
242+
243+
if (auto actRegs = ws.getActualRegisters()) {
244+
createRegRealloc(b, lowRegs,
245+
(*actRegs)[partition->getRegionNumber() + 1]);
246+
}
247+
255248
if (partition->getNumArguments()) {
256249
auto captureType = LLVM::LLVMStructType::getLiteral(
257250
b.getContext(), llvm::to_vector(partition->getArgumentTypes()),
@@ -275,19 +268,15 @@ static void rewritePartitionRegions(WarpSpecializeOp ws, Block *switchLoop,
275268
// another barrier here.
276269
createBarrier(b, kSwitchLoopBarrierIdx, /*numThreads=*/std::nullopt,
277270
/*aligned=*/false);
278-
if (auto actRegs = ws.getActualRegisters()) {
279-
createEntryRegRealloc(b, ws,
280-
(*actRegs)[partition->getRegionNumber() + 1]);
281-
}
282271

283272
// Rewrite all warp returns.
284273
partition->walk([&](WarpReturnOp op) {
285274
TritonLLVMIRRewriter b(op.getLoc(), op);
286275
createBarrier(b, kSwitchLoopBarrierIdx, /*numThreads=*/std::nullopt,
287276
/*aligned=*/false);
288277
if (auto actRegs = ws.getActualRegisters()) {
289-
createExitRegRealloc(b, ws,
290-
(*actRegs)[partition->getRegionNumber() + 1]);
278+
createRegRealloc(b, (*actRegs)[partition->getRegionNumber() + 1],
279+
lowRegs);
291280
}
292281
b.replaceOpWithNewOp<LLVM::BrOp>(op, switchLoop);
293282
});
@@ -328,6 +317,39 @@ static LogicalResult lowerWarpSpecialize(LLVM::LLVMFuncOp func,
328317
defaultWarpGroupSize)))
329318
return failure();
330319

320+
auto totalNumWarpsAttr =
321+
module->getAttrOfType<IntegerAttr>("ttg.total-num-warps");
322+
if (!totalNumWarpsAttr) {
323+
return mlir::emitError(module.getLoc(),
324+
"module missing 'ttg.total-num-warps' attribute");
325+
}
326+
unsigned totalNumThreads = totalNumWarpsAttr.getInt() * threadsPerWarp;
327+
328+
// Determine how many registers the worker warps can surrender before they
329+
// begin execution.
330+
auto maxnreg = func->getParentOfType<ModuleOp>()->getAttrOfType<IntegerAttr>(
331+
AttrMaxRegistersName);
332+
int lowRegs = -1;
333+
int defRegs = -1;
334+
if (maxnreg) {
335+
int numWorkerWarps = totalNumWarpsAttr.getInt() - defaultNumWarps;
336+
int startRegs = maxnreg.getInt();
337+
338+
// First determine how many extra registers the default warp group can get
339+
// if the workers surrender the maximum number of registers.
340+
lowRegs = 24;
341+
int extraRegs = (startRegs - lowRegs) * numWorkerWarps / defaultNumWarps;
342+
defRegs = (startRegs + extraRegs) / 8 * 8;
343+
344+
// If the default warp group goes over 256 registers, the workers don't need
345+
// to give up this much.
346+
if (defRegs > 256) {
347+
defRegs = 256;
348+
int giveRegs = (defRegs - startRegs) * defaultNumWarps / numWorkerWarps;
349+
lowRegs = (startRegs - giveRegs) / 8 * 8;
350+
}
351+
}
352+
331353
// Attempt to elide captures of trivial computations by hoisting them into the
332354
// header or rematerializing them into each partition.
333355
elideTrivialCaptures(func, wsOps);
@@ -357,22 +379,18 @@ static LogicalResult lowerWarpSpecialize(LLVM::LLVMFuncOp func,
357379
llvm::zip(header->getArguments(), entry->getArguments()))
358380
oldArg.replaceAllUsesWith(arg);
359381
entry->eraseArguments([](auto) { return true; });
360-
361-
// Generate the switch loop.
362-
auto totalNumWarpsAttr =
363-
module->getAttrOfType<IntegerAttr>("ttg.total-num-warps");
364-
if (!totalNumWarpsAttr) {
365-
return mlir::emitError(module.getLoc(),
366-
"module missing 'ttg.total-num-warps' attribute");
367-
}
368-
unsigned totalNumThreads = totalNumWarpsAttr.getInt() * threadsPerWarp;
382+
b.setInsertionPointToStart(entry);
383+
if (maxnreg)
384+
createRegRealloc(b, maxnreg.getInt(), defRegs);
369385

370386
// ^switchLoop:
371387
// barrier.sync 1
372388
// %state_ptr = getelementptr (ptr @shared), <offset>
373389
// %rel_tid = sub %tid, <default_warp_group_size>
374390
// %rel_wid = udiv %rel_tid, 32
375391
b.setInsertionPointToStart(switchLoop);
392+
if (maxnreg)
393+
createRegRealloc(b, maxnreg.getInt(), lowRegs);
376394
createBarrier(b, kSwitchLoopBarrierIdx, /*numThreads=*/std::nullopt,
377395
/*aligned=*/false);
378396
Value statePtr = LLVM::getSharedMemoryBase(b.getLoc(), b, targetInfo, func);
@@ -400,7 +418,7 @@ static LogicalResult lowerWarpSpecialize(LLVM::LLVMFuncOp func,
400418
SmallVector<SmallVector<int32_t>> warpToState(
401419
wsOps.size(), SmallVector<int32_t>(maxNumWarps, -1));
402420
for (auto [op, stateMap] : llvm::zip(wsOps, warpToState)) {
403-
rewritePartitionRegions(op, switchLoop, targetInfo);
421+
rewritePartitionRegions(op, switchLoop, targetInfo, lowRegs);
404422
for (auto [partition, partitionNumWarps, startId] :
405423
llvm::zip(op.getPartitionRegions(), op.getPartitionNumWarps(),
406424
*op.getWarpGroupStartIds())) {
@@ -480,18 +498,18 @@ static LogicalResult lowerWarpSpecialize(LLVM::LLVMFuncOp func,
480498
// they have read the captures before the memory is released upon entry.
481499
createBarrier(b, kSwitchLoopBarrierIdx, /*numThreads=*/std::nullopt,
482500
/*aligned=*/false);
501+
if (auto actRegs = ws.getActualRegisters())
502+
createRegRealloc(b, defRegs, actRegs->front());
483503
createBarrier(b, kSwitchLoopBarrierIdx, /*numThreads=*/std::nullopt,
484504
/*aligned=*/false);
485-
if (auto actRegs = ws.getActualRegisters())
486-
createEntryRegRealloc(b, func, actRegs->front());
487505
b.create<LLVM::BrOp>(&ws.getDefaultRegion().front());
488506

489507
ws.getDefaultRegion().walk([&, ws = ws](WarpYieldOp op) mutable {
490508
TritonLLVMIRRewriter b(op.getLoc(), op);
491509
createBarrier(b, kSwitchLoopBarrierIdx, /*numThreads=*/std::nullopt,
492510
/*aligned=*/false);
493511
if (auto actRegs = ws.getActualRegisters())
494-
createExitRegRealloc(b, func, actRegs->front());
512+
createRegRealloc(b, actRegs->front(), defRegs);
495513
b.replaceOpWithNewOp<LLVM::BrOp>(op, op.getOperands(), after);
496514
});
497515
after->getParent()->getBlocks().splice(after->getIterator(),

0 commit comments

Comments
 (0)