33
33
setup_metainfo_loader ,
34
34
trtllm_gen_fmha_module ,
35
35
)
36
+ from .cudnn import cudnn_batch_prefill_with_kv_cache
36
37
from .page import block_sparse_indices_to_vector_sparse_offsets , get_seq_lens
37
38
from .quantization import packbits , segment_packbits
38
39
from .utils import (
@@ -1365,7 +1366,7 @@ def __init__(
1365
1366
mask will be used in attention computation.
1366
1367
1367
1368
backend : str
1368
- The implementation backend, could be ``auto``/``fa2`` or ``fa3 ``. Defaults to ``auto``.
1369
+ The implementation backend, could be ``auto``/``fa2``,``fa3`` or ``cudnn ``. Defaults to ``auto``.
1369
1370
If set to ``auto``, the wrapper will automatically choose the backend based on the
1370
1371
device architecture and kernel availability.
1371
1372
@@ -1389,6 +1390,9 @@ def __init__(
1389
1390
self ._jit_module = None
1390
1391
1391
1392
self ._kv_layout = kv_layout
1393
+ if backend == "cudnn" :
1394
+ assert kv_layout == "NHD" , "CUDNN backend only supports NHD layout"
1395
+
1392
1396
self ._float_workspace_buffer = float_workspace_buffer
1393
1397
self .device = float_workspace_buffer .device
1394
1398
self ._vector_sparse_indptr_buffer : Optional [torch .Tensor ] = None
@@ -1453,6 +1457,11 @@ def __init__(
1453
1457
self ._mask_indptr_buf = mask_indptr_buf
1454
1458
self ._max_total_num_rows = None
1455
1459
self ._backend = backend
1460
+ self ._plan_info = None
1461
+ self ._cached_module = None
1462
+ self ._seq_lens_kv = None
1463
+ self ._seq_lens_q = None
1464
+ self ._block_tables = None
1456
1465
1457
1466
@property
1458
1467
def is_cuda_graph_enabled (self ) -> bool :
@@ -1511,7 +1520,10 @@ def plan(
1511
1520
token_pos_in_items_len : int = 0 ,
1512
1521
max_item_len_ptr : Optional [torch .Tensor ] = None ,
1513
1522
seq_lens : Optional [torch .Tensor ] = None ,
1523
+ seq_lens_q : Optional [torch .Tensor ] = None ,
1514
1524
block_tables : Optional [torch .Tensor ] = None ,
1525
+ max_token_per_sequence : Optional [int ] = None ,
1526
+ max_sequence_kv : Optional [int ] = None ,
1515
1527
) -> None :
1516
1528
r"""Plan batch prefill/append attention on Paged KV-Cache for given problem specification.
1517
1529
@@ -1602,6 +1614,9 @@ def plan(
1602
1614
a uint16 vector contains the max token length of all items for each prompt
1603
1615
seq_lens: Optional[torch.Tensor]
1604
1616
A uint32 1D tensor indicating the kv sequence length of each prompt. shape: ``[batch_size]``.
1617
+ seq_lens_q: Optional[torch.Tensor]
1618
+ A uint32 1D tensor indicating the q sequence length of each prompt. shape: ``[batch_size]``.
1619
+ If not provided, will be set to the same value as ``seq_lens``.
1605
1620
block_tables: Optional[torch.Tensor]
1606
1621
A uint32 2D tensor indicating the block table of each prompt. shape: ``[batch_size, max_num_blocks_per_seq]``.
1607
1622
@@ -1652,22 +1667,28 @@ def plan(
1652
1667
self ._max_item_len_ptr = max_item_len_ptr
1653
1668
1654
1669
# NOTE(Zihao): only required if qo_indptr/paged_kv_indptr are device tensors
1655
- qo_indptr_host = qo_indptr .to ("cpu" )
1656
- paged_kv_indptr_host = paged_kv_indptr .to ("cpu" )
1657
- paged_kv_last_page_len_host = paged_kv_last_page_len .to ("cpu" )
1658
- if seq_lens is None :
1659
- kv_lens_arr_host = get_seq_lens (
1660
- paged_kv_indptr_host , paged_kv_last_page_len_host , page_size
1661
- )
1670
+ if max_token_per_sequence is not None :
1671
+ self ._max_q_len = max_token_per_sequence
1662
1672
else :
1663
- kv_lens_arr_host = seq_lens .cpu ()
1664
- self ._kv_lens_buffer [: len (kv_lens_arr_host )].copy_ (
1665
- kv_lens_arr_host , non_blocking = non_blocking
1666
- )
1667
- self ._max_q_len = max (qo_indptr_host ).item ()
1668
- self ._max_kv_len = max (kv_lens_arr_host ).item ()
1673
+ qo_indptr_host = qo_indptr .to ("cpu" )
1674
+ self ._max_q_len = max (qo_indptr_host ).item ()
1675
+ total_num_rows = qo_indptr_host [- 1 ]
1669
1676
1670
- total_num_rows = qo_indptr_host [- 1 ]
1677
+ if max_sequence_kv is not None :
1678
+ self ._max_kv_len = max_sequence_kv
1679
+ else :
1680
+ paged_kv_indptr_host = paged_kv_indptr .to ("cpu" )
1681
+ paged_kv_last_page_len_host = paged_kv_last_page_len .to ("cpu" )
1682
+ if seq_lens is None :
1683
+ kv_lens_arr_host = get_seq_lens (
1684
+ paged_kv_indptr_host , paged_kv_last_page_len_host , page_size
1685
+ )
1686
+ else :
1687
+ kv_lens_arr_host = seq_lens .cpu ().flatten ()
1688
+ self ._kv_lens_buffer [: len (kv_lens_arr_host )].copy_ (
1689
+ kv_lens_arr_host , non_blocking = non_blocking
1690
+ )
1691
+ self ._max_kv_len = max (kv_lens_arr_host ).item ()
1671
1692
1672
1693
if self .is_cuda_graph_enabled :
1673
1694
if self ._max_total_num_rows is None :
@@ -1756,23 +1777,23 @@ def plan(
1756
1777
q_data_type ,
1757
1778
kv_data_type ,
1758
1779
)
1780
+ if self ._backend != "cudnn" :
1781
+ get_module_args = (
1782
+ q_data_type ,
1783
+ kv_data_type ,
1784
+ q_data_type ,
1785
+ paged_kv_indptr .dtype ,
1786
+ head_dim_qk ,
1787
+ head_dim_vo ,
1788
+ PosEncodingMode [pos_encoding_mode ].value ,
1789
+ window_left >= 0 , # use_sliding_window
1790
+ logits_soft_cap > 0 , # use_logits_soft_cap
1791
+ use_fp16_qk_reduction ,
1792
+ )
1759
1793
1760
- get_module_args = (
1761
- q_data_type ,
1762
- kv_data_type ,
1763
- q_data_type ,
1764
- paged_kv_indptr .dtype ,
1765
- head_dim_qk ,
1766
- head_dim_vo ,
1767
- PosEncodingMode [pos_encoding_mode ].value ,
1768
- window_left >= 0 , # use_sliding_window
1769
- logits_soft_cap > 0 , # use_logits_soft_cap
1770
- use_fp16_qk_reduction ,
1771
- )
1772
-
1773
- self ._cached_module = get_batch_prefill_module (
1774
- self ._backend , * get_module_args
1775
- )
1794
+ self ._cached_module = get_batch_prefill_module (
1795
+ self ._backend , * get_module_args
1796
+ )
1776
1797
1777
1798
if self ._backend == "fa3" or self ._backend == "trtllm-gen" :
1778
1799
if page_size != 1 :
@@ -1790,7 +1811,7 @@ def plan(
1790
1811
].copy_ (vector_sparse_indptr_host , non_blocking = non_blocking )
1791
1812
paged_kv_indptr_host = vector_sparse_indptr_host
1792
1813
1793
- self ._block_tables : Optional [ torch . Tensor ] = block_tables
1814
+ self ._block_tables = block_tables
1794
1815
if self ._backend == "trtllm-gen" :
1795
1816
assert self ._kv_layout == "HND"
1796
1817
assert logits_soft_cap == 0.0
@@ -1808,28 +1829,32 @@ def plan(
1808
1829
block_id = paged_kv_indptr_host [0 ]
1809
1830
for i in range (batch_size ):
1810
1831
num_blocks_needed = blocks_per_seq [i ]
1832
+ assert self ._block_tables is not None , (
1833
+ "block_tables is not initialized"
1834
+ )
1811
1835
self ._block_tables [i , :num_blocks_needed ] = paged_kv_indices [
1812
1836
block_id : block_id + num_blocks_needed
1813
1837
]
1814
1838
block_id += num_blocks_needed
1815
1839
1816
- self ._plan_info = self ._cached_module .plan (
1817
- self ._float_workspace_buffer ,
1818
- self ._int_workspace_buffer ,
1819
- self ._pin_memory_int_workspace_buffer ,
1820
- qo_indptr_host ,
1821
- paged_kv_indptr_host ,
1822
- kv_lens_arr_host ,
1823
- self ._max_total_num_rows or total_num_rows ,
1824
- batch_size ,
1825
- num_qo_heads ,
1826
- num_kv_heads ,
1827
- page_size ,
1828
- self .is_cuda_graph_enabled ,
1829
- head_dim_qk ,
1830
- head_dim_vo ,
1831
- causal ,
1832
- )
1840
+ if self ._cached_module is not None :
1841
+ self ._plan_info = self ._cached_module .plan (
1842
+ self ._float_workspace_buffer ,
1843
+ self ._int_workspace_buffer ,
1844
+ self ._pin_memory_int_workspace_buffer ,
1845
+ qo_indptr_host ,
1846
+ paged_kv_indptr_host ,
1847
+ kv_lens_arr_host ,
1848
+ self ._max_total_num_rows or total_num_rows ,
1849
+ batch_size ,
1850
+ num_qo_heads ,
1851
+ num_kv_heads ,
1852
+ page_size ,
1853
+ self .is_cuda_graph_enabled ,
1854
+ head_dim_qk ,
1855
+ head_dim_vo ,
1856
+ causal ,
1857
+ )
1833
1858
1834
1859
self ._causal = causal
1835
1860
self ._pos_encoding_mode = pos_encoding_mode
@@ -1839,6 +1864,8 @@ def plan(
1839
1864
self ._sm_scale = sm_scale
1840
1865
self ._rope_scale = rope_scale
1841
1866
self ._rope_theta = rope_theta
1867
+ self ._seq_lens_kv = seq_lens
1868
+ self ._seq_lens_q = seq_lens_q if seq_lens_q is not None else seq_lens
1842
1869
1843
1870
begin_forward = plan
1844
1871
@@ -2039,56 +2066,84 @@ def run(
2039
2066
sparse_indices = self ._paged_kv_indices_buf
2040
2067
sparse_indptr = self ._paged_kv_indptr_buf
2041
2068
2042
- run_args = [
2043
- self ._float_workspace_buffer ,
2044
- self ._int_workspace_buffer ,
2045
- self ._plan_info ,
2046
- q ,
2047
- k_cache ,
2048
- v_cache ,
2049
- self ._qo_indptr_buf ,
2050
- sparse_indptr ,
2051
- sparse_indices ,
2052
- self ._paged_kv_last_page_len_buf ,
2053
- out ,
2054
- lse ,
2055
- mask_mode ,
2056
- TensorLayout [self ._kv_layout ].value ,
2057
- window_left ,
2058
- enable_pdl ,
2059
- ]
2060
- if self ._jit_module is not None :
2061
- run_args .extend (list (args ))
2069
+ if self ._backend == "cudnn" :
2070
+ if self ._seq_lens_q is not None and self ._seq_lens_q .dim () == 1 :
2071
+ self ._seq_lens_q = self ._seq_lens_q .reshape (self ._batch_size , 1 , 1 , 1 )
2072
+
2073
+ if self ._seq_lens_kv is not None and self ._seq_lens_kv .dim () == 1 :
2074
+ self ._seq_lens_kv = self ._seq_lens_kv .reshape (self ._batch_size , 1 , 1 , 1 )
2075
+
2076
+ cudnn_batch_prefill_with_kv_cache (
2077
+ q ,
2078
+ k_cache , # Need to be changed
2079
+ v_cache , # Need to be changed
2080
+ self ._sm_scale ,
2081
+ self ._float_workspace_buffer ,
2082
+ actual_seq_lens_q = self ._seq_lens_q ,
2083
+ actual_seq_lens_kv = self ._seq_lens_kv ,
2084
+ max_token_per_sequence = self ._max_q_len ,
2085
+ max_sequence_kv = self ._max_kv_len ,
2086
+ block_tables = self ._block_tables ,
2087
+ causal = self ._causal ,
2088
+ return_lse = return_lse ,
2089
+ batch_offsets_q = self ._qo_indptr_buf ,
2090
+ batch_offsets_o = self ._qo_indptr_buf ,
2091
+ out = out ,
2092
+ lse = lse ,
2093
+ )
2062
2094
else :
2063
- run_args += [
2064
- self ._custom_mask_buf ,
2065
- self ._mask_indptr_buf ,
2066
- _get_cache_alibi_slopes_buf (q .shape [1 ], q .device ),
2067
- self ._prefix_len_ptr ,
2068
- self ._token_pos_in_items_ptr ,
2069
- self ._max_item_len_ptr ,
2070
- logits_soft_cap ,
2071
- sm_scale ,
2072
- None , # scale_q, not supported yet
2073
- None , # scale_k
2074
- None , # scale_v
2075
- rope_scale ,
2076
- rope_theta ,
2077
- self ._token_pos_in_items_len ,
2078
- self ._num_qo_heads ,
2079
- self ._num_kv_heads ,
2080
- self ._block_tables ,
2081
- self ._kv_lens_buffer ,
2082
- page_size ,
2083
- self ._max_q_len ,
2084
- self ._max_kv_len ,
2085
- self ._batch_size ,
2095
+ assert self ._plan_info is not None , "plan info is not initialized"
2096
+ run_args = [
2097
+ self ._float_workspace_buffer ,
2098
+ self ._int_workspace_buffer ,
2099
+ self ._plan_info ,
2100
+ q ,
2101
+ k_cache ,
2102
+ v_cache ,
2086
2103
self ._qo_indptr_buf ,
2087
- self ._vector_sparse_indptr_buffer ,
2088
- sinks ,
2104
+ sparse_indptr ,
2105
+ sparse_indices ,
2106
+ self ._paged_kv_last_page_len_buf ,
2107
+ out ,
2108
+ lse ,
2109
+ mask_mode ,
2110
+ TensorLayout [self ._kv_layout ].value ,
2111
+ window_left ,
2112
+ enable_pdl ,
2089
2113
]
2114
+ if self ._jit_module is not None :
2115
+ run_args .extend (list (args ))
2116
+ else :
2117
+ run_args += [
2118
+ self ._custom_mask_buf ,
2119
+ self ._mask_indptr_buf ,
2120
+ _get_cache_alibi_slopes_buf (q .shape [1 ], q .device ),
2121
+ self ._prefix_len_ptr ,
2122
+ self ._token_pos_in_items_ptr ,
2123
+ self ._max_item_len_ptr ,
2124
+ logits_soft_cap ,
2125
+ sm_scale ,
2126
+ None , # scale_q, not supported yet
2127
+ None , # scale_k
2128
+ None , # scale_v
2129
+ rope_scale ,
2130
+ rope_theta ,
2131
+ self ._token_pos_in_items_len ,
2132
+ self ._num_qo_heads ,
2133
+ self ._num_kv_heads ,
2134
+ self ._block_tables ,
2135
+ self ._kv_lens_buffer ,
2136
+ page_size ,
2137
+ self ._max_q_len ,
2138
+ self ._max_kv_len ,
2139
+ self ._batch_size ,
2140
+ self ._qo_indptr_buf ,
2141
+ self ._vector_sparse_indptr_buffer ,
2142
+ sinks ,
2143
+ ]
2090
2144
2091
- self ._cached_module .paged_run (* run_args )
2145
+ assert self ._cached_module is not None , "cached module is not initialized"
2146
+ self ._cached_module .paged_run (* run_args )
2092
2147
2093
2148
return (out , lse ) if return_lse else out
2094
2149
@@ -2343,6 +2398,7 @@ def __init__(
2343
2398
self ._mask_indptr_buf = mask_indptr_buf
2344
2399
self ._max_total_num_rows = None
2345
2400
self ._backend = backend
2401
+ self ._cached_module = None
2346
2402
2347
2403
@property
2348
2404
def is_cuda_graph_enabled (self ) -> bool :
@@ -2613,6 +2669,7 @@ def plan(
2613
2669
)
2614
2670
self ._max_qo_len = torch .max (qo_indptr [1 :] - qo_indptr [:- 1 ]).item ()
2615
2671
else :
2672
+ assert self ._cached_module is not None , "cached module is not initialized"
2616
2673
self ._plan_info = self ._cached_module .plan (
2617
2674
self ._float_workspace_buffer ,
2618
2675
self ._int_workspace_buffer ,
@@ -2837,6 +2894,7 @@ def run(
2837
2894
self ._token_pos_in_items_len ,
2838
2895
]
2839
2896
2897
+ assert self ._cached_module is not None , "cached module is not initialized"
2840
2898
self ._cached_module .ragged_run (* run_args )
2841
2899
return (out , lse ) if return_lse else out
2842
2900
0 commit comments