Skip to content

Commit e9f43f0

Browse files
authored
bugfix: fix trtllm-gen mla error on new interface (#1348)
<!-- .github/pull_request_template.md --> ## πŸ“Œ Description Fix error introduced by #1318 ## πŸ” Related Issues <!-- Link any related issues here --> ## πŸš€ Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### βœ… Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## πŸ§ͺ Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. -->
1 parent 2ee7465 commit e9f43f0

File tree

3 files changed

+7
-4
lines changed

3 files changed

+7
-4
lines changed

β€Žcsrc/trtllm_fmha_kernel_launcher.cu

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -208,9 +208,9 @@ void trtllm_paged_attention_decode(at::Tensor out, std::optional<at::Tensor> con
208208
std::to_string(head_dim_kv) + " and " +
209209
std::to_string(head_dim_qk));
210210
int head_dim_vo = is_4bit(o_data_type) ? out.size(-1) * 2 : out.size(-1);
211-
TORCH_CHECK(head_dim_kv == head_dim_vo, "head_dim_kv and head_dim_vo must be the same, got " +
212-
std::to_string(head_dim_kv) + " and " +
213-
std::to_string(head_dim_vo));
211+
TORCH_CHECK((head_dim_kv == 576 && head_dim_vo == 512) || head_dim_kv == head_dim_vo,
212+
"head_dim_kv and head_dim_vo must be the same for non-MLA attention, got " +
213+
std::to_string(head_dim_kv) + " and " + std::to_string(head_dim_vo));
214214
// NOTE(Zihao): key_value_cache is [num_pages, 1/2, num_kv_heads, page_size, head_dim]
215215
// For KV-Cache sharing (MLA), the second dimension is 1 (key/value cache are shared)
216216
// otherwise it is 2, one for key and one for value

β€Žflashinfer/decode.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2206,6 +2206,7 @@ def trtllm_batch_decode_with_kv_cache_mla(
22062206

22072207
run_func(
22082208
out,
2209+
None, # fp4 output not supported in wrapper api yet.
22092210
query,
22102211
kv_cache.unsqueeze(-3),
22112212
workspace_buffer,
@@ -2214,6 +2215,8 @@ def trtllm_batch_decode_with_kv_cache_mla(
22142215
max_seq_len,
22152216
bmm1_scale,
22162217
bmm2_scale,
2218+
-1, # o_sf_scale
2219+
-1, # o_sf_vec_size
22172220
-1, # window_left
22182221
sm_count,
22192222
)

β€Žtests/test_trtllm_gen_decode.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -349,7 +349,7 @@ def test_trtllm_batch_decode_fmha(
349349
@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn, torch.bfloat16])
350350
@pytest.mark.parametrize("page_size", [32, 64])
351351
@pytest.mark.parametrize("q_len_per_request", [1, 2])
352-
@pytest.mark.parametrize("dynamic_scale", [False, True])
352+
@pytest.mark.parametrize("dynamic_scale", [False])
353353
def test_trtllm_batch_decode_mla(
354354
batch_size: int,
355355
scale: float,

0 commit comments

Comments
Β (0)