@@ -1366,7 +1366,7 @@ def __init__(
1366
1366
mask will be used in attention computation.
1367
1367
1368
1368
backend : str
1369
- The implementation backend, could be ``auto``/``fa2`` or ``fa3 ``. Defaults to ``auto``.
1369
+ The implementation backend, could be ``auto``/``fa2``,``fa3`` or ``cudnn ``. Defaults to ``auto``.
1370
1370
If set to ``auto``, the wrapper will automatically choose the backend based on the
1371
1371
device architecture and kernel availability.
1372
1372
@@ -1388,6 +1388,9 @@ def __init__(
1388
1388
self ._jit_module = None
1389
1389
1390
1390
self ._kv_layout = kv_layout
1391
+ if backend == "cudnn" :
1392
+ assert kv_layout == "NHD" , "CUDNN backend only supports NHD layout"
1393
+
1391
1394
self ._float_workspace_buffer = float_workspace_buffer
1392
1395
self .device = float_workspace_buffer .device
1393
1396
self ._vector_sparse_indptr_buffer : Optional [torch .Tensor ] = None
@@ -1452,6 +1455,10 @@ def __init__(
1452
1455
self ._mask_indptr_buf = mask_indptr_buf
1453
1456
self ._max_total_num_rows = None
1454
1457
self ._backend = backend
1458
+ self ._cached_module = None
1459
+ self ._seq_lens_kv = None
1460
+ self ._seq_lens_q = None
1461
+ self ._block_tables = None
1455
1462
1456
1463
@property
1457
1464
def is_cuda_graph_enabled (self ) -> bool :
@@ -1510,7 +1517,10 @@ def plan(
1510
1517
token_pos_in_items_len : int = 0 ,
1511
1518
max_item_len_ptr : Optional [torch .Tensor ] = None ,
1512
1519
seq_lens : Optional [torch .Tensor ] = None ,
1520
+ seq_lens_q : Optional [torch .Tensor ] = None ,
1513
1521
block_tables : Optional [torch .Tensor ] = None ,
1522
+ max_token_per_sequence : Optional [int ] = None ,
1523
+ max_sequence_kv : Optional [int ] = None ,
1514
1524
) -> None :
1515
1525
r"""Plan batch prefill/append attention on Paged KV-Cache for given problem specification.
1516
1526
@@ -1601,6 +1611,9 @@ def plan(
1601
1611
a uint16 vector contains the max token length of all items for each prompt
1602
1612
seq_lens: Optional[torch.Tensor]
1603
1613
A uint32 1D tensor indicating the kv sequence length of each prompt. shape: ``[batch_size]``.
1614
+ seq_lens_q: Optional[torch.Tensor]
1615
+ A uint32 1D tensor indicating the q sequence length of each prompt. shape: ``[batch_size]``.
1616
+ If not provided, will be set to the same value as ``seq_lens``.
1604
1617
block_tables: Optional[torch.Tensor]
1605
1618
A uint32 2D tensor indicating the block table of each prompt. shape: ``[batch_size, max_num_blocks_per_seq]``.
1606
1619
@@ -1651,22 +1664,28 @@ def plan(
1651
1664
self ._max_item_len_ptr = max_item_len_ptr
1652
1665
1653
1666
# NOTE(Zihao): only required if qo_indptr/paged_kv_indptr are device tensors
1654
- qo_indptr_host = qo_indptr .to ("cpu" )
1655
- paged_kv_indptr_host = paged_kv_indptr .to ("cpu" )
1656
- paged_kv_last_page_len_host = paged_kv_last_page_len .to ("cpu" )
1657
- if seq_lens is None :
1658
- kv_lens_arr_host = get_seq_lens (
1659
- paged_kv_indptr_host , paged_kv_last_page_len_host , page_size
1660
- )
1667
+ if max_token_per_sequence is not None :
1668
+ self ._max_q_len = max_token_per_sequence
1661
1669
else :
1662
- kv_lens_arr_host = seq_lens .cpu ()
1663
- self ._kv_lens_buffer [: len (kv_lens_arr_host )].copy_ (
1664
- kv_lens_arr_host , non_blocking = non_blocking
1665
- )
1666
- self ._max_q_len = max (qo_indptr_host ).item ()
1667
- self ._max_kv_len = max (kv_lens_arr_host ).item ()
1670
+ qo_indptr_host = qo_indptr .to ("cpu" )
1671
+ self ._max_q_len = max (qo_indptr_host ).item ()
1672
+ total_num_rows = qo_indptr_host [- 1 ]
1668
1673
1669
- total_num_rows = qo_indptr_host [- 1 ]
1674
+ if max_sequence_kv is not None :
1675
+ self ._max_kv_len = max_sequence_kv
1676
+ else :
1677
+ paged_kv_indptr_host = paged_kv_indptr .to ("cpu" )
1678
+ paged_kv_last_page_len_host = paged_kv_last_page_len .to ("cpu" )
1679
+ if seq_lens is None :
1680
+ kv_lens_arr_host = get_seq_lens (
1681
+ paged_kv_indptr_host , paged_kv_last_page_len_host , page_size
1682
+ )
1683
+ else :
1684
+ kv_lens_arr_host = seq_lens .cpu ().flatten ()
1685
+ self ._kv_lens_buffer [: len (kv_lens_arr_host )].copy_ (
1686
+ kv_lens_arr_host , non_blocking = non_blocking
1687
+ )
1688
+ self ._max_kv_len = max (kv_lens_arr_host ).item ()
1670
1689
1671
1690
if self .is_cuda_graph_enabled :
1672
1691
if self ._max_total_num_rows is None :
@@ -1755,23 +1774,23 @@ def plan(
1755
1774
q_data_type ,
1756
1775
kv_data_type ,
1757
1776
)
1777
+ if self ._backend != "cudnn" :
1778
+ get_module_args = (
1779
+ q_data_type ,
1780
+ kv_data_type ,
1781
+ q_data_type ,
1782
+ paged_kv_indptr .dtype ,
1783
+ head_dim_qk ,
1784
+ head_dim_vo ,
1785
+ PosEncodingMode [pos_encoding_mode ].value ,
1786
+ window_left >= 0 , # use_sliding_window
1787
+ logits_soft_cap > 0 , # use_logits_soft_cap
1788
+ use_fp16_qk_reduction ,
1789
+ )
1758
1790
1759
- get_module_args = (
1760
- q_data_type ,
1761
- kv_data_type ,
1762
- q_data_type ,
1763
- paged_kv_indptr .dtype ,
1764
- head_dim_qk ,
1765
- head_dim_vo ,
1766
- PosEncodingMode [pos_encoding_mode ].value ,
1767
- window_left >= 0 , # use_sliding_window
1768
- logits_soft_cap > 0 , # use_logits_soft_cap
1769
- use_fp16_qk_reduction ,
1770
- )
1771
-
1772
- self ._cached_module = get_batch_prefill_module (
1773
- self ._backend , * get_module_args
1774
- )
1791
+ self ._cached_module = get_batch_prefill_module (
1792
+ self ._backend , * get_module_args
1793
+ )
1775
1794
1776
1795
if self ._backend == "fa3" or self ._backend == "trtllm-gen" :
1777
1796
if page_size != 1 :
@@ -1789,7 +1808,6 @@ def plan(
1789
1808
].copy_ (vector_sparse_indptr_host , non_blocking = non_blocking )
1790
1809
paged_kv_indptr_host = vector_sparse_indptr_host
1791
1810
1792
- self ._block_tables : Optional [torch .Tensor ] = block_tables
1793
1811
if self ._backend == "trtllm-gen" :
1794
1812
assert self ._kv_layout == "HND"
1795
1813
assert logits_soft_cap == 0.0
@@ -1812,32 +1830,36 @@ def plan(
1812
1830
]
1813
1831
block_id += num_blocks_needed
1814
1832
1815
- self ._plan_info = self ._cached_module .plan (
1816
- self ._float_workspace_buffer ,
1817
- self ._int_workspace_buffer ,
1818
- self ._pin_memory_int_workspace_buffer ,
1819
- qo_indptr_host ,
1820
- paged_kv_indptr_host ,
1821
- kv_lens_arr_host ,
1822
- self ._max_total_num_rows or total_num_rows ,
1823
- batch_size ,
1824
- num_qo_heads ,
1825
- num_kv_heads ,
1826
- page_size ,
1827
- self .is_cuda_graph_enabled ,
1828
- head_dim_qk ,
1829
- head_dim_vo ,
1830
- causal ,
1831
- )
1833
+ if self ._cached_module is not None :
1834
+ self ._plan_info = self ._cached_module .plan (
1835
+ self ._float_workspace_buffer ,
1836
+ self ._int_workspace_buffer ,
1837
+ self ._pin_memory_int_workspace_buffer ,
1838
+ qo_indptr_host ,
1839
+ paged_kv_indptr_host ,
1840
+ kv_lens_arr_host ,
1841
+ self ._max_total_num_rows or total_num_rows ,
1842
+ batch_size ,
1843
+ num_qo_heads ,
1844
+ num_kv_heads ,
1845
+ page_size ,
1846
+ self .is_cuda_graph_enabled ,
1847
+ head_dim_qk ,
1848
+ head_dim_vo ,
1849
+ causal ,
1850
+ )
1832
1851
1833
1852
self ._causal = causal
1853
+ self ._block_tables = block_tables
1834
1854
self ._pos_encoding_mode = pos_encoding_mode
1835
1855
self ._use_fp16_qk_reduction = use_fp16_qk_reduction
1836
1856
self ._window_left = window_left
1837
1857
self ._logits_soft_cap = logits_soft_cap
1838
1858
self ._sm_scale = sm_scale
1839
1859
self ._rope_scale = rope_scale
1840
1860
self ._rope_theta = rope_theta
1861
+ self ._seq_lens_kv = seq_lens
1862
+ self ._seq_lens_q = seq_lens_q if seq_lens_q is not None else seq_lens
1841
1863
1842
1864
begin_forward = plan
1843
1865
@@ -2038,56 +2060,83 @@ def run(
2038
2060
sparse_indices = self ._paged_kv_indices_buf
2039
2061
sparse_indptr = self ._paged_kv_indptr_buf
2040
2062
2041
- run_args = [
2042
- self ._float_workspace_buffer ,
2043
- self ._int_workspace_buffer ,
2044
- self ._plan_info ,
2045
- q ,
2046
- k_cache ,
2047
- v_cache ,
2048
- self ._qo_indptr_buf ,
2049
- sparse_indptr ,
2050
- sparse_indices ,
2051
- self ._paged_kv_last_page_len_buf ,
2052
- out ,
2053
- lse ,
2054
- mask_mode ,
2055
- TensorLayout [self ._kv_layout ].value ,
2056
- window_left ,
2057
- enable_pdl ,
2058
- ]
2059
- if self ._jit_module is not None :
2060
- run_args .extend (list (args ))
2063
+ if self ._backend == "cudnn" :
2064
+
2065
+ if self ._seq_lens_q is not None and self ._seq_lens_q .dim () == 1 :
2066
+ self ._seq_lens_q = self ._seq_lens_q .reshape (self ._batch_size , 1 , 1 , 1 )
2067
+
2068
+ if self ._seq_lens_kv is not None and self ._seq_lens_kv .dim () == 1 :
2069
+ self ._seq_lens_kv = self ._seq_lens_kv .reshape (self ._batch_size , 1 , 1 , 1 )
2070
+
2071
+ cudnn_batch_prefill_with_kv_cache (
2072
+ q ,
2073
+ k_cache , # Need to be changed
2074
+ v_cache , # Need to be changed
2075
+ self ._sm_scale ,
2076
+ self ._float_workspace_buffer ,
2077
+ actual_seq_lens_q = self ._seq_lens_q ,
2078
+ actual_seq_lens_kv = self ._seq_lens_kv ,
2079
+ max_token_per_sequence = self ._max_q_len ,
2080
+ max_sequence_kv = self ._max_kv_len ,
2081
+ block_tables = self ._block_tables ,
2082
+ causal = self ._causal ,
2083
+ return_lse = return_lse ,
2084
+ batch_offsets_q = self ._qo_indptr_buf ,
2085
+ batch_offsets_o = self ._qo_indptr_buf ,
2086
+ out = out ,
2087
+ lse = lse ,
2088
+ )
2061
2089
else :
2062
- run_args += [
2063
- self ._custom_mask_buf ,
2064
- self ._mask_indptr_buf ,
2065
- _get_cache_alibi_slopes_buf (q .shape [1 ], q .device ),
2066
- self ._prefix_len_ptr ,
2067
- self ._token_pos_in_items_ptr ,
2068
- self ._max_item_len_ptr ,
2069
- logits_soft_cap ,
2070
- sm_scale ,
2071
- None , # scale_q, not supported yet
2072
- None , # scale_k
2073
- None , # scale_v
2074
- rope_scale ,
2075
- rope_theta ,
2076
- self ._token_pos_in_items_len ,
2077
- self ._num_qo_heads ,
2078
- self ._num_kv_heads ,
2079
- self ._block_tables ,
2080
- self ._kv_lens_buffer ,
2081
- page_size ,
2082
- self ._max_q_len ,
2083
- self ._max_kv_len ,
2084
- self ._batch_size ,
2090
+ run_args = [
2091
+ self ._float_workspace_buffer ,
2092
+ self ._int_workspace_buffer ,
2093
+ self ._plan_info ,
2094
+ q ,
2095
+ k_cache ,
2096
+ v_cache ,
2085
2097
self ._qo_indptr_buf ,
2086
- self ._vector_sparse_indptr_buffer ,
2087
- sinks ,
2098
+ sparse_indptr ,
2099
+ sparse_indices ,
2100
+ self ._paged_kv_last_page_len_buf ,
2101
+ out ,
2102
+ lse ,
2103
+ mask_mode ,
2104
+ TensorLayout [self ._kv_layout ].value ,
2105
+ window_left ,
2106
+ enable_pdl ,
2088
2107
]
2108
+ if self ._jit_module is not None :
2109
+ run_args .extend (list (args ))
2110
+ else :
2111
+ run_args += [
2112
+ self ._custom_mask_buf ,
2113
+ self ._mask_indptr_buf ,
2114
+ _get_cache_alibi_slopes_buf (q .shape [1 ], q .device ),
2115
+ self ._prefix_len_ptr ,
2116
+ self ._token_pos_in_items_ptr ,
2117
+ self ._max_item_len_ptr ,
2118
+ logits_soft_cap ,
2119
+ sm_scale ,
2120
+ None , # scale_q, not supported yet
2121
+ None , # scale_k
2122
+ None , # scale_v
2123
+ rope_scale ,
2124
+ rope_theta ,
2125
+ self ._token_pos_in_items_len ,
2126
+ self ._num_qo_heads ,
2127
+ self ._num_kv_heads ,
2128
+ self ._block_tables ,
2129
+ self ._kv_lens_buffer ,
2130
+ page_size ,
2131
+ self ._max_q_len ,
2132
+ self ._max_kv_len ,
2133
+ self ._batch_size ,
2134
+ self ._qo_indptr_buf ,
2135
+ self ._vector_sparse_indptr_buffer ,
2136
+ sinks ,
2137
+ ]
2089
2138
2090
- self ._cached_module .paged_run (* run_args )
2139
+ self ._cached_module .paged_run (* run_args )
2091
2140
2092
2141
return (out , lse ) if return_lse else out
2093
2142
@@ -2340,6 +2389,7 @@ def __init__(
2340
2389
self ._mask_indptr_buf = mask_indptr_buf
2341
2390
self ._max_total_num_rows = None
2342
2391
self ._backend = backend
2392
+ self ._cached_module = None
2343
2393
2344
2394
@property
2345
2395
def is_cuda_graph_enabled (self ) -> bool :
0 commit comments