Skip to content

Commit 175fc73

Browse files
zihaoyeyyihuang
andauthored
feat: port fast_decode_plan from sgl (#1745)
<!-- .github/pull_request_template.md --> ## 📌 Description <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> ## 🔍 Related Issues #1720 ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [ ] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [ ] I have installed the hooks with `pre-commit install`. - [ ] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> --------- Co-authored-by: Avery Yingyi Huang <[email protected]>
1 parent 1dd4af6 commit 175fc73

File tree

4 files changed

+592
-0
lines changed

4 files changed

+592
-0
lines changed

flashinfer/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,9 @@
5050
from .decode import (
5151
CUDAGraphBatchDecodeWithPagedKVCacheWrapper as CUDAGraphBatchDecodeWithPagedKVCacheWrapper,
5252
)
53+
from .decode import (
54+
fast_decode_plan as fast_decode_plan,
55+
)
5356
from .decode import cudnn_batch_decode_with_kv_cache as cudnn_batch_decode_with_kv_cache
5457
from .decode import single_decode_with_kv_cache as single_decode_with_kv_cache
5558
from .fp4_quantization import (

flashinfer/decode.py

Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2366,3 +2366,171 @@ def trtllm_batch_decode_with_kv_cache_mla(
23662366
sinks,
23672367
)
23682368
return out
2369+
2370+
2371+
global_override_indptr_cpu = None
2372+
2373+
2374+
def fast_decode_plan(
2375+
self,
2376+
indptr: torch.Tensor,
2377+
indices: torch.Tensor,
2378+
last_page_len: torch.Tensor,
2379+
num_qo_heads: int,
2380+
num_kv_heads: int,
2381+
head_dim: int,
2382+
page_size: int,
2383+
pos_encoding_mode: str = "NONE",
2384+
window_left: int = -1,
2385+
logits_soft_cap: Optional[float] = None,
2386+
q_data_type: Optional[Union[str, torch.dtype]] = None,
2387+
kv_data_type: Optional[Union[str, torch.dtype]] = None,
2388+
data_type: Optional[Union[str, torch.dtype]] = None,
2389+
sm_scale: Optional[float] = None,
2390+
rope_scale: Optional[float] = None,
2391+
rope_theta: Optional[float] = None,
2392+
non_blocking: bool = True,
2393+
fixed_split_size: Optional[int] = None,
2394+
disable_split_kv: bool = False,
2395+
) -> None:
2396+
"""
2397+
A faster version of BatchDecodeWithPagedKVCacheWrapper::plan used for FlashInferMultiStepDraftBackend.
2398+
Modifications:
2399+
- Remove unnecessary device-to-device copy for the cuda graph buffers.
2400+
- Remove unnecessary host-to-device copy for the metadata buffers.
2401+
"""
2402+
batch_size = len(last_page_len)
2403+
if logits_soft_cap is None:
2404+
logits_soft_cap = 0.0
2405+
2406+
# Handle data types consistently
2407+
if data_type is not None:
2408+
if q_data_type is None:
2409+
q_data_type = data_type
2410+
if kv_data_type is None:
2411+
kv_data_type = data_type
2412+
elif q_data_type is None:
2413+
q_data_type = "float16"
2414+
2415+
if kv_data_type is None:
2416+
kv_data_type = q_data_type
2417+
2418+
if self.use_tensor_cores:
2419+
qo_indptr_host = _get_range_buf(batch_size + 1, "cpu")
2420+
# Here we set fixed_split_size to -1 to avoid the assertion error in flashinfer's plan function
2421+
if fixed_split_size is None:
2422+
fixed_split_size = -1
2423+
2424+
if self.is_cuda_graph_enabled:
2425+
if batch_size != self._fixed_batch_size:
2426+
raise ValueError(
2427+
"The batch size should be fixed in cudagraph mode, the runtime batch size {} "
2428+
" mismatches the batch size set during initialization {}".format(
2429+
batch_size, self._fixed_batch_size
2430+
)
2431+
)
2432+
if len(indices) > len(self._paged_kv_indices_buf):
2433+
raise ValueError(
2434+
"The size of indices should be less than or equal to the allocated buffer"
2435+
)
2436+
else:
2437+
self._paged_kv_indptr_buf = indptr
2438+
self._paged_kv_indices_buf = indices
2439+
self._paged_kv_last_page_len_buf = last_page_len
2440+
if self.use_tensor_cores:
2441+
self._qo_indptr_buf = qo_indptr_host.to(
2442+
self.device, non_blocking=non_blocking
2443+
)
2444+
2445+
# Create empty tensors for dtype info if needed
2446+
empty_q_data = torch.empty(
2447+
0,
2448+
dtype=(
2449+
getattr(torch, q_data_type) if isinstance(q_data_type, str) else q_data_type
2450+
),
2451+
device=self.device,
2452+
)
2453+
2454+
empty_kv_cache = torch.empty(
2455+
0,
2456+
dtype=(
2457+
getattr(torch, kv_data_type)
2458+
if isinstance(kv_data_type, str)
2459+
else kv_data_type
2460+
),
2461+
device=self.device,
2462+
)
2463+
2464+
indptr_host = (
2465+
global_override_indptr_cpu
2466+
if global_override_indptr_cpu is not None
2467+
else indptr.cpu()
2468+
)
2469+
2470+
with torch.cuda.device(self.device):
2471+
if self.use_tensor_cores:
2472+
# ALSO convert last_page_len to CPU
2473+
if page_size == 1:
2474+
# When page size is 1, last_page_len is always 1.
2475+
# Directly construct the host tensor rather than executing a device-to-host copy.
2476+
last_page_len_host = torch.ones(
2477+
(batch_size,), dtype=torch.int32, device="cpu"
2478+
)
2479+
else:
2480+
last_page_len_host = last_page_len.cpu()
2481+
2482+
kv_lens_arr_host = get_seq_lens(indptr_host, last_page_len_host, page_size)
2483+
2484+
try:
2485+
# Make sure we pass exactly 15 arguments for tensor core version
2486+
self._plan_info = self._cached_module.plan(
2487+
self._float_workspace_buffer,
2488+
self._int_workspace_buffer,
2489+
self._pin_memory_int_workspace_buffer,
2490+
qo_indptr_host,
2491+
indptr_host,
2492+
kv_lens_arr_host,
2493+
batch_size, # total_num_rows
2494+
batch_size,
2495+
num_qo_heads,
2496+
num_kv_heads,
2497+
page_size,
2498+
self.is_cuda_graph_enabled,
2499+
head_dim,
2500+
head_dim,
2501+
False, # causal
2502+
window_left,
2503+
fixed_split_size,
2504+
disable_split_kv,
2505+
)
2506+
except Exception as e:
2507+
raise RuntimeError(f"Error in standard plan: {e}") from e
2508+
else:
2509+
try:
2510+
# Make sure we pass exactly 15 arguments for standard version
2511+
self._plan_info = self._cached_module.plan(
2512+
self._float_workspace_buffer,
2513+
self._int_workspace_buffer,
2514+
self._pin_memory_int_workspace_buffer,
2515+
indptr_host,
2516+
batch_size,
2517+
num_qo_heads,
2518+
num_kv_heads,
2519+
page_size,
2520+
self.is_cuda_graph_enabled,
2521+
window_left,
2522+
logits_soft_cap,
2523+
head_dim,
2524+
head_dim,
2525+
empty_q_data,
2526+
empty_kv_cache,
2527+
)
2528+
except Exception as e:
2529+
raise RuntimeError(f"Error in standard plan: {e}") from e
2530+
2531+
self._pos_encoding_mode = pos_encoding_mode
2532+
self._window_left = window_left
2533+
self._logits_soft_cap = logits_soft_cap
2534+
self._sm_scale = sm_scale
2535+
self._rope_scale = rope_scale
2536+
self._rope_theta = rope_theta

tests/test_batch_decode_kernels.py

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import pytest
1818
import torch
19+
from functools import partial
1920
from jit_utils import gen_decode_attention_modules, gen_prefill_attention_modules
2021

2122
import flashinfer
@@ -185,6 +186,155 @@ def test_batch_decode_with_paged_kv_cache(
185186
torch.testing.assert_close(o, o_buffer, rtol=1e-3, atol=1e-3)
186187

187188

189+
@pytest.mark.parametrize("batch_size", [12, 17, 128])
190+
@pytest.mark.parametrize("kv_len", [54, 97, 512, 2048, 16384])
191+
@pytest.mark.parametrize("page_size", [1, 8, 16])
192+
@pytest.mark.parametrize("num_kv_heads", [4])
193+
@pytest.mark.parametrize("num_qo_heads", [4, 32])
194+
@pytest.mark.parametrize("head_dim", [128, 256])
195+
@pytest.mark.parametrize("kv_layout", ["NHD"])
196+
@pytest.mark.parametrize("pos_encoding_mode", ["NONE", "ROPE_LLAMA"])
197+
@pytest.mark.parametrize("logits_soft_cap", [0.0])
198+
@pytest.mark.parametrize("return_lse", [True])
199+
@pytest.mark.parametrize("q_dtype", [torch.float16])
200+
@pytest.mark.parametrize("kv_dtype", [torch.float16, torch.float8_e4m3fn])
201+
@pytest.mark.parametrize("contiguous_kv", [True])
202+
def test_batch_decode_with_paged_kv_cache_with_fast_plan(
203+
batch_size,
204+
kv_len,
205+
page_size,
206+
num_kv_heads,
207+
num_qo_heads,
208+
head_dim,
209+
kv_layout,
210+
pos_encoding_mode,
211+
logits_soft_cap,
212+
return_lse,
213+
q_dtype,
214+
kv_dtype,
215+
contiguous_kv,
216+
):
217+
q = torch.randn(batch_size, num_qo_heads, head_dim, device="cuda:0", dtype=q_dtype)
218+
num_pages_per_seq = (kv_len + page_size - 1) // page_size
219+
total_num_pages = num_pages_per_seq * batch_size
220+
221+
if kv_layout == "HND":
222+
kv_shape = [total_num_pages, 2, num_kv_heads, page_size, head_dim]
223+
else:
224+
kv_shape = [total_num_pages, 2, page_size, num_kv_heads, head_dim]
225+
if not contiguous_kv:
226+
tmp = [kv_shape[0]]
227+
for v in kv_shape[1:]:
228+
tmp.append(2)
229+
tmp.append(v)
230+
kv_shape = tmp
231+
kv_data_fp32 = torch.randn(*kv_shape, dtype=torch.float32, device="cuda:0")
232+
kv_data = kv_data_fp32.to(kv_dtype)
233+
kv_data = kv_data[:, 1, :, 1, :, 1, :, 1, :]
234+
kv_data_fp32 = kv_data_fp32[:, 1, :, 1, :, 1, :, 1, :]
235+
# actual data is stored in non-contiguous memory
236+
assert (
237+
kv_data.stride(-4)
238+
!= kv_data.shape[-3] * kv_data.shape[-2] * kv_data.shape[-1]
239+
)
240+
else:
241+
kv_data_fp32 = torch.randn(*kv_shape, dtype=torch.float32, device="cuda:0")
242+
kv_data = kv_data_fp32.to(kv_dtype)
243+
kv_indptr = (
244+
torch.arange(0, batch_size + 1, device="cuda:0", dtype=torch.int32)
245+
* num_pages_per_seq
246+
)
247+
kv_indices = torch.arange(0, total_num_pages, device="cuda:0", dtype=torch.int32)
248+
kv_last_page_len = torch.full(
249+
(batch_size,), (kv_len - 1) % page_size + 1, dtype=torch.int32, device="cuda:0"
250+
)
251+
252+
workspace_buffer = torch.empty(32 * 1024 * 1024, dtype=torch.int8, device="cuda:0")
253+
wrapper = flashinfer.decode.BatchDecodeWithPagedKVCacheWrapper(
254+
workspace_buffer, kv_layout
255+
)
256+
wrapper.plan(
257+
kv_indptr,
258+
kv_indices,
259+
kv_last_page_len,
260+
num_qo_heads,
261+
num_kv_heads,
262+
head_dim,
263+
page_size,
264+
logits_soft_cap=logits_soft_cap,
265+
pos_encoding_mode=pos_encoding_mode,
266+
data_type=kv_dtype,
267+
q_data_type=q_dtype,
268+
)
269+
wrapper.plan = partial(flashinfer.fast_decode_plan, wrapper)
270+
wrapper.plan(
271+
kv_indptr,
272+
kv_indices,
273+
kv_last_page_len,
274+
num_qo_heads,
275+
num_kv_heads,
276+
head_dim,
277+
page_size,
278+
logits_soft_cap=logits_soft_cap,
279+
pos_encoding_mode=pos_encoding_mode,
280+
data_type=kv_dtype,
281+
q_data_type=q_dtype,
282+
non_blocking=True,
283+
)
284+
if return_lse:
285+
o, _ = wrapper.run(q, kv_data, return_lse=True)
286+
else:
287+
o = wrapper.run(q, kv_data)
288+
289+
for i in range(batch_size):
290+
perm_dims = [0, 2, 1, 3] if kv_layout == "HND" else [0, 1, 2, 3]
291+
perm_dims_last = [1, 0, 2] if kv_layout == "HND" else [0, 1, 2]
292+
qi = q[i]
293+
ki = torch.cat(
294+
[
295+
kv_data_fp32[kv_indptr[i] : kv_indptr[i + 1] - 1, 0]
296+
.permute(*perm_dims)
297+
.reshape(-1, num_kv_heads, head_dim),
298+
(
299+
kv_data_fp32[kv_indptr[i + 1] - 1, 0, :, : kv_last_page_len[i]]
300+
if kv_layout == "HND"
301+
else kv_data_fp32[kv_indptr[i + 1] - 1, 0, : kv_last_page_len[i], :]
302+
)
303+
.permute(*perm_dims_last)
304+
.reshape(-1, num_kv_heads, head_dim),
305+
],
306+
dim=0,
307+
).to(kv_dtype)
308+
vi = torch.cat(
309+
[
310+
kv_data_fp32[kv_indptr[i] : kv_indptr[i + 1] - 1, 1]
311+
.permute(*perm_dims)
312+
.reshape(-1, num_kv_heads, head_dim),
313+
(
314+
kv_data_fp32[kv_indptr[i + 1] - 1, 1, :, : kv_last_page_len[i]]
315+
if kv_layout == "HND"
316+
else kv_data_fp32[kv_indptr[i + 1] - 1, 1, : kv_last_page_len[i], :]
317+
)
318+
.permute(*perm_dims_last)
319+
.reshape(-1, num_kv_heads, head_dim),
320+
],
321+
dim=0,
322+
).to(kv_dtype)
323+
o_ref_i = flashinfer.decode.single_decode_with_kv_cache(
324+
qi,
325+
ki,
326+
vi,
327+
pos_encoding_mode=pos_encoding_mode,
328+
logits_soft_cap=logits_soft_cap,
329+
)
330+
torch.testing.assert_close(o[i], o_ref_i, rtol=1e-3, atol=1e-3)
331+
332+
# test user-allocated output
333+
o_buffer = torch.empty_like(o)
334+
wrapper.run(q, kv_data, out=o_buffer)
335+
torch.testing.assert_close(o, o_buffer, rtol=1e-3, atol=1e-3)
336+
337+
188338
@pytest.mark.parametrize("batch_size", [12, 17, 128])
189339
@pytest.mark.parametrize("kv_len", [54, 97, 512, 2048, 16384])
190340
@pytest.mark.parametrize("page_size", [1, 8, 16])

0 commit comments

Comments
 (0)