Skip to content

Commit 84a7073

Browse files
authored
[NVWS] Support tmem_alloc(desc_load()) pattern in aref insertion (#7734)
A follow-up to triton-lang/triton#7581, handling one remaining case supported by TMA code in `LoadMMASpecialization`. Consider [these lines of code ](https://github.com/triton-lang/triton/blob/main/lib/Dialect/TritonGPU/Transforms/WarpSpecialization/LoadMMASpecialization.cpp#L71-L73) in `LoadMMASpecialization` and [the corresponding lit test](https://github.com/triton-lang/triton/blob/main/test/TritonGPU/load-mma-specialization.mlir#L763-L816). Main WS supports a load pattern `tmem_alloc(desc_load())`. I assume that this pattern is used only for loading scales. Importantly, main WS lowers `tmem_alloc(desc_load())` in exactly the same way as `local_alloc(desc_load())`, **meaning the tmem scale operand is replaced by smem scales** after WS. It's kind of working because `tc_gen5_mma_scaled` allows both tmem and smem scale operands. However, there is a contract to have scales in SMEM: their layout must be compatible with `tcgen05.cp`. This is because in `MMALowering`, if the scales are in SMEM, we generate `tcgen05.cp` on them. The use of `tcgen05.cp` in turn ensures that pipelining scaled MMAv5 is safe without double-buffering scales in TMEM. So in Triton, having scales in SMEM also implies that MMA pipelining is applicable. See triton-lang/triton#6019 for more details on the SMEM scales. For scales whose layout is compatible with `tcgen05.cp`, we put them into SMEM during `OptimizeDotOperand`. So if WS sees tmem scales, it implies that `tcgen05.cp` cannot be used and also MMA cannot be pipelined. So randomly replacing tmem scales with smem is incorrect. This PR correctly handles this case by adding `local_load` on the TMA buffer, which is then consumed by tmem alloc. In this case, since scales are put into TMEM by tmem store, MMA cannot be pipelined. So this is probably the first example where load is async but MMA is sync (not specialized) - if we detect this pattern, the MMA op must have been put into the default partition. If we really want to support this pattern end to end, we need to update the partition scheduling and `pipelineMMA` code as well (probably not worth it).
1 parent 39f3c20 commit 84a7073

File tree

2 files changed

+129
-53
lines changed

2 files changed

+129
-53
lines changed

test/NVWS/insert_aref.mlir

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,15 @@
22

33
#blocked = #ttg.blocked<{sizePerThread = [1, 128], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
44
#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}>
5+
#linear = #ttg.linear<{register = [[0, 1], [0, 2], [32, 0], [64, 0], [0, 4]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], warp = [[0, 0], [0, 0]], block = []}>
56
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}>
67
#shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = true, elementBitWidth = 16}>
8+
#shared2 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
9+
#shared3 = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 8}>
10+
#shared4 = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = true, elementBitWidth = 8}>
711
#smem = #ttg.shared_memory
812
#tmem = #ttng.tensor_memory_encoding<blockM = 128, blockN = 128, unpacked = true>
13+
#tmem_scales = #ttng.tensor_memory_scales_encoding<>
914

1015
module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} {
1116
// FUNC-LABEL: @warp_specialize_tma_matmul
@@ -129,4 +134,40 @@ module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} {
129134
} {tt.num_stages = 2 : i32, tt.scheduled_max_stage = 1 : i32, tt.warp_specialize, ttg.partition.stages = [0 : i32, 1 : i32, 0 : i32]}
130135
tt.return
131136
}
137+
138+
// CHECK-LABEL: @matmul_scaled_rhs_scales_tma
139+
tt.func @matmul_scaled_rhs_scales_tma(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: !tt.tensordesc<tensor<128x64xf8E4M3FN, #shared3>>, %arg4: !tt.tensordesc<tensor<128x64xf8E4M3FN, #shared3>>, %arg5: !tt.tensordesc<tensor<128x8xi8, #shared2>>) {
140+
%true = arith.constant true
141+
%c0_i32 = arith.constant 0 : i32
142+
%c1_i32 = arith.constant 1 : i32
143+
%c64_i32 = arith.constant 64 : i32
144+
%cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked>
145+
%cst_0 = arith.constant dense<127> : tensor<128x8xi8, #linear>
146+
%result = ttng.tmem_alloc %cst_0 : (tensor<128x8xi8, #linear>) -> !ttg.memdesc<128x8xi8, #tmem_scales, #ttng.tensor_memory>
147+
%0 = scf.for %arg6 = %c0_i32 to %arg0 step %c1_i32 iter_args(%arg7 = %cst) -> (tensor<128x128xf32, #blocked>) : i32 {
148+
%1 = arith.muli %arg6, %c64_i32 {loop.cluster = 1 : i32, loop.stage = 0 : i32} : i32
149+
%2 = tt.descriptor_load %arg3[%arg1, %1] {loop.cluster = 1 : i32, loop.stage = 0 : i32, ttg.partition = 2 : i32} : !tt.tensordesc<tensor<128x64xf8E4M3FN, #shared3>> -> tensor<128x64xf8E4M3FN, #blocked1>
150+
%3 = tt.descriptor_load %arg4[%arg2, %1] {loop.cluster = 1 : i32, loop.stage = 0 : i32, ttg.partition = 2 : i32} : !tt.tensordesc<tensor<128x64xf8E4M3FN, #shared3>> -> tensor<128x64xf8E4M3FN, #blocked1>
151+
%5 = ttg.local_alloc %2 {loop.cluster = 0 : i32, loop.stage = 1 : i32, ttg.partition = 2 : i32} : (tensor<128x64xf8E4M3FN, #blocked1>) -> !ttg.memdesc<128x64xf8E4M3FN, #shared3, #smem>
152+
%6 = ttg.local_alloc %3 {loop.cluster = 0 : i32, loop.stage = 1 : i32, ttg.partition = 2 : i32} : (tensor<128x64xf8E4M3FN, #blocked1>) -> !ttg.memdesc<128x64xf8E4M3FN, #shared3, #smem>
153+
154+
// CHECK: nvws.aref.put.enter
155+
// CHECK: nvws.descriptor_load
156+
// CHECK: nvws.aref.put.exit
157+
%4 = tt.descriptor_load %arg5[%arg1, %c0_i32] {loop.cluster = 1 : i32, loop.stage = 0 : i32, ttg.partition = 2 : i32} : !tt.tensordesc<tensor<128x8xi8, #shared2>> -> tensor<128x8xi8, #linear>
158+
159+
// CHECK: nvws.aref.get.enter
160+
// CHECK: [[REG:%.*]] = ttg.local_load
161+
// CHECK: nvws.aref.get.exit
162+
// CHECK: tmem_alloc [[REG]]
163+
%result_1 = ttng.tmem_alloc %4 {loop.cluster = 0 : i32, loop.stage = 1 : i32, ttg.partition = 2 : i32} : (tensor<128x8xi8, #linear>) -> !ttg.memdesc<128x8xi8, #tmem_scales, #ttng.tensor_memory>
164+
165+
%7 = ttg.memdesc_trans %6 {loop.cluster = 0 : i32, loop.stage = 1 : i32, order = array<i32: 1, 0>, ttg.partition = 1 : i32} : !ttg.memdesc<128x64xf8E4M3FN, #shared3, #smem> -> !ttg.memdesc<64x128xf8E4M3FN, #shared4, #smem>
166+
%result_2, %token = ttng.tmem_alloc %arg7 {loop.cluster = 0 : i32, loop.stage = 1 : i32, ttg.partition = 0 : i32} : (tensor<128x128xf32, #blocked>) -> (!ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>, !ttg.async.token)
167+
%8 = ttng.tc_gen5_mma_scaled %5, %7, %result_2[%token], %result, %result_1, %true, %true lhs = e4m3 rhs = e4m3 {loop.cluster = 0 : i32, loop.stage = 1 : i32, tt.self_latency = 1 : i32, ttg.partition = 1 : i32} : !ttg.memdesc<128x64xf8E4M3FN, #shared3, #smem>, !ttg.memdesc<64x128xf8E4M3FN, #shared4, #smem>, !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>, !ttg.memdesc<128x8xi8, #tmem_scales, #ttng.tensor_memory>, !ttg.memdesc<128x8xi8, #tmem_scales, #ttng.tensor_memory>
168+
%result_3, %token_4 = ttng.tmem_load %result_2[%8] {loop.cluster = 0 : i32, loop.stage = 1 : i32, ttg.partition = 0 : i32} : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xf32, #blocked>
169+
scf.yield %result_3 : tensor<128x128xf32, #blocked>
170+
} {tt.num_stages = 2 : i64, tt.scheduled_max_stage = 1 : i32, tt.warp_specialize, ttg.partition.stages = [0 : i32, 1 : i32, 0 : i32]}
171+
tt.return
172+
}
132173
}

third_party/nvidia/lib/Dialect/NVWS/Transforms/InsertAref.cpp

Lines changed: 88 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -49,19 +49,37 @@ SmallVector<ProducedValueInfo> getProducedValues(Operation *op, Block *loopBody,
4949
return producedValues;
5050
};
5151

52+
template <typename AllocOp, typename LoadOp>
53+
std::optional<std::pair<AllocOp, LoadOp>> isLoadAndAlloc(Value result) {
54+
auto alloc = result.getDefiningOp<AllocOp>();
55+
if (!alloc)
56+
return std::nullopt;
57+
if (auto load = alloc.getSrc().template getDefiningOp<LoadOp>()) {
58+
return std::make_pair(alloc, load);
59+
}
60+
return std::nullopt;
61+
}
62+
63+
// if result is defined by descriptor_load followed by alloc, return the alloc
64+
// and the load ops as a pair.
65+
template <typename AllocOp> auto isDescLoadAndAlloc(Value result) {
66+
return isLoadAndAlloc<AllocOp, triton::DescriptorOpInterface>(result);
67+
}
68+
69+
template <typename AllocOp> auto isGlobalLoadAndAlloc(Value result) {
70+
return isLoadAndAlloc<AllocOp, triton::LoadOp>(result);
71+
}
72+
5273
ArefCreateOp createAref(OpBuilder &builder, ProducedValueInfo &producedValue) {
5374
auto result = producedValue.result;
54-
MemDescType arefBufType;
5575

56-
if (auto memDescType = dyn_cast<MemDescType>(result.getType())) {
57-
arefBufType = getMultiBufferedType(memDescType, 1);
58-
} else if (auto tensorType = dyn_cast<RankedTensorType>(result.getType())) {
59-
// if result is a value, create memdesctype for location where value will
60-
// be stored
76+
auto getSmemDescType = [](Value tensorResult) {
77+
auto tensorType = cast<RankedTensorType>(tensorResult.getType());
6178
MemDescType memDescType;
6279
Attribute SharedMemorySpace =
6380
SharedMemorySpaceAttr::get(tensorType.getContext());
64-
if (auto load = result.getDefiningOp<triton::DescriptorOpInterface>()) {
81+
if (auto load =
82+
tensorResult.getDefiningOp<triton::DescriptorOpInterface>()) {
6583
// A use of TMA which is not immediately consumed by LocalAlloc
6684
// This case applies, for example, when TMA is followed by SIMT ops
6785
// or MMAv2 is used.
@@ -73,15 +91,25 @@ ArefCreateOp createAref(OpBuilder &builder, ProducedValueInfo &producedValue) {
7391
} else {
7492
llvm_unreachable("Only TMA is expected for now.");
7593
}
76-
arefBufType = getMultiBufferedType(memDescType, 1);
94+
return memDescType;
95+
};
96+
97+
MemDescType memDescType;
98+
if (isDescLoadAndAlloc<LocalAllocOp>(result)) {
99+
memDescType = dyn_cast<MemDescType>(result.getType());
100+
} else if (auto opt = isDescLoadAndAlloc<TMEMAllocOp>(result)) {
101+
auto descLoadResult = opt->first.getSrc();
102+
memDescType = getSmemDescType(descLoadResult);
103+
} else if (isa<RankedTensorType>(result.getType())) {
104+
memDescType = getSmemDescType(result);
77105
} else {
78-
std::string msg = "unsupported produced value type: " +
106+
std::string msg = "createAref: unsupported produced value type: " +
79107
mlir::debugString(result.getType());
80108
llvm::report_fatal_error(msg.c_str());
81109
}
82110

83-
assert(arefBufType &&
84-
(isa<SharedMemorySpaceAttr>(arefBufType.getMemorySpace())));
111+
MemDescType arefBufType = getMultiBufferedType(memDescType, 1);
112+
assert(isa<SharedMemorySpaceAttr>(arefBufType.getMemorySpace()));
85113
auto loc = result.getLoc();
86114
auto alloc = triton::nvws::createAlloc(builder, loc, arefBufType, Value());
87115
return createArefCreateOp(builder, {arefBufType}, {alloc->getResult(0)}, loc);
@@ -127,26 +155,15 @@ void createNVWSDescriptorLoadOp(OpBuilder &builder, Operation *ttDescLoadOp,
127155
}
128156
}
129157

130-
bool isDescLoadAndAlloc(Value result) {
131-
auto alloc = result.getDefiningOp<LocalAllocOp>();
132-
if (!alloc)
133-
return false;
134-
return alloc.getSrc().getDefiningOp<triton::DescriptorOpInterface>();
135-
}
136-
137-
bool isGlobalLoadAndAlloc(Value result) {
138-
auto alloc = result.getDefiningOp<LocalAllocOp>();
139-
if (!alloc)
140-
return false;
141-
return alloc.getSrc().getDefiningOp<triton::LoadOp>();
142-
}
143-
144158
StageCluster getStageClusterForProducer(Value producedValue) {
145-
if (isDescLoadAndAlloc(producedValue) ||
146-
isGlobalLoadAndAlloc(producedValue)) {
147-
auto alloc = producedValue.getDefiningOp<LocalAllocOp>();
148-
auto loadOp = alloc.getSrc().getDefiningOp();
149-
return getStageCluster(loadOp);
159+
if (auto opt = isDescLoadAndAlloc<LocalAllocOp>(producedValue)) {
160+
return getStageCluster(opt->second);
161+
} else if (auto opt = isDescLoadAndAlloc<TMEMAllocOp>(producedValue)) {
162+
return getStageCluster(opt->second);
163+
} else if (auto opt = isGlobalLoadAndAlloc<LocalAllocOp>(producedValue)) {
164+
return getStageCluster(opt->second);
165+
} else if (auto opt = isGlobalLoadAndAlloc<TMEMAllocOp>(producedValue)) {
166+
return getStageCluster(opt->second);
150167
}
151168
return getStageCluster(producedValue.getDefiningOp());
152169
}
@@ -173,15 +190,21 @@ SmallVector<Operation *> createArefPut(PartitionBuilder &builder,
173190

174191
auto producerKind = AsyncOp::NONE;
175192
SmallVector<Operation *> staleOps;
176-
if (isDescLoadAndAlloc(result)) {
177-
auto alloc = result.getDefiningOp<LocalAllocOp>();
178-
auto descOp = alloc.getSrc().getDefiningOp();
193+
if (auto opt = isDescLoadAndAlloc<LocalAllocOp>(result)) {
194+
auto [alloc, descOp] = *opt;
179195
createNVWSDescriptorLoadOp(builder, descOp, dataBuf, producerPartition,
180196
schedule, loc);
181197
producerKind = AsyncOp::TMALoad;
182198
staleOps.push_back(alloc);
183199
staleOps.push_back(descOp);
184-
} else if (isGlobalLoadAndAlloc(result)) {
200+
} else if (auto opt = isDescLoadAndAlloc<TMEMAllocOp>(result)) {
201+
auto descOp = opt->second;
202+
createNVWSDescriptorLoadOp(builder, descOp, dataBuf, producerPartition,
203+
schedule, loc);
204+
producerKind = AsyncOp::TMALoad;
205+
staleOps.push_back(descOp);
206+
} else if (isGlobalLoadAndAlloc<LocalAllocOp>(result) ||
207+
isGlobalLoadAndAlloc<TMEMAllocOp>(result)) {
185208
llvm_unreachable("cpasync not supported yet");
186209
} else if (auto tensorType = dyn_cast<RankedTensorType>(result.getType())) {
187210
if (auto descOp = result.getDefiningOp<triton::DescriptorOpInterface>()) {
@@ -197,7 +220,7 @@ SmallVector<Operation *> createArefPut(PartitionBuilder &builder,
197220
llvm_unreachable("Aref for values not supported yet");
198221
}
199222
} else {
200-
std::string msg = "unsupported produced value type: " +
223+
std::string msg = "createArefPut: unsupported produced value type: " +
201224
mlir::debugString(result.getType());
202225
llvm::report_fatal_error(msg.c_str());
203226
}
@@ -327,26 +350,34 @@ void createArefGet(PartitionBuilder &builder, scf::ForOp loop,
327350
Value token = getEnterOp.getToken();
328351

329352
Operation *exitInsertPointAfter = nullptr;
353+
354+
auto replaceUsesWithLocalLoad = [&](Value result, StageCluster stageCluster) {
355+
auto localLoadOp = builder.createInto<LocalLoadOp>(
356+
*consumerPartition, stageCluster, result.getType(), dataBuf);
357+
result.replaceAllUsesWith(localLoadOp.getResult());
358+
schedule.insert(consumerPartition, localLoadOp);
359+
if (consumers.size() == 1) {
360+
// If there is only one consumer and we hit this code path, the empty
361+
// barrier can be released after local load.
362+
exitInsertPointAfter = localLoadOp;
363+
}
364+
};
365+
330366
for (auto result : results) {
331-
if (auto memDescType = dyn_cast<MemDescType>(result.getType())) {
367+
if (auto localAlloc = result.getDefiningOp<LocalAllocOp>()) {
368+
auto memDescType = cast<MemDescType>(result.getType());
332369
auto callback = [&](Operation *oldOp, Operation *newOp) {
333370
assert(schedule.getPartition(oldOp) == consumerPartition);
334371
schedule.insert(consumerPartition, newOp);
335372
};
336-
replaceUsesAndPropagateType(builder, result.getDefiningOp(), dataBuf,
337-
callback);
338-
} else if (auto tensorType = dyn_cast<RankedTensorType>(result.getType())) {
339-
auto localLoadOp = builder.createInto<LocalLoadOp>(
340-
*consumerPartition, stageClusterEnter, tensorType, dataBuf);
341-
result.replaceAllUsesWith(localLoadOp.getResult());
342-
schedule.insert(consumerPartition, localLoadOp);
343-
if (consumers.size() == 1) {
344-
// If there is only one consumer and we hit this code path, the empty
345-
// barrier can be released after local load.
346-
exitInsertPointAfter = localLoadOp;
347-
}
373+
replaceUsesAndPropagateType(builder, localAlloc, dataBuf, callback);
374+
} else if (auto tmemAlloc = result.getDefiningOp<TMEMAllocOp>()) {
375+
builder.setInsertionPoint(tmemAlloc);
376+
replaceUsesWithLocalLoad(tmemAlloc.getSrc(), stageClusterEnter);
377+
} else if (isa<RankedTensorType>(result.getType())) {
378+
replaceUsesWithLocalLoad(result, stageClusterEnter);
348379
} else {
349-
std::string msg = "unsupported produced value type: " +
380+
std::string msg = "createArefGet: unsupported produced value type: " +
350381
mlir::debugString(result.getType());
351382
llvm::report_fatal_error(msg.c_str());
352383
}
@@ -384,9 +415,12 @@ bool insertArefs(PartitionBuilder &builder, scf::ForOp loop,
384415

385416
processResultUses(producedValue.result);
386417

387-
if (isDescLoadAndAlloc(producedValue.result)) {
418+
if (auto opt = isDescLoadAndAlloc<LocalAllocOp>(producedValue.result)) {
388419
// Process the register use as well
389-
auto alloc = producedValue.result.getDefiningOp<LocalAllocOp>();
420+
auto alloc = opt->first;
421+
processResultUses(alloc.getSrc());
422+
} else if (auto opt = isDescLoadAndAlloc<TMEMAllocOp>(producedValue.result)) {
423+
auto alloc = opt->first;
390424
processResultUses(alloc.getSrc());
391425
}
392426

@@ -446,7 +480,8 @@ class NVWSArefInsertion
446480
return WalkResult::advance();
447481
}
448482
// Only handles load ops for now.
449-
if (isDescLoadAndAlloc(op->getResult(0)) ||
483+
if (isDescLoadAndAlloc<LocalAllocOp>(op->getResult(0)) ||
484+
isDescLoadAndAlloc<TMEMAllocOp>(op->getResult(0)) ||
450485
(allowDescLoadRegUse &&
451486
(isa<triton::DescriptorOpInterface>(op)))) {
452487
ops.push_back(op);
@@ -459,7 +494,7 @@ class NVWSArefInsertion
459494
getProducedValues(op, loop.getBody(), *schedule);
460495
for (auto producedValue : producedValues) {
461496
PartitionBuilder builder(op->getLoc(), op);
462-
builder.setInsertionPointAfter(op);
497+
builder.setInsertionPoint(op);
463498
if (insertArefs(builder, loop, *schedule, producedValue, arefTag))
464499
arefTag++;
465500
}

0 commit comments

Comments
 (0)