36
36
setup_metainfo_loader ,
37
37
trtllm_gen_fmha_module ,
38
38
)
39
+ from .cudnn import cudnn_batch_prefill_with_kv_cache
39
40
from .page import block_sparse_indices_to_vector_sparse_offsets , get_seq_lens
40
41
from .quantization import packbits , segment_packbits
41
42
from .utils import (
@@ -1368,7 +1369,7 @@ def __init__(
1368
1369
mask will be used in attention computation.
1369
1370
1370
1371
backend : str
1371
- The implementation backend, could be ``auto``/``fa2`` or ``fa3 ``. Defaults to ``auto``.
1372
+ The implementation backend, could be ``auto``/``fa2``,``fa3`` or ``cudnn ``. Defaults to ``auto``.
1372
1373
If set to ``auto``, the wrapper will automatically choose the backend based on the
1373
1374
device architecture and kernel availability.
1374
1375
@@ -1392,6 +1393,9 @@ def __init__(
1392
1393
self ._jit_module = None
1393
1394
1394
1395
self ._kv_layout = kv_layout
1396
+ if backend == "cudnn" :
1397
+ assert kv_layout == "NHD" , "CUDNN backend only supports NHD layout"
1398
+
1395
1399
self ._float_workspace_buffer = float_workspace_buffer
1396
1400
self .device = float_workspace_buffer .device
1397
1401
self ._vector_sparse_indptr_buffer : Optional [torch .Tensor ] = None
@@ -1456,6 +1460,11 @@ def __init__(
1456
1460
self ._mask_indptr_buf = mask_indptr_buf
1457
1461
self ._max_total_num_rows = None
1458
1462
self ._backend = backend
1463
+ self ._plan_info = None
1464
+ self ._cached_module = None
1465
+ self ._seq_lens_kv = None
1466
+ self ._seq_lens_q = None
1467
+ self ._block_tables = None
1459
1468
1460
1469
@property
1461
1470
def is_cuda_graph_enabled (self ) -> bool :
@@ -1514,7 +1523,10 @@ def plan(
1514
1523
token_pos_in_items_len : int = 0 ,
1515
1524
max_item_len_ptr : Optional [torch .Tensor ] = None ,
1516
1525
seq_lens : Optional [torch .Tensor ] = None ,
1526
+ seq_lens_q : Optional [torch .Tensor ] = None ,
1517
1527
block_tables : Optional [torch .Tensor ] = None ,
1528
+ max_token_per_sequence : Optional [int ] = None ,
1529
+ max_sequence_kv : Optional [int ] = None ,
1518
1530
) -> None :
1519
1531
r"""Plan batch prefill/append attention on Paged KV-Cache for given problem specification.
1520
1532
@@ -1605,8 +1617,15 @@ def plan(
1605
1617
a uint16 vector contains the max token length of all items for each prompt
1606
1618
seq_lens: Optional[torch.Tensor]
1607
1619
A uint32 1D tensor indicating the kv sequence length of each prompt. shape: ``[batch_size]``.
1620
+ seq_lens_q: Optional[torch.Tensor]
1621
+ A uint32 1D tensor indicating the q sequence length of each prompt. shape: ``[batch_size]``.
1622
+ If not provided, will be set to the same value as ``seq_lens``.
1608
1623
block_tables: Optional[torch.Tensor]
1609
1624
A uint32 2D tensor indicating the block table of each prompt. shape: ``[batch_size, max_num_blocks_per_seq]``.
1625
+ max_token_per_sequence: Optional[int],
1626
+ Required for cudnn backend. This is the scalar max token length of each sequence.
1627
+ max_sequence_kv: Optional[int],
1628
+ Required for cudnn backend. This is the scalar max sequence length of each sequence in kv cache.
1610
1629
1611
1630
Note
1612
1631
----
@@ -1655,22 +1674,28 @@ def plan(
1655
1674
self ._max_item_len_ptr = max_item_len_ptr
1656
1675
1657
1676
# NOTE(Zihao): only required if qo_indptr/paged_kv_indptr are device tensors
1658
- qo_indptr_host = qo_indptr .to ("cpu" )
1659
- paged_kv_indptr_host = paged_kv_indptr .to ("cpu" )
1660
- paged_kv_last_page_len_host = paged_kv_last_page_len .to ("cpu" )
1661
- if seq_lens is None :
1662
- kv_lens_arr_host = get_seq_lens (
1663
- paged_kv_indptr_host , paged_kv_last_page_len_host , page_size
1664
- )
1677
+ if max_token_per_sequence is not None :
1678
+ self ._max_q_len = max_token_per_sequence
1665
1679
else :
1666
- kv_lens_arr_host = seq_lens .cpu ()
1667
- self ._kv_lens_buffer [: len (kv_lens_arr_host )].copy_ (
1668
- kv_lens_arr_host , non_blocking = non_blocking
1669
- )
1670
- self ._max_q_len = max (qo_indptr_host ).item ()
1671
- self ._max_kv_len = max (kv_lens_arr_host ).item ()
1680
+ qo_indptr_host = qo_indptr .to ("cpu" )
1681
+ self ._max_q_len = max (qo_indptr_host ).item ()
1682
+ total_num_rows = qo_indptr_host [- 1 ]
1672
1683
1673
- total_num_rows = qo_indptr_host [- 1 ]
1684
+ if max_sequence_kv is not None :
1685
+ self ._max_kv_len = max_sequence_kv
1686
+ else :
1687
+ paged_kv_indptr_host = paged_kv_indptr .to ("cpu" )
1688
+ paged_kv_last_page_len_host = paged_kv_last_page_len .to ("cpu" )
1689
+ if seq_lens is None :
1690
+ kv_lens_arr_host = get_seq_lens (
1691
+ paged_kv_indptr_host , paged_kv_last_page_len_host , page_size
1692
+ )
1693
+ else :
1694
+ kv_lens_arr_host = seq_lens .cpu ().flatten ()
1695
+ self ._kv_lens_buffer [: len (kv_lens_arr_host )].copy_ (
1696
+ kv_lens_arr_host , non_blocking = non_blocking
1697
+ )
1698
+ self ._max_kv_len = max (kv_lens_arr_host ).item ()
1674
1699
1675
1700
if self .is_cuda_graph_enabled :
1676
1701
if self ._max_total_num_rows is None :
@@ -1759,23 +1784,23 @@ def plan(
1759
1784
q_data_type ,
1760
1785
kv_data_type ,
1761
1786
)
1787
+ if self ._backend != "cudnn" :
1788
+ get_module_args = (
1789
+ q_data_type ,
1790
+ kv_data_type ,
1791
+ q_data_type ,
1792
+ paged_kv_indptr .dtype ,
1793
+ head_dim_qk ,
1794
+ head_dim_vo ,
1795
+ PosEncodingMode [pos_encoding_mode ].value ,
1796
+ window_left >= 0 , # use_sliding_window
1797
+ logits_soft_cap > 0 , # use_logits_soft_cap
1798
+ use_fp16_qk_reduction ,
1799
+ )
1762
1800
1763
- get_module_args = (
1764
- q_data_type ,
1765
- kv_data_type ,
1766
- q_data_type ,
1767
- paged_kv_indptr .dtype ,
1768
- head_dim_qk ,
1769
- head_dim_vo ,
1770
- PosEncodingMode [pos_encoding_mode ].value ,
1771
- window_left >= 0 , # use_sliding_window
1772
- logits_soft_cap > 0 , # use_logits_soft_cap
1773
- use_fp16_qk_reduction ,
1774
- )
1775
-
1776
- self ._cached_module = get_batch_prefill_module (
1777
- self ._backend , * get_module_args
1778
- )
1801
+ self ._cached_module = get_batch_prefill_module (
1802
+ self ._backend , * get_module_args
1803
+ )
1779
1804
1780
1805
if self ._backend == "fa3" or self ._backend == "trtllm-gen" :
1781
1806
if page_size != 1 :
@@ -1793,7 +1818,7 @@ def plan(
1793
1818
].copy_ (vector_sparse_indptr_host , non_blocking = non_blocking )
1794
1819
paged_kv_indptr_host = vector_sparse_indptr_host
1795
1820
1796
- self ._block_tables : Optional [ torch . Tensor ] = block_tables
1821
+ self ._block_tables = block_tables
1797
1822
if self ._backend == "trtllm-gen" :
1798
1823
assert self ._kv_layout == "HND"
1799
1824
assert logits_soft_cap == 0.0
@@ -1811,28 +1836,32 @@ def plan(
1811
1836
block_id = paged_kv_indptr_host [0 ]
1812
1837
for i in range (batch_size ):
1813
1838
num_blocks_needed = blocks_per_seq [i ]
1839
+ assert self ._block_tables is not None , (
1840
+ "block_tables is not initialized"
1841
+ )
1814
1842
self ._block_tables [i , :num_blocks_needed ] = paged_kv_indices [
1815
1843
block_id : block_id + num_blocks_needed
1816
1844
]
1817
1845
block_id += num_blocks_needed
1818
1846
1819
- self ._plan_info = self ._cached_module .plan (
1820
- self ._float_workspace_buffer ,
1821
- self ._int_workspace_buffer ,
1822
- self ._pin_memory_int_workspace_buffer ,
1823
- qo_indptr_host ,
1824
- paged_kv_indptr_host ,
1825
- kv_lens_arr_host ,
1826
- self ._max_total_num_rows or total_num_rows ,
1827
- batch_size ,
1828
- num_qo_heads ,
1829
- num_kv_heads ,
1830
- page_size ,
1831
- self .is_cuda_graph_enabled ,
1832
- head_dim_qk ,
1833
- head_dim_vo ,
1834
- causal ,
1835
- )
1847
+ if self ._cached_module is not None :
1848
+ self ._plan_info = self ._cached_module .plan (
1849
+ self ._float_workspace_buffer ,
1850
+ self ._int_workspace_buffer ,
1851
+ self ._pin_memory_int_workspace_buffer ,
1852
+ qo_indptr_host ,
1853
+ paged_kv_indptr_host ,
1854
+ kv_lens_arr_host ,
1855
+ self ._max_total_num_rows or total_num_rows ,
1856
+ batch_size ,
1857
+ num_qo_heads ,
1858
+ num_kv_heads ,
1859
+ page_size ,
1860
+ self .is_cuda_graph_enabled ,
1861
+ head_dim_qk ,
1862
+ head_dim_vo ,
1863
+ causal ,
1864
+ )
1836
1865
1837
1866
self ._causal = causal
1838
1867
self ._pos_encoding_mode = pos_encoding_mode
@@ -1842,6 +1871,8 @@ def plan(
1842
1871
self ._sm_scale = sm_scale
1843
1872
self ._rope_scale = rope_scale
1844
1873
self ._rope_theta = rope_theta
1874
+ self ._seq_lens_kv = seq_lens
1875
+ self ._seq_lens_q = seq_lens_q if seq_lens_q is not None else seq_lens
1845
1876
1846
1877
begin_forward = plan
1847
1878
@@ -2042,62 +2073,90 @@ def run(
2042
2073
sparse_indices = self ._paged_kv_indices_buf
2043
2074
sparse_indptr = self ._paged_kv_indptr_buf
2044
2075
2045
- run_args = [
2046
- self ._float_workspace_buffer ,
2047
- self ._int_workspace_buffer ,
2048
- self ._plan_info ,
2049
- q ,
2050
- k_cache ,
2051
- v_cache ,
2052
- self ._qo_indptr_buf ,
2053
- sparse_indptr ,
2054
- sparse_indices ,
2055
- self ._paged_kv_last_page_len_buf ,
2056
- out ,
2057
- lse ,
2058
- mask_mode ,
2059
- TensorLayout [self ._kv_layout ].value ,
2060
- window_left ,
2061
- enable_pdl ,
2062
- ]
2063
- if self ._jit_module is not None :
2064
- run_args .extend (list (args ))
2076
+ if self ._backend == "cudnn" :
2077
+ if self ._seq_lens_q is not None and self ._seq_lens_q .dim () == 1 :
2078
+ self ._seq_lens_q = self ._seq_lens_q .reshape (self ._batch_size , 1 , 1 , 1 )
2079
+
2080
+ if self ._seq_lens_kv is not None and self ._seq_lens_kv .dim () == 1 :
2081
+ self ._seq_lens_kv = self ._seq_lens_kv .reshape (self ._batch_size , 1 , 1 , 1 )
2082
+
2083
+ cudnn_batch_prefill_with_kv_cache (
2084
+ q ,
2085
+ k_cache , # Need to be changed
2086
+ v_cache , # Need to be changed
2087
+ self ._sm_scale ,
2088
+ self ._float_workspace_buffer ,
2089
+ actual_seq_lens_q = self ._seq_lens_q ,
2090
+ actual_seq_lens_kv = self ._seq_lens_kv ,
2091
+ max_token_per_sequence = self ._max_q_len ,
2092
+ max_sequence_kv = self ._max_kv_len ,
2093
+ block_tables = self ._block_tables ,
2094
+ causal = self ._causal ,
2095
+ return_lse = return_lse ,
2096
+ batch_offsets_q = self ._qo_indptr_buf ,
2097
+ batch_offsets_o = self ._qo_indptr_buf ,
2098
+ out = out ,
2099
+ lse = lse ,
2100
+ )
2065
2101
else :
2066
- run_args += [
2067
- self ._custom_mask_buf ,
2068
- self ._mask_indptr_buf ,
2069
- _get_cache_alibi_slopes_buf (q .shape [1 ], q .device ),
2070
- self ._prefix_len_ptr ,
2071
- self ._token_pos_in_items_ptr ,
2072
- self ._max_item_len_ptr ,
2073
- logits_soft_cap ,
2074
- sm_scale ,
2075
- None , # scale_q, not supported yet
2076
- None , # scale_k
2077
- None , # scale_v
2078
- rope_scale ,
2079
- rope_theta ,
2080
- self ._token_pos_in_items_len ,
2081
- self ._num_qo_heads ,
2082
- self ._num_kv_heads ,
2083
- self ._block_tables ,
2084
- self ._kv_lens_buffer ,
2085
- page_size ,
2086
- self ._max_q_len ,
2087
- self ._max_kv_len ,
2088
- self ._batch_size ,
2102
+ assert self ._plan_info is not None , "plan info is not initialized"
2103
+ run_args = [
2104
+ self ._float_workspace_buffer ,
2105
+ self ._int_workspace_buffer ,
2106
+ self ._plan_info ,
2107
+ q ,
2108
+ k_cache ,
2109
+ v_cache ,
2089
2110
self ._qo_indptr_buf ,
2090
- self ._vector_sparse_indptr_buffer ,
2091
- sinks ,
2111
+ sparse_indptr ,
2112
+ sparse_indices ,
2113
+ self ._paged_kv_last_page_len_buf ,
2114
+ out ,
2115
+ lse ,
2116
+ mask_mode ,
2117
+ TensorLayout [self ._kv_layout ].value ,
2118
+ window_left ,
2119
+ enable_pdl ,
2092
2120
]
2093
-
2094
- self ._cached_module .paged_run (* run_args )
2095
- if v_scale is not None :
2096
- # TODO(Zihao): fused into kernel
2097
- if is_float8 (out ):
2098
- out = (out .to (torch .float32 ) * v_scale ).to (out .dtype )
2121
+ if self ._jit_module is not None :
2122
+ run_args .extend (list (args ))
2099
2123
else :
2100
- out *= v_scale
2124
+ run_args += [
2125
+ self ._custom_mask_buf ,
2126
+ self ._mask_indptr_buf ,
2127
+ _get_cache_alibi_slopes_buf (q .shape [1 ], q .device ),
2128
+ self ._prefix_len_ptr ,
2129
+ self ._token_pos_in_items_ptr ,
2130
+ self ._max_item_len_ptr ,
2131
+ logits_soft_cap ,
2132
+ sm_scale ,
2133
+ None , # scale_q, not supported yet
2134
+ None , # scale_k
2135
+ None , # scale_v
2136
+ rope_scale ,
2137
+ rope_theta ,
2138
+ self ._token_pos_in_items_len ,
2139
+ self ._num_qo_heads ,
2140
+ self ._num_kv_heads ,
2141
+ self ._block_tables ,
2142
+ self ._kv_lens_buffer ,
2143
+ page_size ,
2144
+ self ._max_q_len ,
2145
+ self ._max_kv_len ,
2146
+ self ._batch_size ,
2147
+ self ._qo_indptr_buf ,
2148
+ self ._vector_sparse_indptr_buffer ,
2149
+ sinks ,
2150
+ ]
2151
+
2152
+ assert self ._cached_module is not None , "cached module is not initialized"
2153
+ self ._cached_module .paged_run (* run_args )
2154
+ if v_scale is not None :
2155
+ # TODO(Zihao): fused into kernel
2156
+ if is_float8 (out ):
2157
+ out = (out .to (torch .float32 ) * v_scale ).to (out .dtype )
2158
+ else :
2159
+ out *= v_scale
2101
2160
return (out , lse ) if return_lse else out
2102
2161
2103
2162
run_return_lse = functools .partialmethod (run , return_lse = True )
@@ -2351,6 +2410,7 @@ def __init__(
2351
2410
self ._mask_indptr_buf = mask_indptr_buf
2352
2411
self ._max_total_num_rows = None
2353
2412
self ._backend = backend
2413
+ self ._cached_module = None
2354
2414
2355
2415
@property
2356
2416
def is_cuda_graph_enabled (self ) -> bool :
@@ -2621,6 +2681,7 @@ def plan(
2621
2681
)
2622
2682
self ._max_qo_len = torch .max (qo_indptr [1 :] - qo_indptr [:- 1 ]).item ()
2623
2683
else :
2684
+ assert self ._cached_module is not None , "cached module is not initialized"
2624
2685
self ._plan_info = self ._cached_module .plan (
2625
2686
self ._float_workspace_buffer ,
2626
2687
self ._int_workspace_buffer ,
@@ -2845,6 +2906,7 @@ def run(
2845
2906
self ._token_pos_in_items_len ,
2846
2907
]
2847
2908
2909
+ assert self ._cached_module is not None , "cached module is not initialized"
2848
2910
self ._cached_module .ragged_run (* run_args )
2849
2911
return (out , lse ) if return_lse else out
2850
2912
0 commit comments