Skip to content

Commit 60dec90

Browse files
yyihuangzihaoye
andauthored
fix: should pass global_override_indptr_cpu in fast_decode_plan param list (#1757)
<!-- .github/pull_request_template.md --> ## 📌 Description fix #1745 ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [ ] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [ ] I have installed the hooks with `pre-commit install`. - [ ] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> --------- Co-authored-by: Zihao Ye <[email protected]> Co-authored-by: Zihao Ye <[email protected]>
1 parent 905f755 commit 60dec90

File tree

3 files changed

+69
-3
lines changed

3 files changed

+69
-3
lines changed

flashinfer/decode.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2368,9 +2368,6 @@ def trtllm_batch_decode_with_kv_cache_mla(
23682368
return out
23692369

23702370

2371-
global_override_indptr_cpu = None
2372-
2373-
23742371
def fast_decode_plan(
23752372
self,
23762373
indptr: torch.Tensor,
@@ -2392,6 +2389,7 @@ def fast_decode_plan(
23922389
non_blocking: bool = True,
23932390
fixed_split_size: Optional[int] = None,
23942391
disable_split_kv: bool = False,
2392+
global_override_indptr_cpu: Optional[torch.Tensor] = None,
23952393
) -> None:
23962394
"""
23972395
A faster version of BatchDecodeWithPagedKVCacheWrapper::plan used for FlashInferMultiStepDraftBackend.

tests/test_batch_decode_kernels.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,10 @@ def test_batch_decode_with_paged_kv_cache(
186186
torch.testing.assert_close(o, o_buffer, rtol=1e-3, atol=1e-3)
187187

188188

189+
global_override_indptr_cpu = None
190+
MAX_BATCH_SIZE = 128
191+
192+
189193
@pytest.mark.parametrize("batch_size", [12, 17, 128])
190194
@pytest.mark.parametrize("kv_len", [54, 97, 512, 2048, 16384])
191195
@pytest.mark.parametrize("page_size", [1, 8, 16])
@@ -218,6 +222,15 @@ def test_batch_decode_with_paged_kv_cache_with_fast_plan(
218222
num_pages_per_seq = (kv_len + page_size - 1) // page_size
219223
total_num_pages = num_pages_per_seq * batch_size
220224

225+
global global_override_indptr_cpu
226+
if global_override_indptr_cpu is None:
227+
global_override_indptr_cpu = torch.empty(MAX_BATCH_SIZE + 1, device="cpu")
228+
if global_override_indptr_cpu is not None:
229+
global_override_indptr_cpu = (
230+
torch.arange(0, batch_size + 1, device="cpu", dtype=torch.int32)
231+
* num_pages_per_seq
232+
)
233+
221234
if kv_layout == "HND":
222235
kv_shape = [total_num_pages, 2, num_kv_heads, page_size, head_dim]
223236
else:
@@ -280,6 +293,7 @@ def test_batch_decode_with_paged_kv_cache_with_fast_plan(
280293
data_type=kv_dtype,
281294
q_data_type=q_dtype,
282295
non_blocking=True,
296+
global_override_indptr_cpu=global_override_indptr_cpu,
283297
)
284298
if return_lse:
285299
o, _ = wrapper.run(q, kv_data, return_lse=True)

tests/test_tensor_cores_decode.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,10 @@ def test_batch_decode_tensor_cores_cuda_graph(
328328
torch.testing.assert_close(lse, lse_tensor_cores, rtol=1e-3, atol=1e-3)
329329

330330

331+
global_override_indptr_cpu = None
332+
MAX_BATCH_SIZE = 128
333+
334+
331335
@pytest.mark.parametrize("batch_size", [5, 12])
332336
@pytest.mark.parametrize("invariant_bs", [4])
333337
@pytest.mark.parametrize("kv_len", [4096, 8192, 5000])
@@ -358,6 +362,16 @@ def test_batch_decode_tensor_cores_with_fast_plan(
358362
)
359363
num_pages_per_seq = (kv_len + page_size - 1) // page_size
360364
total_num_pages = num_pages_per_seq * batch_size
365+
366+
global global_override_indptr_cpu
367+
if global_override_indptr_cpu is None:
368+
global_override_indptr_cpu = torch.empty(MAX_BATCH_SIZE + 1, device="cpu")
369+
if global_override_indptr_cpu is not None:
370+
global_override_indptr_cpu = (
371+
torch.arange(0, batch_size + 1, device="cpu", dtype=torch.int32)
372+
* num_pages_per_seq
373+
)
374+
361375
kv_data = (
362376
torch.randn(
363377
total_num_pages,
@@ -425,13 +439,15 @@ def test_batch_decode_tensor_cores_with_fast_plan(
425439
q_data_type=torch.float16,
426440
fixed_split_size=fixed_split_size if not disable_split_kv else None,
427441
disable_split_kv=disable_split_kv,
442+
global_override_indptr_cpu=global_override_indptr_cpu,
428443
)
429444
o_tensor_cores, lse_tensor_cores = wrapper_tensor_cores.run(
430445
q, kv_data, return_lse=True
431446
)
432447

433448
kv_indptr_invariant = kv_indptr[: invariant_bs + 1]
434449
kv_last_page_len_invariant = kv_last_page_len[:invariant_bs]
450+
global_override_indptr_cpu = global_override_indptr_cpu[: invariant_bs + 1]
435451
wrapper_tensor_cores.plan(
436452
kv_indptr_invariant,
437453
kv_indices,
@@ -445,6 +461,7 @@ def test_batch_decode_tensor_cores_with_fast_plan(
445461
q_data_type=torch.float16,
446462
fixed_split_size=fixed_split_size if not disable_split_kv else None,
447463
disable_split_kv=disable_split_kv,
464+
global_override_indptr_cpu=global_override_indptr_cpu,
448465
)
449466
o_tensor_cores_invariant, lse_tensor_cores_invariant = wrapper_tensor_cores.run(
450467
q[:invariant_bs], kv_data, return_lse=True
@@ -477,6 +494,16 @@ def test_batch_fast_decode_tensor_cores_cuda_graph(
477494
)
478495
num_pages_per_seq = (kv_len + page_size - 1) // page_size
479496
total_num_pages = num_pages_per_seq * batch_size
497+
498+
global global_override_indptr_cpu
499+
if global_override_indptr_cpu is None:
500+
global_override_indptr_cpu = torch.empty(MAX_BATCH_SIZE + 1, device="cpu")
501+
if global_override_indptr_cpu is not None:
502+
global_override_indptr_cpu = (
503+
torch.arange(0, batch_size + 1, device="cpu", dtype=torch.int32)
504+
* num_pages_per_seq
505+
)
506+
480507
kv_data = (
481508
torch.randn(
482509
total_num_pages,
@@ -562,6 +589,8 @@ def test_batch_fast_decode_tensor_cores_cuda_graph(
562589
paged_kv_indices_buffer=kv_indices,
563590
paged_kv_last_page_len_buffer=kv_last_page_len,
564591
)
592+
593+
# cache
565594
wrapper_tensor_cores.plan(
566595
kv_indptr,
567596
kv_indices,
@@ -574,6 +603,24 @@ def test_batch_fast_decode_tensor_cores_cuda_graph(
574603
data_type=torch.float16,
575604
q_data_type=torch.float16,
576605
)
606+
607+
wrapper_tensor_cores.plan = partial(
608+
flashinfer.fast_decode_plan, wrapper_tensor_cores
609+
)
610+
611+
wrapper_tensor_cores.plan(
612+
kv_indptr,
613+
kv_indices,
614+
kv_last_page_len,
615+
num_qo_heads,
616+
num_kv_heads,
617+
head_dim,
618+
page_size,
619+
pos_encoding_mode=pos_encoding_mode,
620+
data_type=torch.float16,
621+
q_data_type=torch.float16,
622+
global_override_indptr_cpu=global_override_indptr_cpu,
623+
)
577624
# warmup
578625
s = torch.cuda.Stream()
579626
s.wait_stream(torch.cuda.current_stream())
@@ -596,3 +643,10 @@ def test_batch_fast_decode_tensor_cores_cuda_graph(
596643

597644
torch.testing.assert_close(o, o_tensor_cores, rtol=1e-3, atol=1e-3)
598645
torch.testing.assert_close(lse, lse_tensor_cores, rtol=1e-3, atol=1e-3)
646+
647+
648+
if __name__ == "__main__":
649+
test_batch_decode_tensor_cores_with_fast_plan(
650+
5, 4, 4096, 2048, True, 1, 4, 1, 128, "HND", "NONE"
651+
)
652+
test_batch_fast_decode_tensor_cores_cuda_graph(12, 54, 1, 4, 1, 128, "HND", "NONE")

0 commit comments

Comments
 (0)