Skip to content

Commit 7f5ebe7

Browse files
authored
[Blackwell] Support optional scale TMAs in warp specialization for tl.dot_scaled (triton-lang#6551)
[Blackwell] Support optional scale TMAs in warp specialization for tl.dot_scaled This enables automatic warp specialization for block scaled workloads.
1 parent 981e987 commit 7f5ebe7

File tree

2 files changed

+129
-25
lines changed

2 files changed

+129
-25
lines changed

lib/Dialect/TritonGPU/Transforms/WarpSpecialization/LoadMMASpecialization.cpp

Lines changed: 74 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,17 @@ LogicalResult triton::gpu::specializeLoadMMADependencies(scf::ForOp &loop,
207207
"loads for `tt.dot` operands");
208208
}
209209

210+
SmallVector<Operation *> aScaleChain, bScaleChain;
211+
auto scaledMMAOp = dyn_cast<ttng::TCGen5MMAScaledOp>(mmaOp.getOperation());
212+
if (scaledMMAOp) {
213+
if (failed(
214+
findSingleChainToLoad(loop, scaledMMAOp.getAScale(), aScaleChain)))
215+
aScaleChain.clear();
216+
if (failed(
217+
findSingleChainToLoad(loop, scaledMMAOp.getBScale(), bScaleChain)))
218+
bScaleChain.clear();
219+
}
220+
210221
ttng::TMEMAllocOp oldAccAlloc =
211222
mmaOp.getAccumulator().getDefiningOp<ttng::TMEMAllocOp>();
212223
if (!oldAccAlloc)
@@ -218,7 +229,9 @@ LogicalResult triton::gpu::specializeLoadMMADependencies(scf::ForOp &loop,
218229

219230
// Determine if the MMA accumulator can be multibuffered.
220231
auto isLoadPipelineable = [&](Operation *op) {
221-
return llvm::is_contained({aChain.back(), bChain.back()}, op);
232+
return llvm::is_contained(llvm::to_vector(llvm::concat<Operation *>(
233+
aChain, bChain, aScaleChain, bScaleChain)),
234+
op);
222235
};
223236
bool accIsMultiBuffered =
224237
// All operand feeds are pipelineable.
@@ -280,16 +293,27 @@ LogicalResult triton::gpu::specializeLoadMMADependencies(scf::ForOp &loop,
280293
Partition *mmaPartition = schedule.addPartition(numStages);
281294

282295
// Multi-buffer the loads.
283-
auto [loadIndex, loadPhase] = addIndexAndPhase(b, loop, numStages);
296+
BlockArgument loadIndex;
297+
BlockArgument loadPhase;
298+
std::tie(loadIndex, loadPhase) = addIndexAndPhase(b, loop, numStages);
299+
300+
auto allocate = [&](const SmallVector<Operation *> &chain)
301+
-> std::tuple<Operation *, RankedTensorType, SharedEncodingTrait, Value> {
302+
if (chain.empty())
303+
return {nullptr, RankedTensorType(), SharedEncodingTrait(), Value()};
304+
305+
Operation *load = chain.back();
306+
auto type = cast<RankedTensorType>(load->getResult(0).getType());
307+
SharedEncodingTrait enc = getSharedEncoding(chain.back());
308+
Value alloc = createAlloc(loop, type, load->getLoc(), enc, numStages);
309+
310+
return {load, type, enc, alloc};
311+
};
284312

285-
Operation *aLoad = aChain.back();
286-
Operation *bLoad = bChain.back();
287-
auto aType = cast<RankedTensorType>(aLoad->getResult(0).getType());
288-
auto bType = cast<RankedTensorType>(bLoad->getResult(0).getType());
289-
SharedEncodingTrait aEnc = getSharedEncoding(aChain.back());
290-
SharedEncodingTrait bEnc = getSharedEncoding(bChain.back());
291-
Value aAlloc = createAlloc(loop, aType, aLoad->getLoc(), aEnc, numStages);
292-
Value bAlloc = createAlloc(loop, bType, bLoad->getLoc(), bEnc, numStages);
313+
auto [aLoad, aType, aEnc, aAlloc] = allocate(aChain);
314+
auto [bLoad, bType, bEnc, bAlloc] = allocate(bChain);
315+
auto [aScaleLoad, aScaleType, aScaleEnc, aScaleAlloc] = allocate(aScaleChain);
316+
auto [bScaleLoad, bScaleType, bScaleEnc, bScaleAlloc] = allocate(bScaleChain);
293317

294318
// Share the same set of barriers for both.
295319
Value emptyBars = createBarrierAlloc(loop, numStages);
@@ -304,9 +328,23 @@ LogicalResult triton::gpu::specializeLoadMMADependencies(scf::ForOp &loop,
304328
int loadSizeInBytes =
305329
product(aType.getShape()) * aType.getElementTypeBitWidth() / 8 +
306330
product(bType.getShape()) * bType.getElementTypeBitWidth() / 8;
331+
if (aScaleLoad)
332+
loadSizeInBytes += product(aScaleType.getShape()) *
333+
aScaleType.getElementTypeBitWidth() / 8;
334+
if (bScaleLoad)
335+
loadSizeInBytes += product(bScaleType.getShape()) *
336+
bScaleType.getElementTypeBitWidth() / 8;
307337

308338
// Insert before the group of loads.
309-
b.setInsertionPoint(aLoad->isBeforeInBlock(bLoad) ? aLoad : bLoad);
339+
SmallVector<Operation *> allLoads{aLoad, bLoad};
340+
if (aScaleLoad)
341+
allLoads.push_back(aScaleLoad);
342+
if (bScaleLoad)
343+
allLoads.push_back(bScaleLoad);
344+
std::sort(allLoads.begin(), allLoads.end(),
345+
[](Operation *a, Operation *b) { return a->isBeforeInBlock(b); });
346+
b.setInsertionPoint(allLoads.front());
347+
310348
// Wait for the buffer to be empty and the corresponding barrier to be
311349
// exhausted.
312350
Value curEmptyBar = createSingleBufferView(b, emptyBars, loadIndex);
@@ -318,19 +356,21 @@ LogicalResult triton::gpu::specializeLoadMMADependencies(scf::ForOp &loop,
318356
loadSizeInBytes, intCst(true, 1));
319357

320358
// Replace the loads with async copies.
321-
b.setInsertionPoint(aLoad);
322-
Value aView = createSingleBufferView(b, aAlloc, loadIndex);
323-
lowerTMACopy(b, *loadPartition, aLoad, curLoadBar, aView);
324-
replaceUsesAndPropagateType(b, *aLoad->user_begin(), aView);
325-
aLoad->user_begin()->erase();
326-
aLoad->erase();
327-
328-
b.setInsertionPoint(bLoad);
329-
Value bView = createSingleBufferView(b, bAlloc, loadIndex);
330-
lowerTMACopy(b, *loadPartition, bLoad, curLoadBar, bView);
331-
replaceUsesAndPropagateType(b, *bLoad->user_begin(), bView);
332-
bLoad->user_begin()->erase();
333-
bLoad->erase();
359+
auto lowerLoadAndPropagate = [&](Operation *load, Value alloc,
360+
Value barrier) {
361+
b.setInsertionPoint(load);
362+
Value view = createSingleBufferView(b, alloc, loadIndex);
363+
lowerTMACopy(b, *loadPartition, load, barrier, view);
364+
replaceUsesAndPropagateType(b, *load->user_begin(), view);
365+
load->user_begin()->erase();
366+
load->erase();
367+
};
368+
lowerLoadAndPropagate(aLoad, aAlloc, curLoadBar);
369+
lowerLoadAndPropagate(bLoad, bAlloc, curLoadBar);
370+
if (aScaleLoad)
371+
lowerLoadAndPropagate(aScaleLoad, aScaleAlloc, curLoadBar);
372+
if (bScaleLoad)
373+
lowerLoadAndPropagate(bScaleLoad, bScaleAlloc, curLoadBar);
334374

335375
// Place the remaining users in the MMA partition. Re-acquire the use chain
336376
// because some ops were invalidated by `replaceUsesAndPropagateType`.
@@ -339,9 +379,18 @@ LogicalResult triton::gpu::specializeLoadMMADependencies(scf::ForOp &loop,
339379
aChain.push_back(mmaOp);
340380
(void)findSingleChainToLoad(loop, dot.getA(), aChain);
341381
(void)findSingleChainToLoad(loop, dot.getB(), bChain);
382+
if (aScaleLoad) {
383+
aScaleChain.clear();
384+
(void)findSingleChainToLoad(loop, scaledMMAOp.getAScale(), aScaleChain);
385+
}
386+
if (bScaleLoad) {
387+
bScaleChain.clear();
388+
(void)findSingleChainToLoad(loop, scaledMMAOp.getBScale(), bScaleChain);
389+
}
342390

343391
// Place users in the MMA partition.
344-
auto allUsers = llvm::to_vector(llvm::concat<Operation *>(aChain, bChain));
392+
auto allUsers = llvm::to_vector(
393+
llvm::concat<Operation *>(aChain, bChain, aScaleChain, bScaleChain));
345394
for (Operation *user : allUsers)
346395
mmaPartition->insert(user);
347396

test/TritonGPU/load-mma-specialization.mlir

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -762,4 +762,59 @@ tt.func @matmul_tma_acc_with_conditional_def_and_use_no_multibuf_flag(
762762
tt.return
763763
}
764764

765+
766+
767+
tt.func @matmul_scaled_rhs_scales_tma(
768+
// CHECK-LABEL: @matmul_scaled_rhs_scales_tma
769+
%k_tiles: i32,
770+
%off_m: i32,
771+
%off_n: i32,
772+
%a_desc: !tt.tensordesc<tensor<128x64xf8E4M3FN, #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 8}>>>,
773+
%b_desc: !tt.tensordesc<tensor<128x64xf8E4M3FN, #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 8}>>>,
774+
%b_scale_desc: !tt.tensordesc<tensor<128x8xi8, #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [4, 3, 2, 1, 0]}>>>
775+
) {
776+
%true = arith.constant true
777+
%c0_i32 = arith.constant 0 : i32
778+
%c1_i32 = arith.constant 1 : i32
779+
%BLOCK_K = arith.constant 64 : i32
780+
%zero = arith.constant dense<0.0> : tensor<128x128xf32, #acc_layout>
781+
782+
%a_scales_const = arith.constant dense<127> : tensor<128x8xi8, #oper_layout>
783+
%a_scales_tmem = ttng.tmem_alloc %a_scales_const : (tensor<128x8xi8, #oper_layout>) -> !ttg.memdesc<128x8xi8, #ttng.tensor_memory_scales_encoding<>, #ttng.tensor_memory>
784+
785+
%result = scf.for %k = %c0_i32 to %k_tiles step %c1_i32 iter_args(%acc = %zero) -> tensor<128x128xf32, #acc_layout> : i32 {
786+
%off_k = arith.muli %k, %BLOCK_K : i32
787+
788+
// CHECK: %{{[0-9]+}} = ttg.memdesc_subview %{{[0-9]+}}[%arg7, %c0_i32, %c0_i32]
789+
// CHECK-NEXT: %{{[0-9]+}} = ttng.tensor_desc_to_tma_ptr %arg3 {ttg.partition = 0 : i32}
790+
// CHECK-NEXT: ttng.async_tma_copy_global_to_local %{{[0-9]+}}[%arg1, %{{[0-9]+}}] %{{[0-9]+}}, %{{[0-9]+}}, %true {ttg.partition = 0 : i32}
791+
%a_reg = tt.descriptor_load %a_desc[%off_m, %off_k] : !tt.tensordesc<tensor<128x64xf8E4M3FN, #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 8}>>> -> tensor<128x64xf8E4M3FN, #oper_layout>
792+
793+
// CHECK-NEXT: %{{[0-9]+}} = ttg.memdesc_subview %{{[0-9]+}}[%arg7, %c0_i32, %c0_i32]
794+
// CHECK-NEXT: %{{[0-9]+}} = ttng.tensor_desc_to_tma_ptr %arg4 {ttg.partition = 0 : i32}
795+
// CHECK-NEXT: ttng.async_tma_copy_global_to_local %{{[0-9]+}}[%arg2, %{{[0-9]+}}] %{{[0-9]+}}, %{{[0-9]+}}, %true {ttg.partition = 0 : i32}
796+
%b_reg = tt.descriptor_load %b_desc[%off_n, %off_k] : !tt.tensordesc<tensor<128x64xf8E4M3FN, #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 8}>>> -> tensor<128x64xf8E4M3FN, #oper_layout>
797+
798+
// CHECK-NEXT: %{{[0-9]+}} = ttg.memdesc_subview %{{[0-9]+}}[%arg7, %c0_i32, %c0_i32]
799+
// CHECK-NEXT: %{{[0-9]+}} = ttng.tensor_desc_to_tma_ptr %arg5 {ttg.partition = 0 : i32}
800+
// CHECK-NEXT: ttng.async_tma_copy_global_to_local %{{[0-9]+}}[%arg1, %c0_i32] %{{[0-9]+}}, %{{[0-9]+}}, %true {ttg.partition = 0 : i32}
801+
%b_scales_reg = tt.descriptor_load %b_scale_desc[%off_m, %c0_i32] : !tt.tensordesc<tensor<128x8xi8, #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [4, 3, 2, 1, 0]}>>> -> tensor<128x8xi8, #oper_layout>
802+
803+
%a_sh = ttg.local_alloc %a_reg : (tensor<128x64xf8E4M3FN, #oper_layout>) -> !ttg.memdesc<128x64xf8E4M3FN, #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 8}>, #smem>
804+
%b_sh_raw = ttg.local_alloc %b_reg : (tensor<128x64xf8E4M3FN, #oper_layout>) -> !ttg.memdesc<128x64xf8E4M3FN, #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 8}>, #smem>
805+
%b_sh = ttg.memdesc_trans %b_sh_raw {order = array<i32: 1, 0>} : !ttg.memdesc<128x64xf8E4M3FN, #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 8}>, #smem> -> !ttg.memdesc<64x128xf8E4M3FN, #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = true, elementBitWidth = 8}>, #smem>
806+
807+
%b_scales_tmem = ttng.tmem_alloc %b_scales_reg : (tensor<128x8xi8, #oper_layout>) -> !ttg.memdesc<128x8xi8, #ttng.tensor_memory_scales_encoding<>, #ttng.tensor_memory>
808+
809+
%c_tmem = ttng.tmem_alloc %acc : (tensor<128x128xf32, #acc_layout>) -> !ttg.memdesc<128x128xf32, #acc_tmem, #ttng.tensor_memory, mutable>
810+
811+
ttng.tc_gen5_mma_scaled %a_sh, %b_sh, %c_tmem, %a_scales_tmem, %b_scales_tmem, %true, %true lhs = e4m3 rhs = e4m3 : !ttg.memdesc<128x64xf8E4M3FN, #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 8}>, #smem>, !ttg.memdesc<64x128xf8E4M3FN, #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = true, elementBitWidth = 8}>, #smem>, !ttg.memdesc<128x128xf32, #acc_tmem, #ttng.tensor_memory, mutable>, !ttg.memdesc<128x8xi8, #ttng.tensor_memory_scales_encoding<>, #ttng.tensor_memory>, !ttg.memdesc<128x8xi8, #ttng.tensor_memory_scales_encoding<>, #ttng.tensor_memory>
812+
813+
%c = ttng.tmem_load %c_tmem : !ttg.memdesc<128x128xf32, #acc_tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xf32, #acc_layout>
814+
scf.yield %c : tensor<128x128xf32, #acc_layout>
815+
} {tt.warp_specialize}
816+
817+
tt.return
818+
}
819+
765820
}

0 commit comments

Comments
 (0)