Skip to content

Commit ca1ce1b

Browse files
authored
[Backend] Improve how dynamic register reallocation is implemented (triton-lang#6694)
This PR generally improves how the compiler handles dynamic register reallocation for warp specialization. * Codegen generates `setmaxnreg` both at the entry and exit of the partition regions, even when the number of registers does not change. This somehow makes `ptxas` behave much better, allowing the registers allocated to the load and MMA partitions to drop to `24` as they should be. This should improve register pressure across the board for warp specialized kernels. * The maximum number of warpgroups is computed across the whole program and each `ttg.warp_specialize` is padded to it. This ensures all warps are always present to surrender registers. This primarily improves the layering between partitioning in the middle end (no longer need the "extra empty warp" hack). * Handle TMEM ops when relayout'ing the IR (they require a minimum of 4 warps). Thankfully, the TMEM compatible distributed layout can always be inferred for these ops.
1 parent 7ad7cee commit ca1ce1b

File tree

13 files changed

+452
-109
lines changed

13 files changed

+452
-109
lines changed

lib/Conversion/TritonGPUToLLVM/AllocateWarpGroups.cpp

Lines changed: 76 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,89 @@ using namespace mlir;
1111
using namespace mlir::triton;
1212
using namespace mlir::triton::gpu;
1313

14+
// Given a `ttg.warp_specialize` with a certain number of existing warps, pad it
15+
// with extra warps until it has the same number of full warp groups as the
16+
// largest partitioning. This ensures that all threads can be present to
17+
// surrender registers.
18+
static void padToMaxWarpGroups(WarpSpecializeOp op, int numExtraWarpGroups) {
19+
int numExtraWarps = op.getTotalPartitionWarps();
20+
int warpsToAdd = numExtraWarpGroups * 4 - numExtraWarps;
21+
assert(warpsToAdd >= 0);
22+
23+
// Fill it with powers of 2.
24+
SmallVector<int> paddingPartitionSizes;
25+
while (warpsToAdd > 0) {
26+
int paddingSize = llvm::NextPowerOf2(warpsToAdd) / 2;
27+
paddingPartitionSizes.push_back(paddingSize);
28+
warpsToAdd -= paddingSize;
29+
}
30+
31+
auto partitions = cast<WarpSpecializePartitionsOp>(
32+
op.getPartitionOpHolder().front().front());
33+
OperationState state(partitions.getLoc(), partitions.getOperationName());
34+
for (Region *region : partitions.getRegions())
35+
state.addRegion()->takeBody(*region);
36+
37+
SmallVector<int32_t> partitionNumWarps(op.getPartitionNumWarps());
38+
for (int paddingSize : paddingPartitionSizes) {
39+
partitionNumWarps.push_back(paddingSize);
40+
41+
Block &body = state.addRegion()->emplaceBlock();
42+
for (Value capture : op.getExplicitCaptures())
43+
body.addArgument(capture.getType(), capture.getLoc());
44+
OpBuilder b(op.getContext());
45+
b.setInsertionPointToStart(&body);
46+
b.create<WarpReturnOp>(op.getLoc());
47+
}
48+
op.setPartitionNumWarps(partitionNumWarps);
49+
50+
// Set the requested registers to low for the padded partitions that do
51+
// nothing.
52+
if (auto reqRegs = op.getRequestedRegisters()) {
53+
SmallVector<int32_t> newReqRegs(*reqRegs);
54+
newReqRegs.append(paddingPartitionSizes.size(), 16);
55+
op.setRequestedRegisters(newReqRegs);
56+
}
57+
58+
OpBuilder b(partitions);
59+
b.create(state);
60+
partitions.erase();
61+
}
62+
1463
namespace {
1564
struct AllocateWarpGroups
1665
: public mlir::triton::gpu::impl::TritonGPUAllocateWarpGroupsBase<
1766
AllocateWarpGroups> {
1867
void runOnOperation() override {
1968
ModuleOp mod = getOperation();
2069

70+
// First determine the maximum number of extra warps.
71+
int maxExtraWarps = 0;
72+
mod.walk([&](WarpSpecializeOp op) {
73+
maxExtraWarps = std::max<int>(maxExtraWarps, op.getTotalPartitionWarps());
74+
});
75+
76+
// Round this up to the nearest warpgroup (multiple of 4) and then pad each
77+
// `ttg.warp_specialize` to the nearest warpgroup.
78+
int numExtraWarpGroups = llvm::divideCeil(maxExtraWarps, 4);
79+
mod.walk([&](WarpSpecializeOp op) {
80+
padToMaxWarpGroups(op, numExtraWarpGroups);
81+
});
82+
83+
// Determine the maximum number of registers per thread. This may have
84+
// been set by the user.
2185
int threadsPerWarp = TritonGPUDialect::getThreadsPerWarp(mod);
86+
int baseNumWarps = lookupNumWarps(mod);
87+
int maxnreg;
88+
if (auto maxnregAttr =
89+
mod->getAttrOfType<IntegerAttr>(AttrMaxRegistersName)) {
90+
maxnreg = maxnregAttr.getInt();
91+
} else {
92+
// Assume the user wants to use all 64K registers.
93+
maxnreg = (64 * 1024) / (baseNumWarps + numExtraWarpGroups * 4) /
94+
threadsPerWarp;
95+
maxnreg = maxnreg / 8 * 8;
96+
}
2297

2398
struct WarpGroupInfo {
2499
SmallVector<Region *> partitions;
@@ -33,12 +108,8 @@ struct AllocateWarpGroups
33108
};
34109

35110
// Compute the total number of warps required at any given time.
36-
int baseNumWarps = lookupNumWarps(mod);
37-
int maxExtraWarps = 0;
38111
mod.walk([&](WarpSpecializeOp op) {
39112
ArrayRef<int32_t> arr = op.getPartitionNumWarps();
40-
int req = op.getTotalPartitionWarps();
41-
maxExtraWarps = std::max(maxExtraWarps, req);
42113

43114
// Allocate the start IDs such that the largest warpgroups have lower
44115
// starting warp IDs.
@@ -85,18 +156,6 @@ struct AllocateWarpGroups
85156
warpGroups.back().numWarps += numWarps;
86157
}
87158

88-
// Determine the maximum number of registers per thread. This may have
89-
// been set by the user.
90-
int maxnreg;
91-
if (auto maxnregAttr =
92-
op->getAttrOfType<IntegerAttr>(AttrMaxRegistersName)) {
93-
maxnreg = maxnregAttr.getInt();
94-
} else {
95-
maxnreg = (1 << 16) / (baseNumWarps + op.getTotalPartitionWarps()) /
96-
threadsPerWarp;
97-
maxnreg = maxnreg / 8 * 8;
98-
}
99-
100159
// Compute the register deficit over the partition warp groups.
101160
int registerDeficit = 0;
102161
for (const WarpGroupInfo &wg : warpGroups) {
@@ -135,7 +194,7 @@ struct AllocateWarpGroups
135194

136195
Builder b(&getContext());
137196
mod->setAttr("ttg.total-num-warps",
138-
b.getI32IntegerAttr(baseNumWarps + maxExtraWarps));
197+
b.getI32IntegerAttr(baseNumWarps + numExtraWarpGroups * 4));
139198
}
140199
};
141200
} // namespace

lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp

Lines changed: 105 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ namespace {
2525
using namespace mlir;
2626
using namespace mlir::triton;
2727
using namespace mlir::triton::gpu;
28+
namespace ttng = triton::nvidia_gpu;
2829

2930
// pass named attrs (e.g., tt.contiguity) from Triton to Triton
3031
static void addNamedAttrs(Operation *op, DictionaryAttr dictAttrs) {
@@ -466,6 +467,72 @@ struct GatherScatterOpPattern : public OpConversionPattern<OpT> {
466467
}
467468
};
468469

470+
// Given a tensor and its representation in tensor memory, determine its
471+
// distributed layout.
472+
static RankedTensorType getTMEMTensorLayout(const TypeConverter *tc,
473+
RankedTensorType type,
474+
MemDescType memdesc,
475+
unsigned numWarps) {
476+
Attribute encoding;
477+
type = cast<RankedTensorType>(tc->convertType(type));
478+
if (isa<ttng::TensorMemoryScalesEncodingAttr>(memdesc.getEncoding())) {
479+
encoding = LinearEncodingAttr::get(
480+
type.getContext(), getScaleTMEMStoreLinearLayout(type, numWarps));
481+
} else {
482+
auto tmemEnc = cast<ttng::TensorMemoryEncodingAttr>(memdesc.getEncoding());
483+
encoding = ttng::getTmemCompatibleLayout(
484+
tmemEnc.getBlockM(), tmemEnc.getBlockN(), type, numWarps);
485+
}
486+
return RankedTensorType::get(type.getShape(), type.getElementType(),
487+
encoding);
488+
}
489+
490+
struct TMEMLoadOpPattern : public OpConversionPattern<ttng::TMEMLoadOp> {
491+
using OpConversionPattern::OpConversionPattern;
492+
493+
LogicalResult
494+
matchAndRewrite(ttng::TMEMLoadOp op, OpAdaptor adaptor,
495+
ConversionPatternRewriter &rewriter) const override {
496+
RankedTensorType type = getTMEMTensorLayout(
497+
typeConverter, op.getType(), op.getSrc().getType(), lookupNumWarps(op));
498+
rewriter.modifyOpInPlace(op, [&] { op.getResult().setType(type); });
499+
return success();
500+
}
501+
};
502+
503+
struct TMEMStoreOpPattern : public OpConversionPattern<ttng::TMEMStoreOp> {
504+
using OpConversionPattern::OpConversionPattern;
505+
506+
LogicalResult
507+
matchAndRewrite(ttng::TMEMStoreOp op, OpAdaptor adaptor,
508+
ConversionPatternRewriter &rewriter) const override {
509+
RankedTensorType type =
510+
getTMEMTensorLayout(typeConverter, op.getSrc().getType(),
511+
op.getDst().getType(), lookupNumWarps(op));
512+
Value src =
513+
rewriter.create<ConvertLayoutOp>(op.getLoc(), type, adaptor.getSrc());
514+
rewriter.modifyOpInPlace(op, [&] { op.getSrcMutable().assign(src); });
515+
return success();
516+
}
517+
};
518+
519+
struct TMEMAllocOpPattern : public OpConversionPattern<ttng::TMEMAllocOp> {
520+
using OpConversionPattern::OpConversionPattern;
521+
522+
LogicalResult
523+
matchAndRewrite(ttng::TMEMAllocOp op, OpAdaptor adaptor,
524+
ConversionPatternRewriter &rewriter) const override {
525+
if (!op.getSrc())
526+
return success();
527+
RankedTensorType type = getTMEMTensorLayout(
528+
typeConverter, op.getSrc().getType(), op.getType(), lookupNumWarps(op));
529+
Value src =
530+
rewriter.create<ConvertLayoutOp>(op.getLoc(), type, adaptor.getSrc());
531+
rewriter.modifyOpInPlace(op, [&] { op.getSrcMutable().assign(src); });
532+
return success();
533+
}
534+
};
535+
469536
struct TritonTransPattern : public OpConversionPattern<TransOp> {
470537
using OpConversionPattern::OpConversionPattern;
471538

@@ -592,40 +659,61 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
592659
MLIRContext *context = patterns.getContext();
593660
patterns.insert< // TODO: view should have custom pattern that views the
594661
// layout
662+
// clang-format off
595663
GenericOpPattern<triton::AdvanceOp>,
596664
GenericOpPattern<triton::MakeTensorPtrOp>,
597-
GenericOpPattern<triton::ReshapeOp>, GenericOpPattern<triton::BitcastOp>,
598-
GenericOpPattern<triton::FpToFpOp>, GenericOpPattern<triton::IntToPtrOp>,
599-
GenericOpPattern<triton::PtrToIntOp>, GenericOpPattern<triton::SplatOp>,
600-
TritonBroadcastPattern, GenericOpPattern<triton::AddPtrOp>,
601-
TritonCatPattern, TritonJoinOpPattern, TritonSplitOpPattern,
665+
GenericOpPattern<triton::ReshapeOp>,
666+
GenericOpPattern<triton::BitcastOp>,
667+
GenericOpPattern<triton::FpToFpOp>,
668+
GenericOpPattern<triton::IntToPtrOp>,
669+
GenericOpPattern<triton::PtrToIntOp>,
670+
GenericOpPattern<triton::SplatOp>,
671+
GenericOpPattern<triton::AddPtrOp>,
672+
TritonBroadcastPattern,
673+
TritonCatPattern,
674+
TritonJoinOpPattern,
675+
TritonSplitOpPattern,
602676
GenericOpPattern<triton::ClampFOp>,
603677
GenericOpPattern<triton::PreciseSqrtOp>,
604678
GenericOpPattern<triton::PreciseDivFOp>,
605679
GenericOpPattern<triton::MulhiUIOp>,
606-
GenericOpPattern<triton::ElementwiseInlineAsmOp>, TritonReducePattern,
607-
GenericOpPattern<triton::ReduceReturnOp>, TritonScanPattern,
680+
GenericOpPattern<triton::ElementwiseInlineAsmOp>,
681+
TritonReducePattern,
682+
GenericOpPattern<triton::ReduceReturnOp>,
683+
TritonScanPattern,
608684
GenericOpPattern<triton::ScanReturnOp>,
609-
GenericOpPattern<triton::MakeRangeOp>, TritonExpandDimsPattern,
610-
TritonTransPattern, TritonDotPattern,
685+
GenericOpPattern<triton::MakeRangeOp>,
686+
TritonExpandDimsPattern,
687+
TritonTransPattern,
688+
TritonDotPattern,
611689
GatherScatterOpPattern<DescriptorGatherOp>,
612690
GatherScatterOpPattern<DescriptorScatterOp>,
613-
GatherScatterOpPattern<triton::nvidia_gpu::AsyncTMAGatherOp>,
614-
GatherScatterOpPattern<triton::nvidia_gpu::AsyncTMAScatterOp>,
615-
GenericOpPattern<triton::LoadOp>, GenericOpPattern<triton::StoreOp>,
616-
GenericOpPattern<triton::HistogramOp>, GenericOpPattern<triton::GatherOp>,
691+
GatherScatterOpPattern<ttng::AsyncTMAGatherOp>,
692+
GatherScatterOpPattern<ttng::AsyncTMAScatterOp>,
693+
TMEMLoadOpPattern,
694+
TMEMStoreOpPattern,
695+
TMEMAllocOpPattern,
696+
GenericOpPattern<triton::LoadOp>,
697+
GenericOpPattern<triton::StoreOp>,
698+
GenericOpPattern<triton::HistogramOp>,
699+
GenericOpPattern<triton::GatherOp>,
617700
GenericOpPattern<triton::ExternElementwiseOp>,
618-
GenericOpPattern<triton::PrintOp>, GenericOpPattern<triton::AssertOp>,
701+
GenericOpPattern<triton::PrintOp>,
702+
GenericOpPattern<triton::AssertOp>,
619703
GenericOpPattern<triton::AtomicCASOp>,
620-
GenericOpPattern<triton::AtomicRMWOp>, GenericOpPattern<ReturnOp>,
704+
GenericOpPattern<triton::AtomicRMWOp>,
621705
GenericOpPattern<triton::DescriptorLoadOp>,
622706
GenericOpPattern<triton::DescriptorStoreOp>,
623707
GenericOpPattern<triton::DescriptorReduceOp>,
624708
GenericOpPattern<triton::ExperimentalTensormapCreateOp>,
625709
GenericOpPattern<triton::ExperimentalTensormapFenceproxyAcquireOp>,
626710
// this assumes the right layout will be set later for dot scaled.
627-
GenericOpPattern<triton::DotScaledOp>, GenericOpPattern<triton::CallOp>,
628-
TritonFuncOpPattern>(typeConverter, context);
711+
GenericOpPattern<triton::DotScaledOp>,
712+
GenericOpPattern<triton::CallOp>,
713+
GenericOpPattern<ReturnOp>,
714+
TritonFuncOpPattern
715+
// clang-format on
716+
>(typeConverter, context);
629717
}
630718
// Proton patterns
631719
// NOTE: Because Proton's inputs are scalars and not tensors this conversion

lib/Dialect/TritonGPU/Transforms/WarpSpecialization/LoadMMASpecialization.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -213,10 +213,8 @@ static WarpSchedule getInitialSchedule(const PartitionScheme &scheme) {
213213
userPartition->insert(userOp);
214214
// Place the epilogue partition in the default warpgroup. The MMA and load
215215
// partitions shouldn't have tensor computations in them, which means they
216-
// will get assigned just 1 warp each. Add an extra partition to pad the
217-
// number of warps to the nearest warpgroup.
218-
schedule.addPartition(0);
219-
schedule.reorderPartitions({2, 1, 0, 3});
216+
// will get assigned just 1 warp each.
217+
schedule.reorderPartitions({2, 1, 0});
220218
}
221219

222220
schedule.updatePartitions();

lib/Dialect/TritonGPU/Transforms/WarpSpecialization/OptimizePartitionWarps.cpp

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@ static LogicalResult relayoutWarps(ModuleAxisInfoAnalysis &axisInfo,
127127
pm.addPass(createTritonGPUCoalesce());
128128
pm.addPass(createTritonGPURemoveLayoutConversions());
129129
pm.addPass(createTritonGPUOptimizeThreadLocality());
130+
pm.addPass(createTritonGPUAccelerateMatmul());
130131
pm.addPass(createTritonGPURemoveLayoutConversions());
131132
if (failed(runPipeline(pm, *container)))
132133
return failure();
@@ -192,17 +193,19 @@ static LogicalResult optimizePartitionNumWarps(ModuleAxisInfoAnalysis &axisInfo,
192193
SmallVector<int32_t> partitionNumWarps =
193194
llvm::to_vector(wsOp.getPartitionNumWarps());
194195

195-
// Some instructions have critical throughput if have low register usage. Make
196-
// sure there are enough warps for these ops to execute quickly.
196+
// Determine if a partition has a lower limit on the number of warps.
197197
SmallVector<int32_t> minWarpsForPartition(partitionNumWarps.size(), 1);
198198
for (auto [minWarps, region] :
199199
llvm::zip(minWarpsForPartition, wsOp.getPartitionRegions())) {
200200
region->walk([minWarps = &minWarps](Operation *op) {
201-
if (!isa<scf::ForOp>(op->getParentOp()))
202-
return;
201+
// Some instructions have critical throughput if have low register usage.
202+
// Make sure there are enough warps for these ops to execute quickly.
203203
if (isa<ttng::AsyncTMAGatherOp, ttng::AsyncTMAScatterOp,
204204
ttng::AsyncTMACopyGlobalToLocalOp>(op))
205205
*minWarps = 2;
206+
// TMEM ops require at least 4 warps to be able to read all lanes.
207+
else if (isa<ttng::TMEMLoadOp, ttng::TMEMStoreOp, ttng::TMEMAllocOp>(op))
208+
*minWarps = 4;
206209
});
207210
}
208211

@@ -254,7 +257,7 @@ static LogicalResult optimizePartitionNumWarps(ModuleAxisInfoAnalysis &axisInfo,
254257
llvm::zip(wsOp.getPartitionRegions(), partitionNumWarps,
255258
wsOp.getPartitionNumWarps(), maxTensorRegs, estRegUsage)) {
256259
// "Guess" the register usage for each partition.
257-
estRegs = tensorRegs ? 80 : 48;
260+
estRegs = tensorRegs ? 72 : 24;
258261

259262
// Layouts need to be reassigned if the number of warps changed and there
260263
// are tensor computations.

test/Conversion/allocate_warp_groups.mlir

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,11 @@ module attributes {"ttg.num-warps" = 4 : i32} {
66

77
// -----
88

9-
// CHECK: module attributes {"ttg.num-warps" = 4 : i32, "ttg.total-num-warps" = 17 : i32}
9+
// CHECK: module attributes {"ttg.num-warps" = 4 : i32, "ttg.total-num-warps" = 20 : i32}
1010
module attributes {"ttg.num-warps" = 4 : i32} {
1111

1212
tt.func @kernel() {
13-
// CHECK: ttg.warp_specialize() attributes {warpGroupStartIds = array<i32: 16, 4, 12>}
13+
// CHECK: ttg.warp_specialize() attributes {warpGroupStartIds = array<i32: 18, 4, 12, 16, 19>}
1414
ttg.warp_specialize()
1515
default {
1616
ttg.warp_yield
@@ -24,18 +24,20 @@ tt.func @kernel() {
2424
partition2() num_warps(4) {
2525
ttg.warp_return
2626
} : () -> ()
27+
// CHECK: partition3() num_warps(2)
28+
// CHECK: partition4() num_warps(1)
2729
tt.return
2830
}
2931

3032
}
3133

3234
// -----
3335

34-
// CHECK: module attributes {"ttg.num-warps" = 2 : i32, "ttg.total-num-warps" = 11 : i32}
35-
module attributes {"ttg.num-warps" = 2 : i32} {
36+
// CHECK: module attributes {"ttg.num-warps" = 4 : i32, "ttg.total-num-warps" = 16 : i32}
37+
module attributes {"ttg.num-warps" = 4 : i32} {
3638

3739
tt.func @two_warp_specialize() {
38-
// CHECK: ttg.warp_specialize() attributes {warpGroupStartIds = array<i32: 2, 4>}
40+
// CHECK: ttg.warp_specialize() attributes {warpGroupStartIds = array<i32: 12, 14, 4, 15>}
3941
ttg.warp_specialize()
4042
default {
4143
ttg.warp_yield
@@ -46,8 +48,10 @@ tt.func @two_warp_specialize() {
4648
partition1() num_warps(1) {
4749
ttg.warp_return
4850
} : () -> ()
51+
// CHECK: partition2() num_warps(8)
52+
// CHECK: partition3() num_warps(1)
4953

50-
// CHECK: ttg.warp_specialize() attributes {warpGroupStartIds = array<i32: 10, 2>}
54+
// CHECK: ttg.warp_specialize() attributes {warpGroupStartIds = array<i32: 14, 4, 12, 15>}
5155
ttg.warp_specialize()
5256
default {
5357
ttg.warp_yield

0 commit comments

Comments
 (0)