Commit 8454f55
committed
[TRTLLM-11289][fix] Fix MMA accumulation bug in BF16 dense GEMM kernel
When mma_inst_tile_k > 1, cute.gemm() generates multiple sub-MMA
instructions that all share the same ACCUMULATE flag. With
ACCUMULATE=False on the first K tile, every sub-MMA cleared the
accumulator so only the last sub-MMA's result survived, losing
(mma_inst_tile_k - 1) * mma_inst_shape_k elements per output tile.
This caused GSM8K accuracy to drop from 64.7% to 28.5%.
Fix by adding an inner kblock loop that iterates sub-MMA instructions
individually and sets ACCUMULATE=True after the first cute.gemm() call,
matching the pattern used by blockscaled_contiguous_grouped_gemm.py.
GSM8K accuracy restored to 64.86% (reference: 64.74%).
Signed-off-by: peaceh <103117813+peaceh-nv@users.noreply.github.com>1 parent e36f1ee commit 8454f55
File tree
1 file changed
+16
-4
lines changed- tensorrt_llm/_torch/cute_dsl_kernels/blackwell
1 file changed
+16
-4
lines changedLines changed: 16 additions & 4 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
598 | 598 | | |
599 | 599 | | |
600 | 600 | | |
| 601 | + | |
| 602 | + | |
| 603 | + | |
601 | 604 | | |
602 | 605 | | |
603 | 606 | | |
604 | 607 | | |
605 | 608 | | |
606 | | - | |
607 | | - | |
608 | | - | |
609 | | - | |
| 609 | + | |
| 610 | + | |
| 611 | + | |
| 612 | + | |
| 613 | + | |
| 614 | + | |
| 615 | + | |
| 616 | + | |
| 617 | + | |
| 618 | + | |
| 619 | + | |
| 620 | + | |
| 621 | + | |
610 | 622 | | |
611 | 623 | | |
612 | 624 | | |
| |||
0 commit comments