Skip to content

Commit 2632da4

Browse files
committed
Enable nvfp4 output for trtllm-gen keepsMmaAb kernel
- Update cubin artifact path/checksum to new build with nvfp4 output support - Fix kernel selection: remove E2M1 output dtype condition from mixed-precision path, allowing nvfp4 output to use GQA generation kernel selection heuristics - Always invoke selectTileSizeQForGqaGeneration (not just for maxSeqLenQ > 1) - Add mUsesSharedPagedKvIdx field to KernelParams for vLLM/FlashInfer paged KV index - Remove speculative-decode skip for nvfp4 output in tests - Expand test coverage: head_dim [64, 128, 256], additional batch configs AI-assisted Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> revert revert revert
1 parent 1ddef01 commit 2632da4

File tree

2 files changed

+24
-35
lines changed

2 files changed

+24
-35
lines changed

include/flashinfer/trtllm/fmha/fmhaKernels.cuh

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -747,7 +747,7 @@ class TllmGenFmhaKernel {
747747
int& tileSizeQ = selectKernelParams.mTileSizeQ;
748748

749749
// Mixed precision kernels don't work with groupsTokensHeadsQ = true for now.
750-
if (mDtypeQ != mDtypeKv || mDtypeOut == DATA_TYPE_E2M1) {
750+
if (mDtypeQ != mDtypeKv) {
751751
tileSizeQ = params.mNumHeadsQPerKv <= 8 ? 8 : 16;
752752
kernelType = FmhaKernelType::SwapsMmaAbForGeneration;
753753
return;
@@ -773,11 +773,8 @@ class TllmGenFmhaKernel {
773773
kernelType = FmhaKernelType::KeepsMmaAbForGeneration;
774774
}
775775

776-
// When maxSeqLenQ > 1, use an experimental kernel-timing model to select the best kernel that
777-
// groups both tokensQ and headsQ into one CTA.
778-
if (params.mMaxSeqLenQ > 1) {
779-
selectTileSizeQForGqaGeneration(params, selectKernelParams);
780-
}
776+
// Use an experimental kernel-timing model to select the best tileSizeQ.
777+
selectTileSizeQForGqaGeneration(params, selectKernelParams);
781778
}
782779

783780
// Select a kernel based on the heuristic.

tests/attention/test_trtllm_gen_attention.py

Lines changed: 21 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -846,15 +846,6 @@ def _test_trtllm_batch_decode(
846846
if backend == "xqa" and q_dtype == "fp8":
847847
pytest.skip("xqa backend only supports fp16 and bf16 query")
848848

849-
if o_dtype == "nvfp4" and (
850-
q_len_per_req is not None
851-
and q_len_per_req > 1
852-
or max_q_len is not None
853-
and max_q_len > 1
854-
):
855-
# todo(Yingyi): add support for nvfp4 with speculative decoding
856-
pytest.skip("nvfp4 is not supported for q_len_per_req > 1 or max_q_len > 1 yet")
857-
858849
if backend == "trtllm-gen" and o_dtype == "fp8" and q_dtype != "fp8":
859850
pytest.skip("trtllm-gen backend only supports fp8 output for fp8 query")
860851

@@ -1181,7 +1172,7 @@ def _test_trtllm_batch_decode(
11811172
@pytest.mark.parametrize("enable_pdl", [True, False, None])
11821173
@pytest.mark.parametrize("enable_sink", [True, False])
11831174
@pytest.mark.parametrize("max_in_kv_len", [110])
1184-
@pytest.mark.parametrize("head_dim", [128])
1175+
@pytest.mark.parametrize("head_dim", [64, 128, 256])
11851176
@pytest.mark.parametrize("non_contiguous_query", [False, True])
11861177
@pytest.mark.parametrize("skips_softmax", [False, True])
11871178
def test_trtllm_batch_decode(
@@ -1632,25 +1623,27 @@ def make_query_non_contiguous(
16321623
@pytest.mark.parametrize("backend", ["trtllm-gen"])
16331624
@pytest.mark.parametrize("kv_layout", ["HND", "NHD"])
16341625
@pytest.mark.parametrize(
1635-
"batch_size,max_q_len,page_size,num_kv_heads,head_grp_size",
1626+
"batch_size,max_q_len,page_size,num_kv_heads,head_grp_size,head_dim",
16361627
[
1637-
(4, 1, 16, 2, 1),
1638-
(4, 1, 32, 2, 5),
1639-
(4, 2, 64, 2, 5),
1640-
(4, 3, 32, 2, 5),
1641-
(4, 3, 64, 2, 1),
1642-
(4, 4, 64, 4, 1),
1643-
(4, 5, 64, 4, 8),
1644-
(128, 1, 64, 2, 5),
1645-
(128, 2, 32, 4, 1),
1646-
(128, 3, 16, 4, 8),
1647-
(128, 4, 16, 2, 5),
1648-
(128, 5, 16, 2, 5),
1649-
(256, 1, 64, 4, 8),
1650-
(256, 2, 16, 2, 8),
1651-
(256, 3, 64, 4, 5),
1652-
(256, 4, 32, 2, 8),
1653-
(256, 5, 32, 2, 1),
1628+
(4, 1, 16, 2, 1, 128),
1629+
(4, 1, 32, 2, 5, 128),
1630+
(4, 2, 64, 2, 5, 128),
1631+
(4, 3, 32, 2, 5, 128),
1632+
(4, 3, 64, 2, 1, 128),
1633+
(4, 4, 64, 4, 1, 128),
1634+
(4, 5, 64, 4, 8, 128),
1635+
# Iterate over head_dim 64, 128, 256 for these configs to simplify
1636+
*[(bs, 4, 64, 4, 16, hd) for bs in [4, 8, 16, 32] for hd in [64, 128, 256]],
1637+
(128, 1, 64, 2, 5, 128),
1638+
(128, 2, 32, 4, 1, 128),
1639+
(128, 3, 16, 4, 8, 128),
1640+
(128, 4, 16, 2, 5, 128),
1641+
(128, 5, 16, 2, 5, 128),
1642+
(256, 1, 64, 4, 8, 256),
1643+
(256, 2, 16, 2, 8, 256),
1644+
(256, 3, 64, 4, 5, 256),
1645+
(256, 4, 32, 2, 8, 256),
1646+
(256, 16, 32, 2, 8, 256),
16541647
],
16551648
)
16561649
@pytest.mark.parametrize("window_left", [-1, 127])
@@ -1672,7 +1665,6 @@ def make_query_non_contiguous(
16721665
@pytest.mark.parametrize("enable_pdl", [True, False, None])
16731666
@pytest.mark.parametrize("enable_sink", [True, False])
16741667
@pytest.mark.parametrize("max_in_kv_len", [110])
1675-
@pytest.mark.parametrize("head_dim", [128])
16761668
@pytest.mark.parametrize("skips_softmax", [False, True])
16771669
def test_trtllm_batch_decode_spec(
16781670
backend: str,

0 commit comments

Comments
 (0)