Skip to content

Commit bd2ec06

Browse files
committed
Add cudnn to the BatchPrefillWithPagedKVCacheWrapper
1 parent 90d00f0 commit bd2ec06

File tree

3 files changed

+210
-149
lines changed

3 files changed

+210
-149
lines changed

benchmarks/routines/attention.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -798,14 +798,6 @@ def testBatchPrefillWithPagedKVCacheWrapper(args):
798798
.int()
799799
.to(device)
800800
)
801-
qo_indptr_cudnn = torch.cat(
802-
[
803-
torch.tensor([0], device=device),
804-
torch.cumsum(actual_seq_lens_q_device.view(-1), dim=0)
805-
* head_dim_qk
806-
* num_qo_heads,
807-
]
808-
).int()
809801

810802
# Because actual_seq_lens_kv is the same as actual_seq_lens_q, kv_indptr will become the same as qo_indptr
811803
kv_indptr = (

flashinfer/prefill.py

Lines changed: 163 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
setup_metainfo_loader,
3737
trtllm_gen_fmha_module,
3838
)
39+
from .cudnn import cudnn_batch_prefill_with_kv_cache
3940
from .page import block_sparse_indices_to_vector_sparse_offsets, get_seq_lens
4041
from .quantization import packbits, segment_packbits
4142
from .utils import (
@@ -1368,7 +1369,7 @@ def __init__(
13681369
mask will be used in attention computation.
13691370
13701371
backend : str
1371-
The implementation backend, could be ``auto``/``fa2`` or ``fa3``. Defaults to ``auto``.
1372+
The implementation backend, could be ``auto``/``fa2``,``fa3`` or ``cudnn``. Defaults to ``auto``.
13721373
If set to ``auto``, the wrapper will automatically choose the backend based on the
13731374
device architecture and kernel availability.
13741375
@@ -1392,6 +1393,9 @@ def __init__(
13921393
self._jit_module = None
13931394

13941395
self._kv_layout = kv_layout
1396+
if backend == "cudnn":
1397+
assert kv_layout == "NHD", "CUDNN backend only supports NHD layout"
1398+
13951399
self._float_workspace_buffer = float_workspace_buffer
13961400
self.device = float_workspace_buffer.device
13971401
self._vector_sparse_indptr_buffer: Optional[torch.Tensor] = None
@@ -1456,6 +1460,11 @@ def __init__(
14561460
self._mask_indptr_buf = mask_indptr_buf
14571461
self._max_total_num_rows = None
14581462
self._backend = backend
1463+
self._plan_info = None
1464+
self._cached_module = None
1465+
self._seq_lens_kv = None
1466+
self._seq_lens_q = None
1467+
self._block_tables = None
14591468

14601469
@property
14611470
def is_cuda_graph_enabled(self) -> bool:
@@ -1514,7 +1523,10 @@ def plan(
15141523
token_pos_in_items_len: int = 0,
15151524
max_item_len_ptr: Optional[torch.Tensor] = None,
15161525
seq_lens: Optional[torch.Tensor] = None,
1526+
seq_lens_q: Optional[torch.Tensor] = None,
15171527
block_tables: Optional[torch.Tensor] = None,
1528+
max_token_per_sequence: Optional[int] = None,
1529+
max_sequence_kv: Optional[int] = None,
15181530
) -> None:
15191531
r"""Plan batch prefill/append attention on Paged KV-Cache for given problem specification.
15201532
@@ -1605,8 +1617,15 @@ def plan(
16051617
a uint16 vector contains the max token length of all items for each prompt
16061618
seq_lens: Optional[torch.Tensor]
16071619
A uint32 1D tensor indicating the kv sequence length of each prompt. shape: ``[batch_size]``.
1620+
seq_lens_q: Optional[torch.Tensor]
1621+
A uint32 1D tensor indicating the q sequence length of each prompt. shape: ``[batch_size]``.
1622+
If not provided, will be set to the same value as ``seq_lens``.
16081623
block_tables: Optional[torch.Tensor]
16091624
A uint32 2D tensor indicating the block table of each prompt. shape: ``[batch_size, max_num_blocks_per_seq]``.
1625+
max_token_per_sequence: Optional[int],
1626+
Required for cudnn backend. This is the scalar max token length of each sequence.
1627+
max_sequence_kv: Optional[int],
1628+
Required for cudnn backend. This is the scalar max sequence length of each sequence in kv cache.
16101629
16111630
Note
16121631
----
@@ -1655,22 +1674,28 @@ def plan(
16551674
self._max_item_len_ptr = max_item_len_ptr
16561675

16571676
# NOTE(Zihao): only required if qo_indptr/paged_kv_indptr are device tensors
1658-
qo_indptr_host = qo_indptr.to("cpu")
1659-
paged_kv_indptr_host = paged_kv_indptr.to("cpu")
1660-
paged_kv_last_page_len_host = paged_kv_last_page_len.to("cpu")
1661-
if seq_lens is None:
1662-
kv_lens_arr_host = get_seq_lens(
1663-
paged_kv_indptr_host, paged_kv_last_page_len_host, page_size
1664-
)
1677+
if max_token_per_sequence is not None:
1678+
self._max_q_len = max_token_per_sequence
16651679
else:
1666-
kv_lens_arr_host = seq_lens.cpu()
1667-
self._kv_lens_buffer[: len(kv_lens_arr_host)].copy_(
1668-
kv_lens_arr_host, non_blocking=non_blocking
1669-
)
1670-
self._max_q_len = max(qo_indptr_host).item()
1671-
self._max_kv_len = max(kv_lens_arr_host).item()
1680+
qo_indptr_host = qo_indptr.to("cpu")
1681+
self._max_q_len = max(qo_indptr_host).item()
1682+
total_num_rows = qo_indptr_host[-1]
16721683

1673-
total_num_rows = qo_indptr_host[-1]
1684+
if max_sequence_kv is not None:
1685+
self._max_kv_len = max_sequence_kv
1686+
else:
1687+
paged_kv_indptr_host = paged_kv_indptr.to("cpu")
1688+
paged_kv_last_page_len_host = paged_kv_last_page_len.to("cpu")
1689+
if seq_lens is None:
1690+
kv_lens_arr_host = get_seq_lens(
1691+
paged_kv_indptr_host, paged_kv_last_page_len_host, page_size
1692+
)
1693+
else:
1694+
kv_lens_arr_host = seq_lens.cpu().flatten()
1695+
self._kv_lens_buffer[: len(kv_lens_arr_host)].copy_(
1696+
kv_lens_arr_host, non_blocking=non_blocking
1697+
)
1698+
self._max_kv_len = max(kv_lens_arr_host).item()
16741699

16751700
if self.is_cuda_graph_enabled:
16761701
if self._max_total_num_rows is None:
@@ -1759,23 +1784,23 @@ def plan(
17591784
q_data_type,
17601785
kv_data_type,
17611786
)
1787+
if self._backend != "cudnn":
1788+
get_module_args = (
1789+
q_data_type,
1790+
kv_data_type,
1791+
q_data_type,
1792+
paged_kv_indptr.dtype,
1793+
head_dim_qk,
1794+
head_dim_vo,
1795+
PosEncodingMode[pos_encoding_mode].value,
1796+
window_left >= 0, # use_sliding_window
1797+
logits_soft_cap > 0, # use_logits_soft_cap
1798+
use_fp16_qk_reduction,
1799+
)
17621800

1763-
get_module_args = (
1764-
q_data_type,
1765-
kv_data_type,
1766-
q_data_type,
1767-
paged_kv_indptr.dtype,
1768-
head_dim_qk,
1769-
head_dim_vo,
1770-
PosEncodingMode[pos_encoding_mode].value,
1771-
window_left >= 0, # use_sliding_window
1772-
logits_soft_cap > 0, # use_logits_soft_cap
1773-
use_fp16_qk_reduction,
1774-
)
1775-
1776-
self._cached_module = get_batch_prefill_module(
1777-
self._backend, *get_module_args
1778-
)
1801+
self._cached_module = get_batch_prefill_module(
1802+
self._backend, *get_module_args
1803+
)
17791804

17801805
if self._backend == "fa3" or self._backend == "trtllm-gen":
17811806
if page_size != 1:
@@ -1793,7 +1818,7 @@ def plan(
17931818
].copy_(vector_sparse_indptr_host, non_blocking=non_blocking)
17941819
paged_kv_indptr_host = vector_sparse_indptr_host
17951820

1796-
self._block_tables: Optional[torch.Tensor] = block_tables
1821+
self._block_tables = block_tables
17971822
if self._backend == "trtllm-gen":
17981823
assert self._kv_layout == "HND"
17991824
assert logits_soft_cap == 0.0
@@ -1811,28 +1836,32 @@ def plan(
18111836
block_id = paged_kv_indptr_host[0]
18121837
for i in range(batch_size):
18131838
num_blocks_needed = blocks_per_seq[i]
1839+
assert self._block_tables is not None, (
1840+
"block_tables is not initialized"
1841+
)
18141842
self._block_tables[i, :num_blocks_needed] = paged_kv_indices[
18151843
block_id : block_id + num_blocks_needed
18161844
]
18171845
block_id += num_blocks_needed
18181846

1819-
self._plan_info = self._cached_module.plan(
1820-
self._float_workspace_buffer,
1821-
self._int_workspace_buffer,
1822-
self._pin_memory_int_workspace_buffer,
1823-
qo_indptr_host,
1824-
paged_kv_indptr_host,
1825-
kv_lens_arr_host,
1826-
self._max_total_num_rows or total_num_rows,
1827-
batch_size,
1828-
num_qo_heads,
1829-
num_kv_heads,
1830-
page_size,
1831-
self.is_cuda_graph_enabled,
1832-
head_dim_qk,
1833-
head_dim_vo,
1834-
causal,
1835-
)
1847+
if self._cached_module is not None:
1848+
self._plan_info = self._cached_module.plan(
1849+
self._float_workspace_buffer,
1850+
self._int_workspace_buffer,
1851+
self._pin_memory_int_workspace_buffer,
1852+
qo_indptr_host,
1853+
paged_kv_indptr_host,
1854+
kv_lens_arr_host,
1855+
self._max_total_num_rows or total_num_rows,
1856+
batch_size,
1857+
num_qo_heads,
1858+
num_kv_heads,
1859+
page_size,
1860+
self.is_cuda_graph_enabled,
1861+
head_dim_qk,
1862+
head_dim_vo,
1863+
causal,
1864+
)
18361865

18371866
self._causal = causal
18381867
self._pos_encoding_mode = pos_encoding_mode
@@ -1842,6 +1871,8 @@ def plan(
18421871
self._sm_scale = sm_scale
18431872
self._rope_scale = rope_scale
18441873
self._rope_theta = rope_theta
1874+
self._seq_lens_kv = seq_lens
1875+
self._seq_lens_q = seq_lens_q if seq_lens_q is not None else seq_lens
18451876

18461877
begin_forward = plan
18471878

@@ -2042,62 +2073,90 @@ def run(
20422073
sparse_indices = self._paged_kv_indices_buf
20432074
sparse_indptr = self._paged_kv_indptr_buf
20442075

2045-
run_args = [
2046-
self._float_workspace_buffer,
2047-
self._int_workspace_buffer,
2048-
self._plan_info,
2049-
q,
2050-
k_cache,
2051-
v_cache,
2052-
self._qo_indptr_buf,
2053-
sparse_indptr,
2054-
sparse_indices,
2055-
self._paged_kv_last_page_len_buf,
2056-
out,
2057-
lse,
2058-
mask_mode,
2059-
TensorLayout[self._kv_layout].value,
2060-
window_left,
2061-
enable_pdl,
2062-
]
2063-
if self._jit_module is not None:
2064-
run_args.extend(list(args))
2076+
if self._backend == "cudnn":
2077+
if self._seq_lens_q is not None and self._seq_lens_q.dim() == 1:
2078+
self._seq_lens_q = self._seq_lens_q.reshape(self._batch_size, 1, 1, 1)
2079+
2080+
if self._seq_lens_kv is not None and self._seq_lens_kv.dim() == 1:
2081+
self._seq_lens_kv = self._seq_lens_kv.reshape(self._batch_size, 1, 1, 1)
2082+
2083+
cudnn_batch_prefill_with_kv_cache(
2084+
q,
2085+
k_cache, # Need to be changed
2086+
v_cache, # Need to be changed
2087+
self._sm_scale,
2088+
self._float_workspace_buffer,
2089+
actual_seq_lens_q=self._seq_lens_q,
2090+
actual_seq_lens_kv=self._seq_lens_kv,
2091+
max_token_per_sequence=self._max_q_len,
2092+
max_sequence_kv=self._max_kv_len,
2093+
block_tables=self._block_tables,
2094+
causal=self._causal,
2095+
return_lse=return_lse,
2096+
batch_offsets_q=self._qo_indptr_buf,
2097+
batch_offsets_o=self._qo_indptr_buf,
2098+
out=out,
2099+
lse=lse,
2100+
)
20652101
else:
2066-
run_args += [
2067-
self._custom_mask_buf,
2068-
self._mask_indptr_buf,
2069-
_get_cache_alibi_slopes_buf(q.shape[1], q.device),
2070-
self._prefix_len_ptr,
2071-
self._token_pos_in_items_ptr,
2072-
self._max_item_len_ptr,
2073-
logits_soft_cap,
2074-
sm_scale,
2075-
None, # scale_q, not supported yet
2076-
None, # scale_k
2077-
None, # scale_v
2078-
rope_scale,
2079-
rope_theta,
2080-
self._token_pos_in_items_len,
2081-
self._num_qo_heads,
2082-
self._num_kv_heads,
2083-
self._block_tables,
2084-
self._kv_lens_buffer,
2085-
page_size,
2086-
self._max_q_len,
2087-
self._max_kv_len,
2088-
self._batch_size,
2102+
assert self._plan_info is not None, "plan info is not initialized"
2103+
run_args = [
2104+
self._float_workspace_buffer,
2105+
self._int_workspace_buffer,
2106+
self._plan_info,
2107+
q,
2108+
k_cache,
2109+
v_cache,
20892110
self._qo_indptr_buf,
2090-
self._vector_sparse_indptr_buffer,
2091-
sinks,
2111+
sparse_indptr,
2112+
sparse_indices,
2113+
self._paged_kv_last_page_len_buf,
2114+
out,
2115+
lse,
2116+
mask_mode,
2117+
TensorLayout[self._kv_layout].value,
2118+
window_left,
2119+
enable_pdl,
20922120
]
2093-
2094-
self._cached_module.paged_run(*run_args)
2095-
if v_scale is not None:
2096-
# TODO(Zihao): fused into kernel
2097-
if is_float8(out):
2098-
out = (out.to(torch.float32) * v_scale).to(out.dtype)
2121+
if self._jit_module is not None:
2122+
run_args.extend(list(args))
20992123
else:
2100-
out *= v_scale
2124+
run_args += [
2125+
self._custom_mask_buf,
2126+
self._mask_indptr_buf,
2127+
_get_cache_alibi_slopes_buf(q.shape[1], q.device),
2128+
self._prefix_len_ptr,
2129+
self._token_pos_in_items_ptr,
2130+
self._max_item_len_ptr,
2131+
logits_soft_cap,
2132+
sm_scale,
2133+
None, # scale_q, not supported yet
2134+
None, # scale_k
2135+
None, # scale_v
2136+
rope_scale,
2137+
rope_theta,
2138+
self._token_pos_in_items_len,
2139+
self._num_qo_heads,
2140+
self._num_kv_heads,
2141+
self._block_tables,
2142+
self._kv_lens_buffer,
2143+
page_size,
2144+
self._max_q_len,
2145+
self._max_kv_len,
2146+
self._batch_size,
2147+
self._qo_indptr_buf,
2148+
self._vector_sparse_indptr_buffer,
2149+
sinks,
2150+
]
2151+
2152+
assert self._cached_module is not None, "cached module is not initialized"
2153+
self._cached_module.paged_run(*run_args)
2154+
if v_scale is not None:
2155+
# TODO(Zihao): fused into kernel
2156+
if is_float8(out):
2157+
out = (out.to(torch.float32) * v_scale).to(out.dtype)
2158+
else:
2159+
out *= v_scale
21012160
return (out, lse) if return_lse else out
21022161

21032162
run_return_lse = functools.partialmethod(run, return_lse=True)
@@ -2351,6 +2410,7 @@ def __init__(
23512410
self._mask_indptr_buf = mask_indptr_buf
23522411
self._max_total_num_rows = None
23532412
self._backend = backend
2413+
self._cached_module = None
23542414

23552415
@property
23562416
def is_cuda_graph_enabled(self) -> bool:
@@ -2621,6 +2681,7 @@ def plan(
26212681
)
26222682
self._max_qo_len = torch.max(qo_indptr[1:] - qo_indptr[:-1]).item()
26232683
else:
2684+
assert self._cached_module is not None, "cached module is not initialized"
26242685
self._plan_info = self._cached_module.plan(
26252686
self._float_workspace_buffer,
26262687
self._int_workspace_buffer,
@@ -2845,6 +2906,7 @@ def run(
28452906
self._token_pos_in_items_len,
28462907
]
28472908

2909+
assert self._cached_module is not None, "cached module is not initialized"
28482910
self._cached_module.ragged_run(*run_args)
28492911
return (out, lse) if return_lse else out
28502912

0 commit comments

Comments
 (0)