Skip to content

Commit 711caa4

Browse files
authored
[TritonGPU] Control dynamic register allocation from triton (#6407)
This pipes the ability for the compiler to set dynamic register usage in warp specialization through the compiler. The middle-end will estimate how many registers each partition will use as part of deciding how many warps each partition will be. Then, the warpgroup allocator will group partitions together and figure out the actual distribution of registers. If registers can be redistributed, it will set a final number of registers per warpgroup, which in turn generate `nvvm.setmaxregister` directives. "estimating" register usage is in general not possible from TTGIR, so currently this just hard codes some numbers.
1 parent 366de71 commit 711caa4

File tree

10 files changed

+203
-22
lines changed

10 files changed

+203
-22
lines changed

include/triton/Dialect/TritonGPU/IR/Dialect.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ template <> struct hash<CacheKey> {
3939

4040
namespace mlir::triton::gpu {
4141

42+
constexpr static char AttrMaxRegistersName[] = "ttg.maxnreg";
4243
constexpr static char AttrNumWarpsName[] = "ttg.num-warps";
4344
constexpr static char AttrNumCTAsName[] = "ttg.num-ctas";
4445
constexpr static char AttrTargetName[] = "ttg.target";

include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -376,7 +376,9 @@ def TTG_WarpSpecializeOp : TTG_Op<"warp_specialize", [
376376
let arguments = (ins
377377
Variadic<AnyType>:$explicitCaptures,
378378
DenseI32ArrayAttr:$partitionNumWarps,
379-
OptionalAttr<DenseI32ArrayAttr>:$warpGroupStartIds
379+
OptionalAttr<DenseI32ArrayAttr>:$warpGroupStartIds,
380+
OptionalAttr<DenseI32ArrayAttr>:$requestedRegisters,
381+
OptionalAttr<DenseI32ArrayAttr>:$actualRegisters
380382
);
381383
let results = (outs Variadic<AnyType>:$defaultPassthrough);
382384

lib/Conversion/TritonGPUToLLVM/AllocateWarpGroups.cpp

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,20 @@ struct AllocateWarpGroups
1818
void runOnOperation() override {
1919
ModuleOp mod = getOperation();
2020

21+
int threadsPerWarp = TritonGPUDialect::getThreadsPerWarp(mod);
22+
23+
struct WarpGroupInfo {
24+
SmallVector<Region *> partitions;
25+
int maxRequestedRegs = 0;
26+
unsigned numWarps = 0;
27+
};
28+
struct WarpGroupPartition {
29+
int startId;
30+
Region *partition;
31+
int32_t estRegs;
32+
int numWarps;
33+
};
34+
2135
// Compute the total number of warps required at any given time.
2236
int baseNumWarps = lookupNumWarps(mod);
2337
int maxExtraWarps = 0;
@@ -42,6 +56,81 @@ struct AllocateWarpGroups
4256
startId += size;
4357
}
4458
op.setWarpGroupStartIds(startIds);
59+
60+
// Require that an estimate has been set and that we have even warpgroups.
61+
auto regsAttr = op.getRequestedRegisters();
62+
if (!regsAttr || op.getTotalPartitionWarps() % 4 != 0)
63+
return;
64+
65+
// Group the partitions into warpgroups.
66+
SmallVector<WarpGroupPartition> orderedPartitions;
67+
for (auto [startId, partition, estRegs, numWarps] :
68+
llvm::zip(startIds, op.getPartitionRegions(), *regsAttr, arr))
69+
orderedPartitions.push_back({startId, partition, estRegs, numWarps});
70+
llvm::sort(orderedPartitions,
71+
[&](auto lhs, auto rhs) { return lhs.startId < rhs.startId; });
72+
73+
// Iterate over the partitions and assign them to warp groups. Determine
74+
// the maximum number of requested registers per warp group.
75+
SmallVector<WarpGroupInfo> warpGroups;
76+
for (auto [startId, partition, estRegs, numWarps] : orderedPartitions) {
77+
if (startId % 4 == 0) {
78+
warpGroups.push_back(WarpGroupInfo{});
79+
}
80+
warpGroups.back().partitions.push_back(partition);
81+
// Round up the nearest multiple of 8.
82+
int estRegsCeil8 = llvm::divideCeil(estRegs, 8) * 8;
83+
warpGroups.back().maxRequestedRegs =
84+
std::max<int>(warpGroups.back().maxRequestedRegs, estRegsCeil8);
85+
warpGroups.back().numWarps += numWarps;
86+
}
87+
88+
// Determine the maximum number of registers per thread. This may have
89+
// been set by the user.
90+
int maxnreg;
91+
if (auto maxnregAttr =
92+
op->getAttrOfType<IntegerAttr>(AttrMaxRegistersName)) {
93+
maxnreg = maxnregAttr.getInt();
94+
} else {
95+
maxnreg = (1 << 16) / (baseNumWarps + op.getTotalPartitionWarps()) /
96+
threadsPerWarp;
97+
maxnreg = maxnreg / 8 * 8;
98+
}
99+
100+
// Compute the register deficit over the partition warp groups.
101+
int registerDeficit = 0;
102+
for (const WarpGroupInfo &wg : warpGroups) {
103+
assert(wg.numWarps % 4 == 0);
104+
registerDeficit +=
105+
(maxnreg - wg.maxRequestedRegs) * wg.numWarps * threadsPerWarp;
106+
}
107+
if (registerDeficit <= 0)
108+
return;
109+
110+
// Determine the number of extra registers that we can distribute to the
111+
// default warp group.
112+
int leftover =
113+
((baseNumWarps * threadsPerWarp * maxnreg) + registerDeficit) /
114+
baseNumWarps / threadsPerWarp;
115+
// Round down to the nearest multiple of 8.
116+
leftover = leftover / 8 * 8;
117+
118+
// Generate setmaxnreg in each partition according to its warp group.
119+
SmallVector<int32_t> maxnregsPerPartition(1 + arr.size());
120+
for (const WarpGroupInfo &wg : warpGroups) {
121+
for (Region *region : wg.partitions) {
122+
maxnregsPerPartition[1 + region->getRegionNumber()] =
123+
wg.maxRequestedRegs;
124+
}
125+
}
126+
// Set the register usage for the default warp group.
127+
maxnregsPerPartition.front() = leftover;
128+
op.setActualRegisters(maxnregsPerPartition);
129+
130+
// Set the initial max number of registers. This is needed for PTXAS to
131+
// cooperate.
132+
mod->setAttr(AttrMaxRegistersName,
133+
Builder(op.getContext()).getI32IntegerAttr(maxnreg));
45134
});
46135

47136
Builder b(&getContext());

lib/Conversion/TritonGPUToLLVM/FuncOpToLLVM.cpp

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -125,11 +125,11 @@ struct FuncOpConversion : public ConvertOpToLLVMPattern<triton::FuncOp> {
125125
mlir::IntegerType::get(llvmFuncOp.getContext(), 8);
126126
const auto arrayType = mlir::LLVM::LLVMArrayType::get(
127127
llvmFuncOp.getContext(), byteType, 128);
128-
llvmFuncOp.setArgAttr(i, "llvm.byval",
128+
llvmFuncOp.setArgAttr(i, LLVM::LLVMDialect::getByValAttrName(),
129129
mlir::TypeAttr::get(arrayType));
130-
llvmFuncOp.setArgAttr(i, "nvvm.grid_constant",
130+
llvmFuncOp.setArgAttr(i, NVVM::NVVMDialect::getGridConstantAttrName(),
131131
mlir::UnitAttr::get(llvmFuncOp.getContext()));
132-
llvmFuncOp.setArgAttr(i, "llvm.align",
132+
llvmFuncOp.setArgAttr(i, LLVM::LLVMDialect::getAlignAttrName(),
133133
mlir::IntegerAttr::get(i32_type, 64));
134134
}
135135
}
@@ -155,7 +155,7 @@ struct FuncOpConversion : public ConvertOpToLLVMPattern<triton::FuncOp> {
155155

156156
if (LLVM::isKernel(funcOp)) {
157157
// Set an attribute to indicate this function is a kernel entry.
158-
newFuncOp->setAttr("nvvm.kernel",
158+
newFuncOp->setAttr(NVVM::NVVMDialect::getKernelFuncAttrName(),
159159
rewriter.getIntegerAttr(type::u1Ty(ctx), 1));
160160
newFuncOp.setLinkage(LLVM::Linkage::External);
161161
} else {
@@ -166,12 +166,20 @@ struct FuncOpConversion : public ConvertOpToLLVMPattern<triton::FuncOp> {
166166
ArrayAttr::get(ctx, rewriter.getStringAttr("noinline")));
167167
newFuncOp.setLinkage(LLVM::Linkage::Internal);
168168
}
169-
// Set an attribute for reqntidx, it could be used in latter LLVM codegen
170-
// for `nvvm.annotation` metadata.
169+
170+
// Determine the actual number of required warps.
171171
int numWarps = triton::gpu::lookupNumWarps(funcOp);
172172
if (auto totalNumWarps = funcOp.getParentOp()->getAttrOfType<IntegerAttr>(
173173
"ttg.total-num-warps"))
174174
numWarps = totalNumWarps.getInt();
175+
176+
// Set `nvvm.maxnreg` if it was specified on the module.
177+
if (Attribute maxnregAttr =
178+
funcOp.getParentOp()->getAttr(triton::gpu::AttrMaxRegistersName))
179+
newFuncOp->setAttr(NVVM::NVVMDialect::getMaxnregAttrName(), maxnregAttr);
180+
181+
// Set an attribute for reqntidx, it could be used in latter LLVM codegen
182+
// for `nvvm.annotation` metadata.
175183
newFuncOp->setAttr(NVVM::NVVMDialect::getReqntidAttrName(),
176184
rewriter.getDenseI32ArrayAttr(32 * numWarps));
177185

lib/Dialect/TritonGPU/IR/Ops.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -769,7 +769,7 @@ void WarpSpecializeOp::build(OpBuilder &builder, OperationState &state,
769769
ArrayRef<int32_t> partitionNumWarps,
770770
unsigned partitionNumRegions) {
771771
build(builder, state, resultTypes, /*explicitCaptures=*/ValueRange(),
772-
partitionNumWarps, /*warpGroupStartIds=*/{});
772+
partitionNumWarps, {}, {}, {});
773773
OpBuilder::InsertionGuard guard(builder);
774774
Block *container = builder.createBlock(state.regions.back().get());
775775
builder.create<WarpSpecializePartitionsOp>(state.location,

lib/Dialect/TritonGPU/Transforms/WarpSpecialization/OptimizePartitionWarps.cpp

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -249,16 +249,23 @@ static LogicalResult optimizePartitionNumWarps(ModuleAxisInfoAnalysis &axisInfo,
249249
}
250250
} while (changed);
251251

252-
for (auto [partition, newNumWarps, prevNumWarps, tensorRegs] :
252+
SmallVector<int32_t> estRegUsage(partitionNumWarps.size());
253+
for (auto [partition, newNumWarps, prevNumWarps, tensorRegs, estRegs] :
253254
llvm::zip(wsOp.getPartitionRegions(), partitionNumWarps,
254-
wsOp.getPartitionNumWarps(), maxTensorRegs)) {
255+
wsOp.getPartitionNumWarps(), maxTensorRegs, estRegUsage)) {
256+
// "Guess" the register usage for each partition.
257+
estRegs = tensorRegs ? 80 : 48;
258+
259+
// Layouts need to be reassigned if the number of warps changed and there
260+
// are tensor computations.
255261
if (newNumWarps == prevNumWarps || !tensorRegs)
256262
continue;
257263
// We need to reassign layouts.
258264
if (failed(relayoutWarps(axisInfo, partition, prevNumWarps, newNumWarps,
259265
runPipeline)))
260266
return failure();
261267
}
268+
wsOp.setRequestedRegisters(estRegUsage);
262269
wsOp.setPartitionNumWarps(partitionNumWarps);
263270
return success();
264271
}

test/Conversion/allocate_warp_groups.mlir

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,3 +63,28 @@ tt.func @two_warp_specialize() {
6363
}
6464

6565
}
66+
67+
// -----
68+
69+
// CHECK: module attributes {ttg.maxnreg = 168 : i32
70+
module attributes {"ttg.num-warps" = 8 : i32} {
71+
72+
tt.func @setmaxnreg() {
73+
// CHECK: actualRegisters = array<i32: 208, 80, 80, 80>
74+
ttg.warp_specialize() attributes {requestedRegisters = array<i32: 48, 80, 48>}
75+
default {
76+
ttg.warp_yield
77+
}
78+
partition0() num_warps(1) {
79+
ttg.warp_return
80+
}
81+
partition1() num_warps(2) {
82+
ttg.warp_return
83+
}
84+
partition2() num_warps(1) {
85+
ttg.warp_return
86+
} : () -> ()
87+
tt.return
88+
}
89+
90+
}

third_party/nvidia/backend/compiler.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,10 @@ def make_ttir(mod, metadata, opt):
238238

239239
@staticmethod
240240
def make_ttgir(mod, metadata, opt, capability):
241+
# Set maxnreg on all kernels, if it was provided.
242+
if opt.maxnreg is not None:
243+
mod.set_attr("ttg.maxnreg", ir.builder(mod.context).get_int32_attr(opt.maxnreg))
244+
241245
cluster_info = nvidia.ClusterInfo()
242246
if opt.cluster_dims is not None:
243247
cluster_info.clusterDimX = opt.cluster_dims[0]
@@ -335,12 +339,6 @@ def make_llir(self, src, metadata, options, capability):
335339
llvm.attach_datalayout(llvm_mod, triple, proc, features)
336340
nvidia.set_nvvm_reflect_ftz(llvm_mod)
337341

338-
# Set maxnreg on all kernels, if it was provided.
339-
if options.maxnreg is not None:
340-
for k in llvm_mod.get_functions():
341-
if not k.is_declaration() and k.is_external_linkage():
342-
k.set_nvvm_maxnreg(options.maxnreg)
343-
344342
if options.extern_libs:
345343
paths = [path for (name, path) in options.extern_libs]
346344
llvm.link_extern_libs(llvm_mod, paths)

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertWarpSpecializeToLLVM.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,26 @@ static LogicalResult rewriteWarpGroupBarriers(LLVM::LLVMFuncOp func,
216216
bar.erase();
217217
});
218218
}
219+
220+
if (auto actRegisters = op.getActualRegisters()) {
221+
int maxnreg = func->getParentOfType<ModuleOp>()
222+
->getAttrOfType<IntegerAttr>(AttrMaxRegistersName)
223+
.getInt();
224+
auto b = OpBuilder::atBlockBegin(&op.getDefaultRegion().front());
225+
b.create<NVVM::SetMaxRegisterOp>(op.getLoc(),
226+
std::min(256, actRegisters->front()),
227+
NVVM::SetMaxRegisterAction::increase);
228+
for (auto [actRegs, region] :
229+
llvm::zip(actRegisters->drop_front(), op.getPartitionRegions())) {
230+
if (actRegs == maxnreg)
231+
continue;
232+
auto action = actRegs < maxnreg ? NVVM::SetMaxRegisterAction::decrease
233+
: NVVM::SetMaxRegisterAction::increase;
234+
b.setInsertionPointToStart(&region->front());
235+
b.create<NVVM::SetMaxRegisterOp>(op.getLoc(), std::min(256, actRegs),
236+
action);
237+
}
238+
}
219239
}
220240

221241
return success();

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TensorMemoryToLLVM.cpp

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,8 @@ TMemMessageTraits getTMemMessageFromAtom(const TMemAccessAtom &atom,
100100
// Only allows half of the thread registers to be used for tensor memory access
101101
// to avoid register pressure. This ensures the largest tmem message width is
102102
// used for the workload without inducing spills.
103-
int getTMemMessageNarrowingFactor(int workloadThreadRegs) {
104-
const int allowedRegUsage = maxRegisters / 2;
103+
int getTMemMessageNarrowingFactor(int workloadThreadRegs, int maxnreg) {
104+
const int allowedRegUsage = maxnreg / 2;
105105
int narrowingFactor = 1;
106106
while (workloadThreadRegs > allowedRegUsage) {
107107
workloadThreadRegs /= 2;
@@ -338,13 +338,13 @@ void createWaitOpSt(Location loc, ConversionPatternRewriter &rewriter) {
338338
ptxBuilder.launch(rewriter, loc, void_ty(rewriter.getContext()));
339339
}
340340

341-
TMemMessageTraits selectTMemMessage(const TMemRuntimeInfo &info) {
341+
TMemMessageTraits selectTMemMessage(const TMemRuntimeInfo &info, int maxnreg) {
342342
auto atom = info.useStridedMessage ? TMemAccess16x32bx2 : TMemAccess32x32b;
343343

344344
int totalRegsNeeded =
345345
getEffectiveRegs(info.unpackedb16, info.useStridedMessage,
346346
info.numCols / info.numWarpGroups);
347-
int narrowingFactor = getTMemMessageNarrowingFactor(totalRegsNeeded);
347+
int narrowingFactor = getTMemMessageNarrowingFactor(totalRegsNeeded, maxnreg);
348348
auto narrowedMessage = getTMemMessageFromAtom(atom, narrowingFactor);
349349
narrowedMessage = constrainMessageFromWorkload(narrowedMessage, info,
350350
narrowedMessage.numRegs);
@@ -355,6 +355,35 @@ TMemMessageTraits selectTMemMessage(const TMemRuntimeInfo &info) {
355355
return std::min(narrowedMessage, maxWidthMessage);
356356
}
357357

358+
// Get the maximum number of registers per thread based on the context. This is
359+
// by default 256, but it can be overridden by `ttg.maxnreg` set on the module.
360+
// Alternatively, warp groups within warp specialized regions can have a
361+
// different number of registers allocated.
362+
static int getContextualMaxNReg(Operation *op) {
363+
if (auto mod = dyn_cast<ModuleOp>(op)) {
364+
// Check for a maxnreg attribute.
365+
if (auto attr = op->getAttrOfType<IntegerAttr>(AttrMaxRegistersName))
366+
return std::max<int>(maxRegisters, attr.getInt());
367+
368+
} else if (auto partitions =
369+
dyn_cast<WarpSpecializePartitionsOp>(op->getParentOp())) {
370+
// Check if the partition has reduced registers.
371+
unsigned idx = op->getParentRegion()->getRegionNumber();
372+
if (auto actRegisters = partitions.getParentOp().getActualRegisters())
373+
return std::max<int>(maxRegisters, (*actRegisters)[1 + idx]);
374+
return getContextualMaxNReg(partitions.getParentOp());
375+
376+
} else if (auto wsOp = dyn_cast<WarpSpecializeOp>(op->getParentOp())) {
377+
// Check the register usage of the default warpgroup.
378+
if (auto actRegisters = wsOp.getActualRegisters())
379+
return std::max<int>(maxRegisters, actRegisters->front());
380+
}
381+
382+
if (Operation *parent = op->getParentOp())
383+
return getContextualMaxNReg(parent);
384+
return maxRegisters;
385+
}
386+
358387
static void lowerStoreToTensorMemory(Location loc, Operation *op, Value src,
359388
Value dest, Value llSrc, Value pred,
360389
Value tmemBase,
@@ -365,7 +394,8 @@ static void lowerStoreToTensorMemory(Location loc, Operation *op, Value src,
365394
auto dstType = cast<MemDescType>(dest.getType());
366395
auto info = getTMemRuntimeInfo(op, cast<RankedTensorType>(src.getType()),
367396
cast<MemDescType>(dest.getType()));
368-
const TMemMessageTraits message = selectTMemMessage(info);
397+
const TMemMessageTraits message =
398+
selectTMemMessage(info, getContextualMaxNReg(op));
369399
int regIdx = 0;
370400
calculateAddressAndEmitTmemMessage(
371401
loc, tmemBase, info, message, rewriter,
@@ -503,7 +533,8 @@ struct TensorMemoryLoadOpConversion
503533

504534
auto info = getTMemRuntimeInfo(op, cast<RankedTensorType>(op.getType()),
505535
cast<MemDescType>(op.getSrc().getType()));
506-
const TMemMessageTraits message = selectTMemMessage(info);
536+
const TMemMessageTraits message =
537+
selectTMemMessage(info, getContextualMaxNReg(op));
507538
SmallVector<Value> resultVals;
508539
calculateAddressAndEmitTmemMessage(
509540
loc, tmemBase, info, message, rewriter,

0 commit comments

Comments
 (0)