Skip to content

Commit 8454f55

Browse files
committed
[TRTLLM-11289][fix] Fix MMA accumulation bug in BF16 dense GEMM kernel
When mma_inst_tile_k > 1, cute.gemm() generates multiple sub-MMA instructions that all share the same ACCUMULATE flag. With ACCUMULATE=False on the first K tile, every sub-MMA cleared the accumulator so only the last sub-MMA's result survived, losing (mma_inst_tile_k - 1) * mma_inst_shape_k elements per output tile. This caused GSM8K accuracy to drop from 64.7% to 28.5%. Fix by adding an inner kblock loop that iterates sub-MMA instructions individually and sets ACCUMULATE=True after the first cute.gemm() call, matching the pattern used by blockscaled_contiguous_grouped_gemm.py. GSM8K accuracy restored to 64.86% (reference: 64.74%). Signed-off-by: peaceh <103117813+peaceh-nv@users.noreply.github.com>
1 parent e36f1ee commit 8454f55

File tree

1 file changed

+16
-4
lines changed

1 file changed

+16
-4
lines changed

tensorrt_llm/_torch/cute_dsl_kernels/blackwell/dense_gemm_persistent.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -598,15 +598,27 @@ class SharedStorage:
598598
if is_leader_cta:
599599
acc_pipeline.producer_acquire(acc_producer_state)
600600

601+
# Reset ACCUMULATE for each new output tile
602+
tiled_mma.set(tcgen05.Field.ACCUMULATE, False)
603+
601604
for k_tile in range(k_tile_cnt):
602605
if is_leader_cta:
603606
handle = ab_consumer.wait_and_advance(
604607
peek_ab_full_status)
605608

606-
tiled_mma.set(tcgen05.Field.ACCUMULATE, k_tile != 0)
607-
tile_crd = (None, None, None, handle.index)
608-
cute.gemm(tiled_mma, tCtAcc, tCrA[tile_crd],
609-
tCrB[tile_crd], tCtAcc)
609+
# Inner loop over kblocks within each K tile.
610+
# Set ACCUMULATE=True after first gemm call to
611+
# avoid clearing the accumulator on each sub-MMA.
612+
num_kblocks = cute.size(tCrA, mode=[2])
613+
for kblock_idx in cutlass.range(
614+
num_kblocks, unroll_full=True):
615+
kblock_crd = (None, None, kblock_idx,
616+
handle.index)
617+
cute.gemm(tiled_mma, tCtAcc,
618+
tCrA[kblock_crd],
619+
tCrB[kblock_crd], tCtAcc)
620+
tiled_mma.set(
621+
tcgen05.Field.ACCUMULATE, True)
610622

611623
handle.release()
612624

0 commit comments

Comments
 (0)