Skip to content

Commit 220041c

Browse files
committed
Add cudnn to the BatchPrefillWithPagedKVCacheWrapper
1 parent 20bbc34 commit 220041c

File tree

3 files changed

+200
-143
lines changed

3 files changed

+200
-143
lines changed

benchmarks/routines/attention.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -798,14 +798,6 @@ def testBatchPrefillWithPagedKVCacheWrapper(args):
798798
.int()
799799
.to(device)
800800
)
801-
qo_indptr_cudnn = torch.cat(
802-
[
803-
torch.tensor([0], device=device),
804-
torch.cumsum(actual_seq_lens_q_device.view(-1), dim=0)
805-
* head_dim_qk
806-
* num_qo_heads,
807-
]
808-
).int()
809801

810802
# Because actual_seq_lens_kv is the same as actual_seq_lens_q, kv_indptr will become the same as qo_indptr
811803
kv_indptr = (

flashinfer/prefill.py

Lines changed: 153 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
setup_metainfo_loader,
3434
trtllm_gen_fmha_module,
3535
)
36+
from .cudnn import cudnn_batch_prefill_with_kv_cache
3637
from .page import block_sparse_indices_to_vector_sparse_offsets, get_seq_lens
3738
from .quantization import packbits, segment_packbits
3839
from .utils import (
@@ -1365,7 +1366,7 @@ def __init__(
13651366
mask will be used in attention computation.
13661367
13671368
backend : str
1368-
The implementation backend, could be ``auto``/``fa2`` or ``fa3``. Defaults to ``auto``.
1369+
The implementation backend, could be ``auto``/``fa2``,``fa3`` or ``cudnn``. Defaults to ``auto``.
13691370
If set to ``auto``, the wrapper will automatically choose the backend based on the
13701371
device architecture and kernel availability.
13711372
@@ -1389,6 +1390,9 @@ def __init__(
13891390
self._jit_module = None
13901391

13911392
self._kv_layout = kv_layout
1393+
if backend == "cudnn":
1394+
assert kv_layout == "NHD", "CUDNN backend only supports NHD layout"
1395+
13921396
self._float_workspace_buffer = float_workspace_buffer
13931397
self.device = float_workspace_buffer.device
13941398
self._vector_sparse_indptr_buffer: Optional[torch.Tensor] = None
@@ -1453,6 +1457,11 @@ def __init__(
14531457
self._mask_indptr_buf = mask_indptr_buf
14541458
self._max_total_num_rows = None
14551459
self._backend = backend
1460+
self._plan_info = None
1461+
self._cached_module = None
1462+
self._seq_lens_kv = None
1463+
self._seq_lens_q = None
1464+
self._block_tables = None
14561465

14571466
@property
14581467
def is_cuda_graph_enabled(self) -> bool:
@@ -1511,7 +1520,10 @@ def plan(
15111520
token_pos_in_items_len: int = 0,
15121521
max_item_len_ptr: Optional[torch.Tensor] = None,
15131522
seq_lens: Optional[torch.Tensor] = None,
1523+
seq_lens_q: Optional[torch.Tensor] = None,
15141524
block_tables: Optional[torch.Tensor] = None,
1525+
max_token_per_sequence: Optional[int] = None,
1526+
max_sequence_kv: Optional[int] = None,
15151527
) -> None:
15161528
r"""Plan batch prefill/append attention on Paged KV-Cache for given problem specification.
15171529
@@ -1602,6 +1614,9 @@ def plan(
16021614
a uint16 vector contains the max token length of all items for each prompt
16031615
seq_lens: Optional[torch.Tensor]
16041616
A uint32 1D tensor indicating the kv sequence length of each prompt. shape: ``[batch_size]``.
1617+
seq_lens_q: Optional[torch.Tensor]
1618+
A uint32 1D tensor indicating the q sequence length of each prompt. shape: ``[batch_size]``.
1619+
If not provided, will be set to the same value as ``seq_lens``.
16051620
block_tables: Optional[torch.Tensor]
16061621
A uint32 2D tensor indicating the block table of each prompt. shape: ``[batch_size, max_num_blocks_per_seq]``.
16071622
@@ -1652,22 +1667,28 @@ def plan(
16521667
self._max_item_len_ptr = max_item_len_ptr
16531668

16541669
# NOTE(Zihao): only required if qo_indptr/paged_kv_indptr are device tensors
1655-
qo_indptr_host = qo_indptr.to("cpu")
1656-
paged_kv_indptr_host = paged_kv_indptr.to("cpu")
1657-
paged_kv_last_page_len_host = paged_kv_last_page_len.to("cpu")
1658-
if seq_lens is None:
1659-
kv_lens_arr_host = get_seq_lens(
1660-
paged_kv_indptr_host, paged_kv_last_page_len_host, page_size
1661-
)
1670+
if max_token_per_sequence is not None:
1671+
self._max_q_len = max_token_per_sequence
16621672
else:
1663-
kv_lens_arr_host = seq_lens.cpu()
1664-
self._kv_lens_buffer[: len(kv_lens_arr_host)].copy_(
1665-
kv_lens_arr_host, non_blocking=non_blocking
1666-
)
1667-
self._max_q_len = max(qo_indptr_host).item()
1668-
self._max_kv_len = max(kv_lens_arr_host).item()
1673+
qo_indptr_host = qo_indptr.to("cpu")
1674+
self._max_q_len = max(qo_indptr_host).item()
1675+
total_num_rows = qo_indptr_host[-1]
16691676

1670-
total_num_rows = qo_indptr_host[-1]
1677+
if max_sequence_kv is not None:
1678+
self._max_kv_len = max_sequence_kv
1679+
else:
1680+
paged_kv_indptr_host = paged_kv_indptr.to("cpu")
1681+
paged_kv_last_page_len_host = paged_kv_last_page_len.to("cpu")
1682+
if seq_lens is None:
1683+
kv_lens_arr_host = get_seq_lens(
1684+
paged_kv_indptr_host, paged_kv_last_page_len_host, page_size
1685+
)
1686+
else:
1687+
kv_lens_arr_host = seq_lens.cpu().flatten()
1688+
self._kv_lens_buffer[: len(kv_lens_arr_host)].copy_(
1689+
kv_lens_arr_host, non_blocking=non_blocking
1690+
)
1691+
self._max_kv_len = max(kv_lens_arr_host).item()
16711692

16721693
if self.is_cuda_graph_enabled:
16731694
if self._max_total_num_rows is None:
@@ -1756,23 +1777,23 @@ def plan(
17561777
q_data_type,
17571778
kv_data_type,
17581779
)
1780+
if self._backend != "cudnn":
1781+
get_module_args = (
1782+
q_data_type,
1783+
kv_data_type,
1784+
q_data_type,
1785+
paged_kv_indptr.dtype,
1786+
head_dim_qk,
1787+
head_dim_vo,
1788+
PosEncodingMode[pos_encoding_mode].value,
1789+
window_left >= 0, # use_sliding_window
1790+
logits_soft_cap > 0, # use_logits_soft_cap
1791+
use_fp16_qk_reduction,
1792+
)
17591793

1760-
get_module_args = (
1761-
q_data_type,
1762-
kv_data_type,
1763-
q_data_type,
1764-
paged_kv_indptr.dtype,
1765-
head_dim_qk,
1766-
head_dim_vo,
1767-
PosEncodingMode[pos_encoding_mode].value,
1768-
window_left >= 0, # use_sliding_window
1769-
logits_soft_cap > 0, # use_logits_soft_cap
1770-
use_fp16_qk_reduction,
1771-
)
1772-
1773-
self._cached_module = get_batch_prefill_module(
1774-
self._backend, *get_module_args
1775-
)
1794+
self._cached_module = get_batch_prefill_module(
1795+
self._backend, *get_module_args
1796+
)
17761797

17771798
if self._backend == "fa3" or self._backend == "trtllm-gen":
17781799
if page_size != 1:
@@ -1790,7 +1811,7 @@ def plan(
17901811
].copy_(vector_sparse_indptr_host, non_blocking=non_blocking)
17911812
paged_kv_indptr_host = vector_sparse_indptr_host
17921813

1793-
self._block_tables: Optional[torch.Tensor] = block_tables
1814+
self._block_tables = block_tables
17941815
if self._backend == "trtllm-gen":
17951816
assert self._kv_layout == "HND"
17961817
assert logits_soft_cap == 0.0
@@ -1808,28 +1829,32 @@ def plan(
18081829
block_id = paged_kv_indptr_host[0]
18091830
for i in range(batch_size):
18101831
num_blocks_needed = blocks_per_seq[i]
1832+
assert self._block_tables is not None, (
1833+
"block_tables is not initialized"
1834+
)
18111835
self._block_tables[i, :num_blocks_needed] = paged_kv_indices[
18121836
block_id : block_id + num_blocks_needed
18131837
]
18141838
block_id += num_blocks_needed
18151839

1816-
self._plan_info = self._cached_module.plan(
1817-
self._float_workspace_buffer,
1818-
self._int_workspace_buffer,
1819-
self._pin_memory_int_workspace_buffer,
1820-
qo_indptr_host,
1821-
paged_kv_indptr_host,
1822-
kv_lens_arr_host,
1823-
self._max_total_num_rows or total_num_rows,
1824-
batch_size,
1825-
num_qo_heads,
1826-
num_kv_heads,
1827-
page_size,
1828-
self.is_cuda_graph_enabled,
1829-
head_dim_qk,
1830-
head_dim_vo,
1831-
causal,
1832-
)
1840+
if self._cached_module is not None:
1841+
self._plan_info = self._cached_module.plan(
1842+
self._float_workspace_buffer,
1843+
self._int_workspace_buffer,
1844+
self._pin_memory_int_workspace_buffer,
1845+
qo_indptr_host,
1846+
paged_kv_indptr_host,
1847+
kv_lens_arr_host,
1848+
self._max_total_num_rows or total_num_rows,
1849+
batch_size,
1850+
num_qo_heads,
1851+
num_kv_heads,
1852+
page_size,
1853+
self.is_cuda_graph_enabled,
1854+
head_dim_qk,
1855+
head_dim_vo,
1856+
causal,
1857+
)
18331858

18341859
self._causal = causal
18351860
self._pos_encoding_mode = pos_encoding_mode
@@ -1839,6 +1864,8 @@ def plan(
18391864
self._sm_scale = sm_scale
18401865
self._rope_scale = rope_scale
18411866
self._rope_theta = rope_theta
1867+
self._seq_lens_kv = seq_lens
1868+
self._seq_lens_q = seq_lens_q if seq_lens_q is not None else seq_lens
18421869

18431870
begin_forward = plan
18441871

@@ -2039,56 +2066,84 @@ def run(
20392066
sparse_indices = self._paged_kv_indices_buf
20402067
sparse_indptr = self._paged_kv_indptr_buf
20412068

2042-
run_args = [
2043-
self._float_workspace_buffer,
2044-
self._int_workspace_buffer,
2045-
self._plan_info,
2046-
q,
2047-
k_cache,
2048-
v_cache,
2049-
self._qo_indptr_buf,
2050-
sparse_indptr,
2051-
sparse_indices,
2052-
self._paged_kv_last_page_len_buf,
2053-
out,
2054-
lse,
2055-
mask_mode,
2056-
TensorLayout[self._kv_layout].value,
2057-
window_left,
2058-
enable_pdl,
2059-
]
2060-
if self._jit_module is not None:
2061-
run_args.extend(list(args))
2069+
if self._backend == "cudnn":
2070+
if self._seq_lens_q is not None and self._seq_lens_q.dim() == 1:
2071+
self._seq_lens_q = self._seq_lens_q.reshape(self._batch_size, 1, 1, 1)
2072+
2073+
if self._seq_lens_kv is not None and self._seq_lens_kv.dim() == 1:
2074+
self._seq_lens_kv = self._seq_lens_kv.reshape(self._batch_size, 1, 1, 1)
2075+
2076+
cudnn_batch_prefill_with_kv_cache(
2077+
q,
2078+
k_cache, # Need to be changed
2079+
v_cache, # Need to be changed
2080+
self._sm_scale,
2081+
self._float_workspace_buffer,
2082+
actual_seq_lens_q=self._seq_lens_q,
2083+
actual_seq_lens_kv=self._seq_lens_kv,
2084+
max_token_per_sequence=self._max_q_len,
2085+
max_sequence_kv=self._max_kv_len,
2086+
block_tables=self._block_tables,
2087+
causal=self._causal,
2088+
return_lse=return_lse,
2089+
batch_offsets_q=self._qo_indptr_buf,
2090+
batch_offsets_o=self._qo_indptr_buf,
2091+
out=out,
2092+
lse=lse,
2093+
)
20622094
else:
2063-
run_args += [
2064-
self._custom_mask_buf,
2065-
self._mask_indptr_buf,
2066-
_get_cache_alibi_slopes_buf(q.shape[1], q.device),
2067-
self._prefix_len_ptr,
2068-
self._token_pos_in_items_ptr,
2069-
self._max_item_len_ptr,
2070-
logits_soft_cap,
2071-
sm_scale,
2072-
None, # scale_q, not supported yet
2073-
None, # scale_k
2074-
None, # scale_v
2075-
rope_scale,
2076-
rope_theta,
2077-
self._token_pos_in_items_len,
2078-
self._num_qo_heads,
2079-
self._num_kv_heads,
2080-
self._block_tables,
2081-
self._kv_lens_buffer,
2082-
page_size,
2083-
self._max_q_len,
2084-
self._max_kv_len,
2085-
self._batch_size,
2095+
assert self._plan_info is not None, "plan info is not initialized"
2096+
run_args = [
2097+
self._float_workspace_buffer,
2098+
self._int_workspace_buffer,
2099+
self._plan_info,
2100+
q,
2101+
k_cache,
2102+
v_cache,
20862103
self._qo_indptr_buf,
2087-
self._vector_sparse_indptr_buffer,
2088-
sinks,
2104+
sparse_indptr,
2105+
sparse_indices,
2106+
self._paged_kv_last_page_len_buf,
2107+
out,
2108+
lse,
2109+
mask_mode,
2110+
TensorLayout[self._kv_layout].value,
2111+
window_left,
2112+
enable_pdl,
20892113
]
2114+
if self._jit_module is not None:
2115+
run_args.extend(list(args))
2116+
else:
2117+
run_args += [
2118+
self._custom_mask_buf,
2119+
self._mask_indptr_buf,
2120+
_get_cache_alibi_slopes_buf(q.shape[1], q.device),
2121+
self._prefix_len_ptr,
2122+
self._token_pos_in_items_ptr,
2123+
self._max_item_len_ptr,
2124+
logits_soft_cap,
2125+
sm_scale,
2126+
None, # scale_q, not supported yet
2127+
None, # scale_k
2128+
None, # scale_v
2129+
rope_scale,
2130+
rope_theta,
2131+
self._token_pos_in_items_len,
2132+
self._num_qo_heads,
2133+
self._num_kv_heads,
2134+
self._block_tables,
2135+
self._kv_lens_buffer,
2136+
page_size,
2137+
self._max_q_len,
2138+
self._max_kv_len,
2139+
self._batch_size,
2140+
self._qo_indptr_buf,
2141+
self._vector_sparse_indptr_buffer,
2142+
sinks,
2143+
]
20902144

2091-
self._cached_module.paged_run(*run_args)
2145+
assert self._cached_module is not None, "cached module is not initialized"
2146+
self._cached_module.paged_run(*run_args)
20922147

20932148
return (out, lse) if return_lse else out
20942149

@@ -2343,6 +2398,7 @@ def __init__(
23432398
self._mask_indptr_buf = mask_indptr_buf
23442399
self._max_total_num_rows = None
23452400
self._backend = backend
2401+
self._cached_module = None
23462402

23472403
@property
23482404
def is_cuda_graph_enabled(self) -> bool:
@@ -2613,6 +2669,7 @@ def plan(
26132669
)
26142670
self._max_qo_len = torch.max(qo_indptr[1:] - qo_indptr[:-1]).item()
26152671
else:
2672+
assert self._cached_module is not None, "cached module is not initialized"
26162673
self._plan_info = self._cached_module.plan(
26172674
self._float_workspace_buffer,
26182675
self._int_workspace_buffer,
@@ -2837,6 +2894,7 @@ def run(
28372894
self._token_pos_in_items_len,
28382895
]
28392896

2897+
assert self._cached_module is not None, "cached module is not initialized"
28402898
self._cached_module.ragged_run(*run_args)
28412899
return (out, lse) if return_lse else out
28422900

0 commit comments

Comments
 (0)