1
1
import functools
2
+ from enum import Enum
2
3
from typing import Optional
3
4
4
5
import torch
5
6
6
7
from ..jit import get_cudnn_fmha_gen_module
7
8
9
+ try :
10
+ import cudnn
11
+
12
+ CUDNN_AVAILABLE = True
13
+ except ImportError :
14
+ cudnn = None
15
+ CUDNN_AVAILABLE = False
16
+
17
+ # Global cudnn handle. need to make it per device in future
18
+ _cudnn_handle = None
19
+
20
+
21
+ def _create_cudnn_handle (stream : torch .cuda .Stream ):
22
+ global _cudnn_handle
23
+ if _cudnn_handle is None :
24
+ _cudnn_handle = cudnn .create_handle ()
25
+ cudnn .set_stream (_cudnn_handle , stream .cuda_stream )
26
+ return _cudnn_handle
27
+
28
+
29
+ # Tensor ids
30
+ class UIDs (Enum ):
31
+ RESERVED_INVALID_UID = 0
32
+
33
+ Q_UID = 1 # Query tensor
34
+ K_UID = 2 # Key cache tensor
35
+ V_UID = 3 # Value cache tensor
36
+
37
+ ACTUAL_SEQ_LENS_Q_UID = 100 # Actual sequence lengths for query tensor
38
+ ACTUAL_SEQ_LENS_KV_UID = 101 # Actual sequence lengths for key/value tensor
39
+
40
+ BLOCK_TABLES_UID = 200 # Block tables tensor
41
+ BLOCK_TABLES_K_UID = 201 # Block tables tensor for key
42
+ BLOCK_TABLES_V_UID = 202 # Block tables tensor for value
43
+
44
+ RAGGED_Q_UID = 50 # Ragged query tensor
45
+ RAGGED_O_UID = 51 # Ragged output tensor
46
+ RAGGED_STATS_UID = 52 # Ragged stats tensor
47
+
48
+ O_UID = 1000 # Output tensor
49
+ STATS_UID = 1001 # Stats tensor
50
+
51
+
52
+ def _sdpa_decode_key_fn (
53
+ q : torch .Tensor ,
54
+ k_cache : torch .Tensor ,
55
+ v_cache : torch .Tensor ,
56
+ scale : float ,
57
+ * ,
58
+ max_sequence_kv : int ,
59
+ block_size : Optional [int ] = 1 ,
60
+ actual_seq_lens_q : Optional [torch .Tensor ] = None ,
61
+ actual_seq_lens_kv : Optional [torch .Tensor ] = None ,
62
+ block_tables : Optional [torch .Tensor ] = None ,
63
+ batch_offsets_q : Optional [torch .Tensor ] = None ,
64
+ batch_offsets_o : Optional [torch .Tensor ] = None ,
65
+ ):
66
+ return (
67
+ "decode" ,
68
+ max_sequence_kv ,
69
+ tuple (q .shape ),
70
+ tuple (k_cache .shape ),
71
+ )
72
+
73
+
74
+ if CUDNN_AVAILABLE :
75
+
76
+ @cudnn .jit (heur_modes = [cudnn .heur_mode .A ])
77
+ @cudnn .graph_cache (key_fn = _sdpa_decode_key_fn )
78
+ def _build_decode_graph (
79
+ q : torch .Tensor ,
80
+ k_cache : torch .Tensor ,
81
+ v_cache : torch .Tensor ,
82
+ scale : float ,
83
+ * ,
84
+ max_sequence_kv : int ,
85
+ block_size : Optional [int ] = 1 ,
86
+ actual_seq_lens_q : Optional [torch .Tensor ] = None ,
87
+ actual_seq_lens_kv : Optional [torch .Tensor ] = None ,
88
+ block_tables : Optional [torch .Tensor ] = None ,
89
+ batch_offsets_q : Optional [torch .Tensor ] = None ,
90
+ batch_offsets_o : Optional [torch .Tensor ] = None ,
91
+ ):
92
+ handle = _create_cudnn_handle (torch .cuda .current_stream ())
93
+
94
+ # WAR: override batch offsets for now, as it leads to a poor performance
95
+ batch_offsets_q = None
96
+ batch_offsets_o = None
97
+
98
+ with cudnn .graph (handle ) as (g , _ ):
99
+ if q .dim () == 3 :
100
+ s_qo = 1
101
+ b , h_qo , d_qk = q .shape [0 ], q .shape [1 ], q .shape [2 ]
102
+ elif q .dim () == 4 :
103
+ b , h_qo , s_qo , d_qk = (
104
+ q .shape [0 ],
105
+ q .shape [1 ],
106
+ q .shape [2 ],
107
+ q .shape [3 ],
108
+ )
109
+ else :
110
+ raise ValueError (f"q must have 3 or 4 dimensions, got { q .dim ()} " )
111
+
112
+ assert s_qo == 1 , "q must have a sequence length of 1"
113
+ assert k_cache .dim () == 4 , "k_cache must have 4 dimensions"
114
+
115
+ h_kv = k_cache .shape [1 ]
116
+ s_kv = max_sequence_kv
117
+ d_vo = v_cache .shape [3 ]
118
+
119
+ cudnn_q = g .tensor (
120
+ name = "q" ,
121
+ dim = (b , h_qo , s_qo , d_qk ),
122
+ stride = (h_qo * d_qk , d_qk , d_qk * h_qo , 1 ),
123
+ data_type = cudnn .data_type .BFLOAT16 ,
124
+ )
125
+ if batch_offsets_q is not None :
126
+ ragged_q = g .tensor_like (batch_offsets_q )
127
+ ragged_q .set_uid (UIDs .RAGGED_Q_UID .value )
128
+ cudnn_q .set_ragged_offset (ragged_q )
129
+
130
+ cudnn_k_cache = g .tensor_like (k_cache )
131
+ cudnn_v_cache = g .tensor_like (v_cache )
132
+
133
+ cudnn_q .set_uid (UIDs .Q_UID .value )
134
+ cudnn_k_cache .set_uid (UIDs .K_UID .value )
135
+ cudnn_v_cache .set_uid (UIDs .V_UID .value )
136
+
137
+ if block_tables is not None :
138
+ nd_block_tables = block_tables .reshape (
139
+ block_tables .shape [0 ], 1 , block_tables .shape [1 ], 1
140
+ )
141
+ cudnn_k_block_tables = g .tensor_like (nd_block_tables )
142
+ cudnn_k_block_tables .set_uid (UIDs .BLOCK_TABLES_K_UID .value )
143
+
144
+ cudnn_v_block_tables = g .tensor_like (nd_block_tables )
145
+ cudnn_v_block_tables .set_uid (UIDs .BLOCK_TABLES_V_UID .value )
146
+
147
+ if actual_seq_lens_q is not None :
148
+ cudnn_actual_seq_lens_q = g .tensor_like (actual_seq_lens_q )
149
+ cudnn_actual_seq_lens_q .set_uid (UIDs .ACTUAL_SEQ_LENS_Q_UID .value )
150
+
151
+ if actual_seq_lens_kv is not None :
152
+ cudnn_actual_seq_lens_kv = g .tensor_like (actual_seq_lens_kv )
153
+ cudnn_actual_seq_lens_kv .set_uid (UIDs .ACTUAL_SEQ_LENS_KV_UID .value )
154
+ cudnn_actual_seq_lens_kv .set_is_pass_by_value (False )
155
+
156
+ padding_mask = actual_seq_lens_kv is not None
157
+
158
+ O , _ = g .sdpa (
159
+ name = "sdpa" ,
160
+ q = cudnn_q ,
161
+ k = cudnn_k_cache ,
162
+ v = cudnn_v_cache ,
163
+ seq_len_q = (
164
+ cudnn_actual_seq_lens_q if actual_seq_lens_q is not None else None
165
+ ),
166
+ seq_len_kv = (
167
+ cudnn_actual_seq_lens_kv if actual_seq_lens_kv is not None else None
168
+ ),
169
+ use_padding_mask = padding_mask ,
170
+ is_inference = True ,
171
+ attn_scale = scale ,
172
+ paged_attention_k_table = cudnn_k_block_tables ,
173
+ paged_attention_v_table = cudnn_v_block_tables ,
174
+ paged_attention_max_seq_len_kv = max_sequence_kv ,
175
+ compute_data_type = cudnn .data_type .FLOAT ,
176
+ )
177
+
178
+ if batch_offsets_o is not None :
179
+ ragged_o = g .tensor_like (batch_offsets_o )
180
+ ragged_o .set_uid (UIDs .RAGGED_O_UID .value )
181
+ O .set_ragged_offset (ragged_o )
182
+
183
+ O .set_uid (UIDs .O_UID .value ).set_output (True ).set_dim (
184
+ [b , h_qo , s_qo , d_vo ]
185
+ ).set_stride ([d_vo * h_qo , d_vo , d_vo * h_qo , 1 ]).set_data_type (
186
+ cudnn .data_type .BFLOAT16
187
+ )
188
+
189
+ tensors_to_return = [cudnn_q , cudnn_k_cache , cudnn_v_cache , O ]
190
+
191
+ if actual_seq_lens_q is not None :
192
+ tensors_to_return .append (cudnn_actual_seq_lens_q )
193
+ if actual_seq_lens_kv is not None :
194
+ tensors_to_return .append (cudnn_actual_seq_lens_kv )
195
+
196
+ return g , tensors_to_return
197
+
198
+
199
+ def _batch_decode_with_kv_cache (
200
+ q : torch .Tensor ,
201
+ k_cache : torch .Tensor ,
202
+ v_cache : torch .Tensor ,
203
+ scale : float ,
204
+ workspace_buffer : torch .Tensor ,
205
+ * ,
206
+ max_sequence_kv : int ,
207
+ actual_seq_lens_q : Optional [torch .Tensor ] = None ,
208
+ actual_seq_lens_kv : Optional [torch .Tensor ] = None ,
209
+ block_tables : Optional [torch .Tensor ] = None ,
210
+ block_size : Optional [int ] = 1 ,
211
+ batch_offsets_q : Optional [torch .Tensor ] = None ,
212
+ batch_offsets_o : Optional [torch .Tensor ] = None ,
213
+ batch_offsets_k : Optional [torch .Tensor ] = None ,
214
+ batch_offsets_v : Optional [torch .Tensor ] = None ,
215
+ out : torch .Tensor ,
216
+ ) -> torch .Tensor :
217
+
218
+ graph , tensors = _build_decode_graph (
219
+ q = q ,
220
+ k_cache = k_cache ,
221
+ v_cache = v_cache ,
222
+ scale = scale ,
223
+ max_sequence_kv = max_sequence_kv ,
224
+ actual_seq_lens_q = actual_seq_lens_q ,
225
+ actual_seq_lens_kv = actual_seq_lens_kv ,
226
+ block_tables = block_tables ,
227
+ block_size = block_size ,
228
+ batch_offsets_q = batch_offsets_q if batch_offsets_q is not None else None ,
229
+ batch_offsets_o = batch_offsets_q if batch_offsets_q is not None else None ,
230
+ )
231
+
232
+ handle_ = _create_cudnn_handle (torch .cuda .current_stream ())
233
+
234
+ var_map = {
235
+ UIDs .Q_UID .value : q ,
236
+ UIDs .K_UID .value : k_cache ,
237
+ UIDs .V_UID .value : v_cache ,
238
+ UIDs .O_UID .value : out ,
239
+ }
240
+ if actual_seq_lens_q is not None :
241
+ var_map [UIDs .ACTUAL_SEQ_LENS_Q_UID .value ] = actual_seq_lens_q
242
+ if actual_seq_lens_kv is not None :
243
+ var_map [UIDs .ACTUAL_SEQ_LENS_KV_UID .value ] = actual_seq_lens_kv
244
+
245
+ if batch_offsets_q is not None :
246
+ var_map [UIDs .RAGGED_Q_UID .value ] = batch_offsets_q
247
+ if batch_offsets_o is not None :
248
+ var_map [UIDs .RAGGED_O_UID .value ] = batch_offsets_o
249
+
250
+ if block_tables is not None :
251
+ var_map [UIDs .BLOCK_TABLES_K_UID .value ] = block_tables
252
+ var_map [UIDs .BLOCK_TABLES_V_UID .value ] = block_tables
253
+
254
+ graph .execute (var_map , workspace = workspace_buffer , handle = handle_ )
255
+
256
+ return out
257
+
8
258
9
259
def cudnn_batch_decode_with_kv_cache (
10
260
q : torch .Tensor ,
@@ -37,7 +287,6 @@ def cudnn_batch_decode_with_kv_cache(
37
287
is_cuda_graph_compatible: Whether the decode operation is compatible with CUDA graph
38
288
batch_offsets: Optional batch offsets tensor of shape (batch_size,) on GPU
39
289
out: Optional pre-allocated output tensor
40
- lse: Optional pre-allocated tensor for log-sum-exp values if return_lse is True else returns None
41
290
batch_offsets_q: Optional batch offsets for query tensor of shape (batch_size,) on GPU
42
291
batch_offsets_o: Optional batch offsets for output tensor of shape (batch_size,) on GPU
43
292
batch_offsets_k: Optional batch offsets for key tensor of shape (batch_size,) on GPU
@@ -53,30 +302,51 @@ def cudnn_batch_decode_with_kv_cache(
53
302
"""
54
303
55
304
bs = q .shape [0 ]
56
- s_q = 1
57
305
h_qo = q .shape [1 ]
58
306
d_vo = v_cache .shape [3 ]
59
307
60
308
if out is None :
61
309
out = torch .empty (bs , h_qo , d_vo , device = q .device , dtype = q .dtype )
62
310
63
- actual_seq_lens_kv_gpu = actual_seq_lens_kv .to (q .device , non_blocking = True )
311
+ if not CUDNN_AVAILABLE :
312
+ actual_seq_lens_kv_gpu = actual_seq_lens_kv .to (q .device , non_blocking = True )
64
313
65
- run_func = get_cudnn_fmha_gen_module ().decode
66
- run_func (
67
- max_sequence_kv ,
68
- q ,
69
- k_cache ,
70
- v_cache ,
71
- scale ,
72
- workspace_buffer ,
73
- actual_seq_lens_kv ,
74
- actual_seq_lens_kv_gpu ,
75
- block_tables ,
76
- out ,
77
- batch_offsets_q ,
78
- batch_offsets_o ,
79
- is_cuda_graph_compatible ,
80
- )
314
+ run_func = get_cudnn_fmha_gen_module ().decode
315
+ run_func (
316
+ max_sequence_kv ,
317
+ q ,
318
+ k_cache ,
319
+ v_cache ,
320
+ scale ,
321
+ workspace_buffer ,
322
+ actual_seq_lens_kv ,
323
+ actual_seq_lens_kv_gpu ,
324
+ block_tables ,
325
+ out ,
326
+ batch_offsets_q ,
327
+ batch_offsets_o ,
328
+ is_cuda_graph_compatible ,
329
+ )
330
+ else :
331
+ actual_seq_lens_q = torch .ones (
332
+ (bs , 1 , 1 , 1 ), device = q .device , dtype = torch .int32
333
+ )
334
+ block_size = k_cache .shape [2 ]
335
+
336
+ _batch_decode_with_kv_cache (
337
+ q = q ,
338
+ k_cache = k_cache ,
339
+ v_cache = v_cache ,
340
+ scale = scale ,
341
+ workspace_buffer = workspace_buffer ,
342
+ max_sequence_kv = max_sequence_kv ,
343
+ actual_seq_lens_q = actual_seq_lens_q ,
344
+ actual_seq_lens_kv = actual_seq_lens_kv ,
345
+ block_tables = block_tables ,
346
+ batch_offsets_q = batch_offsets_q ,
347
+ batch_offsets_o = batch_offsets_o ,
348
+ block_size = block_size ,
349
+ out = out ,
350
+ )
81
351
82
352
return out
0 commit comments