Skip to content

Commit 71e0eb7

Browse files
Merge commit '711caa4d78a56cc7eb539b61501b904a94bba2db'
2 parents 65f7671 + 711caa4 commit 71e0eb7

File tree

13 files changed

+249
-38
lines changed

13 files changed

+249
-38
lines changed

include/triton/Dialect/TritonGPU/IR/Dialect.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ template <> struct hash<CacheKey> {
3939

4040
namespace mlir::triton::gpu {
4141

42+
constexpr static char AttrMaxRegistersName[] = "ttg.maxnreg";
4243
constexpr static char AttrNumWarpsName[] = "ttg.num-warps";
4344
constexpr static char AttrNumCTAsName[] = "ttg.num-ctas";
4445
constexpr static char AttrTargetName[] = "ttg.target";

include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -376,7 +376,9 @@ def TTG_WarpSpecializeOp : TTG_Op<"warp_specialize", [
376376
let arguments = (ins
377377
Variadic<AnyType>:$explicitCaptures,
378378
DenseI32ArrayAttr:$partitionNumWarps,
379-
OptionalAttr<DenseI32ArrayAttr>:$warpGroupStartIds
379+
OptionalAttr<DenseI32ArrayAttr>:$warpGroupStartIds,
380+
OptionalAttr<DenseI32ArrayAttr>:$requestedRegisters,
381+
OptionalAttr<DenseI32ArrayAttr>:$actualRegisters
380382
);
381383
let results = (outs Variadic<AnyType>:$defaultPassthrough);
382384

lib/Conversion/TritonGPUToLLVM/AllocateWarpGroups.cpp

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,20 @@ struct AllocateWarpGroups
1818
void runOnOperation() override {
1919
ModuleOp mod = getOperation();
2020

21+
int threadsPerWarp = TritonGPUDialect::getThreadsPerWarp(mod);
22+
23+
struct WarpGroupInfo {
24+
SmallVector<Region *> partitions;
25+
int maxRequestedRegs = 0;
26+
unsigned numWarps = 0;
27+
};
28+
struct WarpGroupPartition {
29+
int startId;
30+
Region *partition;
31+
int32_t estRegs;
32+
int numWarps;
33+
};
34+
2135
// Compute the total number of warps required at any given time.
2236
int baseNumWarps = lookupNumWarps(mod);
2337
int maxExtraWarps = 0;
@@ -42,6 +56,81 @@ struct AllocateWarpGroups
4256
startId += size;
4357
}
4458
op.setWarpGroupStartIds(startIds);
59+
60+
// Require that an estimate has been set and that we have even warpgroups.
61+
auto regsAttr = op.getRequestedRegisters();
62+
if (!regsAttr || op.getTotalPartitionWarps() % 4 != 0)
63+
return;
64+
65+
// Group the partitions into warpgroups.
66+
SmallVector<WarpGroupPartition> orderedPartitions;
67+
for (auto [startId, partition, estRegs, numWarps] :
68+
llvm::zip(startIds, op.getPartitionRegions(), *regsAttr, arr))
69+
orderedPartitions.push_back({startId, partition, estRegs, numWarps});
70+
llvm::sort(orderedPartitions,
71+
[&](auto lhs, auto rhs) { return lhs.startId < rhs.startId; });
72+
73+
// Iterate over the partitions and assign them to warp groups. Determine
74+
// the maximum number of requested registers per warp group.
75+
SmallVector<WarpGroupInfo> warpGroups;
76+
for (auto [startId, partition, estRegs, numWarps] : orderedPartitions) {
77+
if (startId % 4 == 0) {
78+
warpGroups.push_back(WarpGroupInfo{});
79+
}
80+
warpGroups.back().partitions.push_back(partition);
81+
// Round up the nearest multiple of 8.
82+
int estRegsCeil8 = llvm::divideCeil(estRegs, 8) * 8;
83+
warpGroups.back().maxRequestedRegs =
84+
std::max<int>(warpGroups.back().maxRequestedRegs, estRegsCeil8);
85+
warpGroups.back().numWarps += numWarps;
86+
}
87+
88+
// Determine the maximum number of registers per thread. This may have
89+
// been set by the user.
90+
int maxnreg;
91+
if (auto maxnregAttr =
92+
op->getAttrOfType<IntegerAttr>(AttrMaxRegistersName)) {
93+
maxnreg = maxnregAttr.getInt();
94+
} else {
95+
maxnreg = (1 << 16) / (baseNumWarps + op.getTotalPartitionWarps()) /
96+
threadsPerWarp;
97+
maxnreg = maxnreg / 8 * 8;
98+
}
99+
100+
// Compute the register deficit over the partition warp groups.
101+
int registerDeficit = 0;
102+
for (const WarpGroupInfo &wg : warpGroups) {
103+
assert(wg.numWarps % 4 == 0);
104+
registerDeficit +=
105+
(maxnreg - wg.maxRequestedRegs) * wg.numWarps * threadsPerWarp;
106+
}
107+
if (registerDeficit <= 0)
108+
return;
109+
110+
// Determine the number of extra registers that we can distribute to the
111+
// default warp group.
112+
int leftover =
113+
((baseNumWarps * threadsPerWarp * maxnreg) + registerDeficit) /
114+
baseNumWarps / threadsPerWarp;
115+
// Round down to the nearest multiple of 8.
116+
leftover = leftover / 8 * 8;
117+
118+
// Generate setmaxnreg in each partition according to its warp group.
119+
SmallVector<int32_t> maxnregsPerPartition(1 + arr.size());
120+
for (const WarpGroupInfo &wg : warpGroups) {
121+
for (Region *region : wg.partitions) {
122+
maxnregsPerPartition[1 + region->getRegionNumber()] =
123+
wg.maxRequestedRegs;
124+
}
125+
}
126+
// Set the register usage for the default warp group.
127+
maxnregsPerPartition.front() = leftover;
128+
op.setActualRegisters(maxnregsPerPartition);
129+
130+
// Set the initial max number of registers. This is needed for PTXAS to
131+
// cooperate.
132+
mod->setAttr(AttrMaxRegistersName,
133+
Builder(op.getContext()).getI32IntegerAttr(maxnreg));
45134
});
46135

47136
Builder b(&getContext());

lib/Conversion/TritonGPUToLLVM/FuncOpToLLVM.cpp

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -125,11 +125,11 @@ struct FuncOpConversion : public ConvertOpToLLVMPattern<triton::FuncOp> {
125125
mlir::IntegerType::get(llvmFuncOp.getContext(), 8);
126126
const auto arrayType = mlir::LLVM::LLVMArrayType::get(
127127
llvmFuncOp.getContext(), byteType, 128);
128-
llvmFuncOp.setArgAttr(i, "llvm.byval",
128+
llvmFuncOp.setArgAttr(i, LLVM::LLVMDialect::getByValAttrName(),
129129
mlir::TypeAttr::get(arrayType));
130-
llvmFuncOp.setArgAttr(i, "nvvm.grid_constant",
130+
llvmFuncOp.setArgAttr(i, NVVM::NVVMDialect::getGridConstantAttrName(),
131131
mlir::UnitAttr::get(llvmFuncOp.getContext()));
132-
llvmFuncOp.setArgAttr(i, "llvm.align",
132+
llvmFuncOp.setArgAttr(i, LLVM::LLVMDialect::getAlignAttrName(),
133133
mlir::IntegerAttr::get(i32_type, 64));
134134
}
135135
}
@@ -155,7 +155,7 @@ struct FuncOpConversion : public ConvertOpToLLVMPattern<triton::FuncOp> {
155155

156156
if (LLVM::isKernel(funcOp)) {
157157
// Set an attribute to indicate this function is a kernel entry.
158-
newFuncOp->setAttr("nvvm.kernel",
158+
newFuncOp->setAttr(NVVM::NVVMDialect::getKernelFuncAttrName(),
159159
rewriter.getIntegerAttr(type::u1Ty(ctx), 1));
160160
newFuncOp.setLinkage(LLVM::Linkage::External);
161161
} else {
@@ -166,12 +166,20 @@ struct FuncOpConversion : public ConvertOpToLLVMPattern<triton::FuncOp> {
166166
ArrayAttr::get(ctx, rewriter.getStringAttr("noinline")));
167167
newFuncOp.setLinkage(LLVM::Linkage::Internal);
168168
}
169-
// Set an attribute for reqntidx, it could be used in latter LLVM codegen
170-
// for `nvvm.annotation` metadata.
169+
170+
// Determine the actual number of required warps.
171171
int numWarps = triton::gpu::lookupNumWarps(funcOp);
172172
if (auto totalNumWarps = funcOp.getParentOp()->getAttrOfType<IntegerAttr>(
173173
"ttg.total-num-warps"))
174174
numWarps = totalNumWarps.getInt();
175+
176+
// Set `nvvm.maxnreg` if it was specified on the module.
177+
if (Attribute maxnregAttr =
178+
funcOp.getParentOp()->getAttr(triton::gpu::AttrMaxRegistersName))
179+
newFuncOp->setAttr(NVVM::NVVMDialect::getMaxnregAttrName(), maxnregAttr);
180+
181+
// Set an attribute for reqntidx, it could be used in latter LLVM codegen
182+
// for `nvvm.annotation` metadata.
175183
newFuncOp->setAttr(NVVM::NVVMDialect::getReqntidAttrName(),
176184
rewriter.getDenseI32ArrayAttr(32 * numWarps));
177185

lib/Dialect/TritonGPU/IR/Ops.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -771,7 +771,7 @@ void WarpSpecializeOp::build(OpBuilder &builder, OperationState &state,
771771
ArrayRef<int32_t> partitionNumWarps,
772772
unsigned partitionNumRegions) {
773773
build(builder, state, resultTypes, /*explicitCaptures=*/ValueRange(),
774-
partitionNumWarps, /*warpGroupStartIds=*/{});
774+
partitionNumWarps, {}, {}, {});
775775
OpBuilder::InsertionGuard guard(builder);
776776
Block *container = builder.createBlock(state.regions.back().get());
777777
builder.create<WarpSpecializePartitionsOp>(state.location,

lib/Dialect/TritonGPU/Transforms/WarpSpecialization/LoadMMASpecialization.cpp

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -381,17 +381,29 @@ LogicalResult triton::gpu::specializeLoadMMADependencies(scf::ForOp &loop,
381381
donePt.getPoint()->isBeforeInBlock(&*b.getInsertionPoint()));
382382
donePt = b.saveInsertionPoint();
383383

384-
// Acquire and get the accumulator result.
385-
b.setInsertionPoint(domOp);
386384
Partition *userPartition = schedule.addPartition(numStages + numMmaStages);
385+
// Acquire and get the accumulator result. Normally, we want to acquire the
386+
// accumulator for as small of a critical section as possible to unblock
387+
// dependents, but if the most dominating user is inside a conditional,
388+
// acquire the accumulator for the whole branch. This will improve
389+
// instruction scheduling and interleaving of the TMEM load.
390+
bool userInConditional = isa<scf::IfOp>(domOp->getParentOp());
391+
b.setInsertionPoint(domOp);
392+
if (userInConditional)
393+
b.setInsertionPointToStart(domOp->getBlock());
387394
createInPartition<ttng::WaitBarrierOp>(b, *userPartition, curAccReadyBar,
388395
accPhase);
396+
397+
b.setInsertionPoint(domOp);
389398
Value acc = createInPartition<ttng::TMEMLoadOp>(
390399
b, *userPartition, info.accLoad.getType(), curAccBuf);
391400
for (Operation *user : accUses)
392401
user->replaceUsesOfWith(info.accLoad, acc);
402+
393403
// Signal the accumulator buffer is ready for the next iteration. Because
394404
// the mbarriers got shifted over by 1, we have to signal the next mbarrier.
405+
if (userInConditional)
406+
b.setInsertionPoint(domOp->getBlock()->getTerminator());
395407
Value nextIndex =
396408
b.create<arith::AddIOp>(accIndex, intCst(numMmaStages - 1));
397409
nextIndex = b.create<arith::RemUIOp>(nextIndex, intCst(numMmaStages));

lib/Dialect/TritonGPU/Transforms/WarpSpecialization/OptimizePartitionWarps.cpp

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,13 @@
55
#include "triton/Analysis/AxisInfo.h"
66
#include "triton/Conversion/TritonToTritonGPU/Passes.h"
77
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
8+
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
89
#include "llvm/ADT/ScopeExit.h"
910

1011
using namespace mlir;
1112
using namespace triton;
1213
using namespace triton::gpu;
14+
namespace ttng = triton::nvidia_gpu;
1315

1416
//===----------------------------------------------------------------------===//
1517
// relayoutWarps
@@ -182,14 +184,28 @@ static LogicalResult optimizePartitionNumWarps(ModuleAxisInfoAnalysis &axisInfo,
182184
// If the compiler could control that, then we could allow non-uniform
183185
// register distributions, mostly beneficial for single-warp warpgroups that
184186
// just do some artihmetic.
185-
constexpr unsigned nTotalRegs = 65536; // for Blackwell SMs
187+
constexpr unsigned nTotalRegs = 1 << 16; // for Blackwell SMs
186188
const unsigned threadsPerWarp =
187189
TritonGPUDialect::getThreadsPerWarp(axisInfo.getModuleOp());
188190
const unsigned defaultNumWarps = lookupNumWarps(wsOp);
189191

190192
SmallVector<int32_t> partitionNumWarps =
191193
llvm::to_vector(wsOp.getPartitionNumWarps());
192194

195+
// Some instructions have critical throughput if have low register usage. Make
196+
// sure there are enough warps for these ops to execute quickly.
197+
SmallVector<int32_t> minWarpsForPartition(partitionNumWarps.size(), 1);
198+
for (auto [minWarps, region] :
199+
llvm::zip(minWarpsForPartition, wsOp.getPartitionRegions())) {
200+
region->walk([minWarps = &minWarps](Operation *op) {
201+
if (!isa<scf::ForOp>(op->getParentOp()))
202+
return;
203+
if (isa<ttng::AsyncTMAGatherOp, ttng::AsyncTMAScatterOp,
204+
ttng::AsyncTMACopyGlobalToLocalOp>(op))
205+
*minWarps = 2;
206+
});
207+
}
208+
193209
bool changed;
194210
do {
195211
changed = false;
@@ -215,9 +231,9 @@ static LogicalResult optimizePartitionNumWarps(ModuleAxisInfoAnalysis &axisInfo,
215231
int32_t curTotalNumWarps = std::accumulate(
216232
partitionNumWarps.begin(), partitionNumWarps.end(), defaultNumWarps);
217233

218-
for (auto [numWarps, tensorRegs] :
219-
llvm::zip(partitionNumWarps, maxTensorRegs)) {
220-
if (numWarps == 1)
234+
for (auto [minWarps, numWarps, tensorRegs] :
235+
llvm::zip(minWarpsForPartition, partitionNumWarps, maxTensorRegs)) {
236+
if (numWarps <= minWarps)
221237
continue;
222238
// Check if reducing the number of warps will still fit the tensor. If it
223239
// didn't fit to begin with, it won't fit after shrinking.
@@ -233,16 +249,23 @@ static LogicalResult optimizePartitionNumWarps(ModuleAxisInfoAnalysis &axisInfo,
233249
}
234250
} while (changed);
235251

236-
for (auto [partition, newNumWarps, prevNumWarps, tensorRegs] :
252+
SmallVector<int32_t> estRegUsage(partitionNumWarps.size());
253+
for (auto [partition, newNumWarps, prevNumWarps, tensorRegs, estRegs] :
237254
llvm::zip(wsOp.getPartitionRegions(), partitionNumWarps,
238-
wsOp.getPartitionNumWarps(), maxTensorRegs)) {
255+
wsOp.getPartitionNumWarps(), maxTensorRegs, estRegUsage)) {
256+
// "Guess" the register usage for each partition.
257+
estRegs = tensorRegs ? 80 : 48;
258+
259+
// Layouts need to be reassigned if the number of warps changed and there
260+
// are tensor computations.
239261
if (newNumWarps == prevNumWarps || !tensorRegs)
240262
continue;
241263
// We need to reassign layouts.
242264
if (failed(relayoutWarps(axisInfo, partition, prevNumWarps, newNumWarps,
243265
runPipeline)))
244266
return failure();
245267
}
268+
wsOp.setRequestedRegisters(estRegUsage);
246269
wsOp.setPartitionNumWarps(partitionNumWarps);
247270
return success();
248271
}

test/Conversion/allocate_warp_groups.mlir

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,3 +63,28 @@ tt.func @two_warp_specialize() {
6363
}
6464

6565
}
66+
67+
// -----
68+
69+
// CHECK: module attributes {ttg.maxnreg = 168 : i32
70+
module attributes {"ttg.num-warps" = 8 : i32} {
71+
72+
tt.func @setmaxnreg() {
73+
// CHECK: actualRegisters = array<i32: 208, 80, 80, 80>
74+
ttg.warp_specialize() attributes {requestedRegisters = array<i32: 48, 80, 48>}
75+
default {
76+
ttg.warp_yield
77+
}
78+
partition0() num_warps(1) {
79+
ttg.warp_return
80+
}
81+
partition1() num_warps(2) {
82+
ttg.warp_return
83+
}
84+
partition2() num_warps(1) {
85+
ttg.warp_return
86+
} : () -> ()
87+
tt.return
88+
}
89+
90+
}

test/TritonGPU/automatic-warp-specialization.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ tt.func @matmul_change_desc_in_prologue(
3232
// BASE-NOT: tt.make_tensor_descriptor
3333
// PIPELINE-NOT: tt.experimental_tensormap_create
3434
// CHECK-LABEL: partition1
35-
// CHECK-SAME: num_warps(1)
35+
// CHECK-SAME: num_warps(2)
3636
// BASE-COUNT-2: tt.make_tensor_descriptor
3737
// PIPELINE-COUNT-2: ttg.global_scratch_alloc {alignment = 128 : i32, nbytes = 512 : i32}
3838
// PIPELINE-COUNT-2: tt.experimental_tensormap_create
@@ -87,7 +87,7 @@ tt.func @matmul_tma_acc_with_conditional_def_and_use(
8787
// CHECK-LABEL: partition0
8888
// CHECK-SAME: num_warps(1)
8989
// CHECK-LABEL: partition1
90-
// CHECK-SAME: num_warps(1)
90+
// CHECK-SAME: num_warps(2)
9191
// CHECK: [[INDICES:%.*]] = tt.splat %{{.*}} : i32 -> tensor<128xi32,
9292
// CHECK: ttng.async_tma_gather %{{.*}}[[[INDICES]],
9393
// CHECK-LABEL: partition2

0 commit comments

Comments
 (0)