Skip to content

Commit 2b5977e

Browse files
committed
Add cudnn to the BatchPrefillWithPagedKVCacheWrapper
1 parent 0c850d3 commit 2b5977e

File tree

3 files changed

+192
-143
lines changed

3 files changed

+192
-143
lines changed

benchmarks/routines/attention.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -799,14 +799,6 @@ def testBatchPrefillWithPagedKVCacheWrapper(args):
799799
.int()
800800
.to(device)
801801
)
802-
qo_indptr_cudnn = torch.cat(
803-
[
804-
torch.tensor([0], device=device),
805-
torch.cumsum(actual_seq_lens_q_device.view(-1), dim=0)
806-
* head_dim_qk
807-
* num_qo_heads,
808-
]
809-
).int()
810802

811803
# Because actual_seq_lens_kv is the same as actual_seq_lens_q, kv_indptr will become the same as qo_indptr
812804
kv_indptr = (

flashinfer/prefill.py

Lines changed: 145 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -1366,7 +1366,7 @@ def __init__(
13661366
mask will be used in attention computation.
13671367
13681368
backend : str
1369-
The implementation backend, could be ``auto``/``fa2`` or ``fa3``. Defaults to ``auto``.
1369+
The implementation backend, could be ``auto``/``fa2``,``fa3`` or ``cudnn``. Defaults to ``auto``.
13701370
If set to ``auto``, the wrapper will automatically choose the backend based on the
13711371
device architecture and kernel availability.
13721372
@@ -1388,6 +1388,9 @@ def __init__(
13881388
self._jit_module = None
13891389

13901390
self._kv_layout = kv_layout
1391+
if backend == "cudnn":
1392+
assert kv_layout == "NHD", "CUDNN backend only supports NHD layout"
1393+
13911394
self._float_workspace_buffer = float_workspace_buffer
13921395
self.device = float_workspace_buffer.device
13931396
self._vector_sparse_indptr_buffer: Optional[torch.Tensor] = None
@@ -1452,6 +1455,10 @@ def __init__(
14521455
self._mask_indptr_buf = mask_indptr_buf
14531456
self._max_total_num_rows = None
14541457
self._backend = backend
1458+
self._cached_module = None
1459+
self._seq_lens_kv = None
1460+
self._seq_lens_q = None
1461+
self._block_tables = None
14551462

14561463
@property
14571464
def is_cuda_graph_enabled(self) -> bool:
@@ -1510,7 +1517,10 @@ def plan(
15101517
token_pos_in_items_len: int = 0,
15111518
max_item_len_ptr: Optional[torch.Tensor] = None,
15121519
seq_lens: Optional[torch.Tensor] = None,
1520+
seq_lens_q: Optional[torch.Tensor] = None,
15131521
block_tables: Optional[torch.Tensor] = None,
1522+
max_token_per_sequence: Optional[int] = None,
1523+
max_sequence_kv: Optional[int] = None,
15141524
) -> None:
15151525
r"""Plan batch prefill/append attention on Paged KV-Cache for given problem specification.
15161526
@@ -1601,6 +1611,9 @@ def plan(
16011611
a uint16 vector contains the max token length of all items for each prompt
16021612
seq_lens: Optional[torch.Tensor]
16031613
A uint32 1D tensor indicating the kv sequence length of each prompt. shape: ``[batch_size]``.
1614+
seq_lens_q: Optional[torch.Tensor]
1615+
A uint32 1D tensor indicating the q sequence length of each prompt. shape: ``[batch_size]``.
1616+
If not provided, will be set to the same value as ``seq_lens``.
16041617
block_tables: Optional[torch.Tensor]
16051618
A uint32 2D tensor indicating the block table of each prompt. shape: ``[batch_size, max_num_blocks_per_seq]``.
16061619
@@ -1651,22 +1664,28 @@ def plan(
16511664
self._max_item_len_ptr = max_item_len_ptr
16521665

16531666
# NOTE(Zihao): only required if qo_indptr/paged_kv_indptr are device tensors
1654-
qo_indptr_host = qo_indptr.to("cpu")
1655-
paged_kv_indptr_host = paged_kv_indptr.to("cpu")
1656-
paged_kv_last_page_len_host = paged_kv_last_page_len.to("cpu")
1657-
if seq_lens is None:
1658-
kv_lens_arr_host = get_seq_lens(
1659-
paged_kv_indptr_host, paged_kv_last_page_len_host, page_size
1660-
)
1667+
if max_token_per_sequence is not None:
1668+
self._max_q_len = max_token_per_sequence
16611669
else:
1662-
kv_lens_arr_host = seq_lens.cpu()
1663-
self._kv_lens_buffer[: len(kv_lens_arr_host)].copy_(
1664-
kv_lens_arr_host, non_blocking=non_blocking
1665-
)
1666-
self._max_q_len = max(qo_indptr_host).item()
1667-
self._max_kv_len = max(kv_lens_arr_host).item()
1670+
qo_indptr_host = qo_indptr.to("cpu")
1671+
self._max_q_len = max(qo_indptr_host).item()
1672+
total_num_rows = qo_indptr_host[-1]
16681673

1669-
total_num_rows = qo_indptr_host[-1]
1674+
if max_sequence_kv is not None:
1675+
self._max_kv_len = max_sequence_kv
1676+
else:
1677+
paged_kv_indptr_host = paged_kv_indptr.to("cpu")
1678+
paged_kv_last_page_len_host = paged_kv_last_page_len.to("cpu")
1679+
if seq_lens is None:
1680+
kv_lens_arr_host = get_seq_lens(
1681+
paged_kv_indptr_host, paged_kv_last_page_len_host, page_size
1682+
)
1683+
else:
1684+
kv_lens_arr_host = seq_lens.cpu().flatten()
1685+
self._kv_lens_buffer[: len(kv_lens_arr_host)].copy_(
1686+
kv_lens_arr_host, non_blocking=non_blocking
1687+
)
1688+
self._max_kv_len = max(kv_lens_arr_host).item()
16701689

16711690
if self.is_cuda_graph_enabled:
16721691
if self._max_total_num_rows is None:
@@ -1755,23 +1774,23 @@ def plan(
17551774
q_data_type,
17561775
kv_data_type,
17571776
)
1777+
if self._backend != "cudnn":
1778+
get_module_args = (
1779+
q_data_type,
1780+
kv_data_type,
1781+
q_data_type,
1782+
paged_kv_indptr.dtype,
1783+
head_dim_qk,
1784+
head_dim_vo,
1785+
PosEncodingMode[pos_encoding_mode].value,
1786+
window_left >= 0, # use_sliding_window
1787+
logits_soft_cap > 0, # use_logits_soft_cap
1788+
use_fp16_qk_reduction,
1789+
)
17581790

1759-
get_module_args = (
1760-
q_data_type,
1761-
kv_data_type,
1762-
q_data_type,
1763-
paged_kv_indptr.dtype,
1764-
head_dim_qk,
1765-
head_dim_vo,
1766-
PosEncodingMode[pos_encoding_mode].value,
1767-
window_left >= 0, # use_sliding_window
1768-
logits_soft_cap > 0, # use_logits_soft_cap
1769-
use_fp16_qk_reduction,
1770-
)
1771-
1772-
self._cached_module = get_batch_prefill_module(
1773-
self._backend, *get_module_args
1774-
)
1791+
self._cached_module = get_batch_prefill_module(
1792+
self._backend, *get_module_args
1793+
)
17751794

17761795
if self._backend == "fa3" or self._backend == "trtllm-gen":
17771796
if page_size != 1:
@@ -1789,7 +1808,6 @@ def plan(
17891808
].copy_(vector_sparse_indptr_host, non_blocking=non_blocking)
17901809
paged_kv_indptr_host = vector_sparse_indptr_host
17911810

1792-
self._block_tables: Optional[torch.Tensor] = block_tables
17931811
if self._backend == "trtllm-gen":
17941812
assert self._kv_layout == "HND"
17951813
assert logits_soft_cap == 0.0
@@ -1812,32 +1830,36 @@ def plan(
18121830
]
18131831
block_id += num_blocks_needed
18141832

1815-
self._plan_info = self._cached_module.plan(
1816-
self._float_workspace_buffer,
1817-
self._int_workspace_buffer,
1818-
self._pin_memory_int_workspace_buffer,
1819-
qo_indptr_host,
1820-
paged_kv_indptr_host,
1821-
kv_lens_arr_host,
1822-
self._max_total_num_rows or total_num_rows,
1823-
batch_size,
1824-
num_qo_heads,
1825-
num_kv_heads,
1826-
page_size,
1827-
self.is_cuda_graph_enabled,
1828-
head_dim_qk,
1829-
head_dim_vo,
1830-
causal,
1831-
)
1833+
if self._cached_module is not None:
1834+
self._plan_info = self._cached_module.plan(
1835+
self._float_workspace_buffer,
1836+
self._int_workspace_buffer,
1837+
self._pin_memory_int_workspace_buffer,
1838+
qo_indptr_host,
1839+
paged_kv_indptr_host,
1840+
kv_lens_arr_host,
1841+
self._max_total_num_rows or total_num_rows,
1842+
batch_size,
1843+
num_qo_heads,
1844+
num_kv_heads,
1845+
page_size,
1846+
self.is_cuda_graph_enabled,
1847+
head_dim_qk,
1848+
head_dim_vo,
1849+
causal,
1850+
)
18321851

18331852
self._causal = causal
1853+
self._block_tables = block_tables
18341854
self._pos_encoding_mode = pos_encoding_mode
18351855
self._use_fp16_qk_reduction = use_fp16_qk_reduction
18361856
self._window_left = window_left
18371857
self._logits_soft_cap = logits_soft_cap
18381858
self._sm_scale = sm_scale
18391859
self._rope_scale = rope_scale
18401860
self._rope_theta = rope_theta
1861+
self._seq_lens_kv = seq_lens
1862+
self._seq_lens_q = seq_lens_q if seq_lens_q is not None else seq_lens
18411863

18421864
begin_forward = plan
18431865

@@ -2038,56 +2060,83 @@ def run(
20382060
sparse_indices = self._paged_kv_indices_buf
20392061
sparse_indptr = self._paged_kv_indptr_buf
20402062

2041-
run_args = [
2042-
self._float_workspace_buffer,
2043-
self._int_workspace_buffer,
2044-
self._plan_info,
2045-
q,
2046-
k_cache,
2047-
v_cache,
2048-
self._qo_indptr_buf,
2049-
sparse_indptr,
2050-
sparse_indices,
2051-
self._paged_kv_last_page_len_buf,
2052-
out,
2053-
lse,
2054-
mask_mode,
2055-
TensorLayout[self._kv_layout].value,
2056-
window_left,
2057-
enable_pdl,
2058-
]
2059-
if self._jit_module is not None:
2060-
run_args.extend(list(args))
2063+
if self._backend == "cudnn":
2064+
2065+
if self._seq_lens_q is not None and self._seq_lens_q.dim() == 1:
2066+
self._seq_lens_q = self._seq_lens_q.reshape(self._batch_size, 1, 1, 1)
2067+
2068+
if self._seq_lens_kv is not None and self._seq_lens_kv.dim() == 1:
2069+
self._seq_lens_kv = self._seq_lens_kv.reshape(self._batch_size, 1, 1, 1)
2070+
2071+
cudnn_batch_prefill_with_kv_cache(
2072+
q,
2073+
k_cache, # Need to be changed
2074+
v_cache, # Need to be changed
2075+
self._sm_scale,
2076+
self._float_workspace_buffer,
2077+
actual_seq_lens_q=self._seq_lens_q,
2078+
actual_seq_lens_kv=self._seq_lens_kv,
2079+
max_token_per_sequence=self._max_q_len,
2080+
max_sequence_kv=self._max_kv_len,
2081+
block_tables=self._block_tables,
2082+
causal=self._causal,
2083+
return_lse=return_lse,
2084+
batch_offsets_q=self._qo_indptr_buf,
2085+
batch_offsets_o=self._qo_indptr_buf,
2086+
out=out,
2087+
lse=lse,
2088+
)
20612089
else:
2062-
run_args += [
2063-
self._custom_mask_buf,
2064-
self._mask_indptr_buf,
2065-
_get_cache_alibi_slopes_buf(q.shape[1], q.device),
2066-
self._prefix_len_ptr,
2067-
self._token_pos_in_items_ptr,
2068-
self._max_item_len_ptr,
2069-
logits_soft_cap,
2070-
sm_scale,
2071-
None, # scale_q, not supported yet
2072-
None, # scale_k
2073-
None, # scale_v
2074-
rope_scale,
2075-
rope_theta,
2076-
self._token_pos_in_items_len,
2077-
self._num_qo_heads,
2078-
self._num_kv_heads,
2079-
self._block_tables,
2080-
self._kv_lens_buffer,
2081-
page_size,
2082-
self._max_q_len,
2083-
self._max_kv_len,
2084-
self._batch_size,
2090+
run_args = [
2091+
self._float_workspace_buffer,
2092+
self._int_workspace_buffer,
2093+
self._plan_info,
2094+
q,
2095+
k_cache,
2096+
v_cache,
20852097
self._qo_indptr_buf,
2086-
self._vector_sparse_indptr_buffer,
2087-
sinks,
2098+
sparse_indptr,
2099+
sparse_indices,
2100+
self._paged_kv_last_page_len_buf,
2101+
out,
2102+
lse,
2103+
mask_mode,
2104+
TensorLayout[self._kv_layout].value,
2105+
window_left,
2106+
enable_pdl,
20882107
]
2108+
if self._jit_module is not None:
2109+
run_args.extend(list(args))
2110+
else:
2111+
run_args += [
2112+
self._custom_mask_buf,
2113+
self._mask_indptr_buf,
2114+
_get_cache_alibi_slopes_buf(q.shape[1], q.device),
2115+
self._prefix_len_ptr,
2116+
self._token_pos_in_items_ptr,
2117+
self._max_item_len_ptr,
2118+
logits_soft_cap,
2119+
sm_scale,
2120+
None, # scale_q, not supported yet
2121+
None, # scale_k
2122+
None, # scale_v
2123+
rope_scale,
2124+
rope_theta,
2125+
self._token_pos_in_items_len,
2126+
self._num_qo_heads,
2127+
self._num_kv_heads,
2128+
self._block_tables,
2129+
self._kv_lens_buffer,
2130+
page_size,
2131+
self._max_q_len,
2132+
self._max_kv_len,
2133+
self._batch_size,
2134+
self._qo_indptr_buf,
2135+
self._vector_sparse_indptr_buffer,
2136+
sinks,
2137+
]
20892138

2090-
self._cached_module.paged_run(*run_args)
2139+
self._cached_module.paged_run(*run_args)
20912140

20922141
return (out, lse) if return_lse else out
20932142

@@ -2340,6 +2389,7 @@ def __init__(
23402389
self._mask_indptr_buf = mask_indptr_buf
23412390
self._max_total_num_rows = None
23422391
self._backend = backend
2392+
self._cached_module = None
23432393

23442394
@property
23452395
def is_cuda_graph_enabled(self) -> bool:

0 commit comments

Comments
 (0)