@@ -89,23 +89,34 @@ def _vllm_layout_trans_kernel(
89
89
tl .store (k_values_ptr + kv_values_off , k_vals , mask = block_mask )
90
90
tl .store (v_values_ptr + kv_values_off , v_vals , mask = block_mask )
91
91
92
- def vllm_layout_trans (b_query_lens_loc , b_seq_lens_loc , block_table ,
93
- k_cache , v_cache , max_seq_len , k_scale , v_scale ,
94
- output_dtype , total_tokens ):
92
+ @torch .inference_mode ()
93
+ def vllm_layout_trans (b_query_lens_loc ,
94
+ b_seq_lens_loc ,
95
+ block_table ,
96
+ k_cache ,
97
+ v_cache ,
98
+ max_seq_len ,
99
+ k_scale ,
100
+ v_scale ,
101
+ output_dtype ,
102
+ total_tokens ,
103
+ k_values = None ,
104
+ v_values = None ):
95
105
H_KV = v_cache .shape [2 ]
96
106
D = v_cache .shape [3 ]
97
107
BLOCK_SIZE = v_cache .shape [1 ]
98
-
99
- k_values = torch .empty (
100
- (total_tokens , H_KV , D ),
101
- dtype = output_dtype ,
102
- device = k_cache .device ,
103
- )
104
- v_values = torch .empty (
105
- (total_tokens , H_KV , D ),
106
- dtype = output_dtype ,
107
- device = v_cache .device ,
108
- )
108
+ if k_values is None :
109
+ k_values = torch .empty (
110
+ (total_tokens , H_KV , D ),
111
+ dtype = output_dtype ,
112
+ device = k_cache .device ,
113
+ )
114
+ if v_values is None :
115
+ v_values = torch .empty (
116
+ (total_tokens , H_KV , D ),
117
+ dtype = output_dtype ,
118
+ device = v_cache .device ,
119
+ )
109
120
110
121
grid = (block_table .shape [0 ],
111
122
(max_seq_len + BLOCK_SIZE - 1 ) // BLOCK_SIZE )
@@ -148,13 +159,14 @@ def flash_attn_varlen_func_impl(
148
159
block_table : torch .Tensor ,
149
160
k_scale : torch .Tensor ,
150
161
v_scale : torch .Tensor ,
151
- total_tokens : int = 0 ,
162
+ total_tokens : int ,
163
+ k_values : Optional [torch .Tensor ] = None ,
164
+ v_values : Optional [torch .Tensor ] = None ,
152
165
) -> torch .Tensor :
153
- if total_tokens == 0 :
154
- total_tokens = int (cu_seqlens_k [- 1 ].item ())
155
166
k , v = vllm_layout_trans (cu_seqlens_q , cu_seqlens_k , block_table ,
156
167
k_cache , v_cache , max_seqlen_k , k_scale ,
157
- v_scale , q .dtype , total_tokens )
168
+ v_scale , q .dtype , total_tokens , k_values ,
169
+ v_values )
158
170
159
171
output = aiter .flash_attn_varlen_func (
160
172
q = q ,
@@ -222,24 +234,27 @@ class AiterFlashAttentionMetadata:
222
234
seq_lens : torch .Tensor
223
235
slot_mapping : torch .Tensor
224
236
block_table : torch .Tensor
225
- cu_seq_lens : Optional [torch .Tensor ]
226
237
227
238
# For cascade attention.
228
239
use_cascade : bool
229
240
common_prefix_len : int
230
- total_tokens : int
241
+ k_buffer : torch .Tensor
242
+ v_buffer : torch .Tensor
243
+ workspace_buffer : torch .Tensor
244
+ cu_seq_lens : torch .Tensor
231
245
232
246
233
247
class AiterFlashAttentionMetadataBuilder (
234
248
AttentionMetadataBuilder [AiterFlashAttentionMetadata ]):
235
- cudagraph_support = AttentionCGSupport .ALWAYS
249
+ cudagraph_support = AttentionCGSupport .UNIFORM_SINGLE_TOKEN_DECODE
236
250
237
251
def __init__ (self , kv_cache_spec : AttentionSpec , layer_names : list [str ],
238
252
vllm_config : VllmConfig , device : torch .device ):
239
253
self .vllm_config = vllm_config
240
254
self .model_config = vllm_config .model_config
241
255
self .parallel_config = vllm_config .parallel_config
242
256
self .cache_config = vllm_config .cache_config
257
+ self .compilation_config = vllm_config .compilation_config
243
258
self .device = device
244
259
245
260
self .num_heads_q = self .model_config .get_num_attention_heads (
@@ -249,53 +264,68 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
249
264
self .headdim = self .model_config .get_head_size ()
250
265
self .block_size = kv_cache_spec .block_size
251
266
self .kv_cache_spec = kv_cache_spec
252
- # Sliding window size to be used with the AOT scheduler will be
253
- # populated on first build() call.
254
- self .aot_sliding_window : Optional [tuple [int , int ]] = None
255
- self .total_tokens : int = 0
256
-
257
- def build_for_cudagraph_capture (
258
- self , common_attn_metadata : CommonAttentionMetadata ):
259
- self .total_tokens = self .model_config .max_model_len \
260
- * self .vllm_config .scheduler_config .max_num_partial_prefills
261
- res = self .build (common_prefix_len = 0 ,
262
- common_attn_metadata = common_attn_metadata )
263
- self .total_tokens = 0
264
- return res
265
267
266
268
def build (self ,
267
269
common_prefix_len : int ,
268
270
common_attn_metadata : CommonAttentionMetadata ,
269
271
fast_build : bool = False ) -> 'AiterFlashAttentionMetadata' :
270
-
271
272
num_actual_tokens = common_attn_metadata .num_actual_tokens
272
273
max_query_len = common_attn_metadata .max_query_len
273
- max_seq_len = common_attn_metadata . max_seq_len
274
+
274
275
query_start_loc = common_attn_metadata .query_start_loc
275
276
seq_lens = common_attn_metadata .seq_lens
276
277
block_table_tensor = common_attn_metadata .block_table_tensor
277
278
slot_mapping = common_attn_metadata .slot_mapping
278
- if max_query_len > 1 :
279
- # We pre-compute cumulative seq len needed for prefill attention
280
- # here to avoid recomputing it for every layer
281
- cu_seq_lens = torch .zeros (seq_lens .shape [0 ] + 1 ,
282
- dtype = torch .int32 ,
283
- device = seq_lens .device )
284
- torch .cumsum (seq_lens ,
285
- dim = 0 ,
286
- dtype = cu_seq_lens .dtype ,
287
- out = cu_seq_lens [1 :])
288
- num_actual_kv_tokens = int (cu_seq_lens [- 1 ].item ())
289
- else :
290
- cu_seq_lens = None
291
- num_actual_kv_tokens = 0
292
-
293
- def schedule (batch_size , cu_query_lens , max_query_len , seqlens ,
294
- max_seq_len , causal ):
295
- return None
279
+ num_seqs = common_attn_metadata .num_reqs
280
+ max_seq_len = common_attn_metadata .max_seq_len
281
+ num_actual_kv_tokens = int (seq_lens .sum ())
296
282
297
283
use_cascade = common_prefix_len > 0
298
284
285
+ nbytes_per_qo_elem = torch .finfo (self .model_config .dtype ).bits // 8
286
+ max_num_partitions = (max_seq_len + _PARTITION_SIZE_ROCM -
287
+ 1 ) // _PARTITION_SIZE_ROCM
288
+ empty_gpu_memory , total_gpu_memory = torch .cuda .mem_get_info ()
289
+ k_buffer = None
290
+ v_buffer = None
291
+ workspace_buffer = None
292
+ cu_seq_lens = None
293
+ if max_query_len > 1 :
294
+ required_memory = num_actual_kv_tokens * \
295
+ self .num_heads_kv * self .headdim * 2 * 2
296
+ if required_memory >= empty_gpu_memory :
297
+ raise ValueError (
298
+ f"Not enough GPU memory to allocate k_buffer and v_buffer. "
299
+ f"Required: { required_memory } bytes, "
300
+ f"Available: { empty_gpu_memory } bytes, please reduce the "
301
+ f"max_num_seqs or max_model_len." )
302
+ if not torch .cuda .graphs .is_current_stream_capturing ():
303
+ k_buffer = torch .empty (
304
+ (num_actual_kv_tokens , self .num_heads_kv , self .headdim ),
305
+ dtype = self .model_config .dtype ,
306
+ device = self .device ,
307
+ )
308
+ v_buffer = torch .empty (
309
+ (num_actual_kv_tokens , self .num_heads_kv , self .headdim ),
310
+ dtype = self .model_config .dtype ,
311
+ device = self .device ,
312
+ )
313
+ cu_seq_lens = torch .zeros (seq_lens .shape [0 ] + 1 ,
314
+ dtype = torch .int32 ,
315
+ device = self .device )
316
+ torch .cumsum (seq_lens ,
317
+ dim = 0 ,
318
+ dtype = cu_seq_lens .dtype ,
319
+ out = cu_seq_lens [1 :])
320
+
321
+ workspace_buffer = torch .empty (
322
+ (num_seqs * self .num_heads_q * max_num_partitions * self .headdim ) *
323
+ nbytes_per_qo_elem + 2 *
324
+ (num_seqs * self .num_heads_q * max_num_partitions ) * 4 ,
325
+ dtype = torch .uint8 ,
326
+ device = self .device ,
327
+ )
328
+
299
329
attn_metadata = AiterFlashAttentionMetadata (
300
330
num_actual_tokens = num_actual_tokens ,
301
331
num_actual_kv_tokens = num_actual_kv_tokens ,
@@ -305,10 +335,12 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
305
335
seq_lens = seq_lens ,
306
336
block_table = block_table_tensor ,
307
337
slot_mapping = slot_mapping ,
308
- cu_seq_lens = cu_seq_lens ,
309
338
use_cascade = use_cascade ,
310
339
common_prefix_len = common_prefix_len ,
311
- total_tokens = self .total_tokens ,
340
+ k_buffer = k_buffer ,
341
+ v_buffer = v_buffer ,
342
+ workspace_buffer = workspace_buffer ,
343
+ cu_seq_lens = cu_seq_lens ,
312
344
)
313
345
return attn_metadata
314
346
@@ -381,6 +413,7 @@ def __init__(
381
413
logits_soft_cap : Optional [float ] = None ,
382
414
attn_type : AttentionType = AttentionType .DECODER ,
383
415
kv_sharing_target_layer_name : Optional [int ] = None ,
416
+ sinks : Optional [torch .Tensor ] = None ,
384
417
) -> None :
385
418
self .num_heads = num_heads
386
419
self .head_size = head_size
@@ -410,6 +443,9 @@ def __init__(
410
443
"encoder/decoder cross-attention "
411
444
"are not implemented for "
412
445
"FlashAttentionImpl" )
446
+ self .sinks = sinks
447
+ if self .sinks is not None :
448
+ raise NotImplementedError ("Sinks are not supported for ROCM AITER" )
413
449
414
450
def forward (
415
451
self ,
@@ -491,6 +527,17 @@ def forward(
491
527
block_table = attn_metadata .block_table
492
528
493
529
if max_seqlen_q > 1 :
530
+ if attn_metadata .cu_seq_lens is None :
531
+ cu_seq_lens = torch .zeros (seqused_k .shape [0 ] + 1 ,
532
+ dtype = torch .int32 ,
533
+ device = query .device )
534
+ torch .cumsum (seqused_k ,
535
+ dim = 0 ,
536
+ dtype = cu_seq_lens .dtype ,
537
+ out = cu_seq_lens [1 :])
538
+ else :
539
+ cu_seq_lens = attn_metadata .cu_seq_lens
540
+
494
541
torch .ops .vllm .flash_attn_varlen_func (
495
542
query [:num_actual_tokens ],
496
543
key_cache ,
@@ -503,25 +550,29 @@ def forward(
503
550
alibi_slopes = self .alibi_slopes ,
504
551
window_size = self .sliding_window ,
505
552
block_table = block_table ,
506
- cu_seqlens_k = attn_metadata . cu_seq_lens ,
553
+ cu_seqlens_k = cu_seq_lens ,
507
554
k_scale = layer ._k_scale ,
508
555
v_scale = layer ._v_scale ,
509
556
total_tokens = attn_metadata .num_actual_kv_tokens ,
557
+ k_values = attn_metadata .k_buffer ,
558
+ v_values = attn_metadata .v_buffer ,
510
559
)
511
-
512
- _ , num_heads , head_size = query .shape
513
- nbytes_per_qo_elem = torch .finfo (query .dtype ).bits // 8
514
- num_seqs = seqused_k .shape [0 ]
515
- max_num_partitions = (max_seqlen_k + _PARTITION_SIZE_ROCM -
516
- 1 ) // _PARTITION_SIZE_ROCM
517
-
518
- workspace_buffer = torch .empty (
519
- (num_seqs * num_heads * max_num_partitions * head_size ) *
520
- nbytes_per_qo_elem + 2 *
521
- (num_seqs * num_heads * max_num_partitions ) * 4 ,
522
- dtype = torch .uint8 ,
523
- device = output .device ,
524
- )
560
+ if attn_metadata .workspace_buffer is None :
561
+ _ , num_heads , head_size = query .shape
562
+ nbytes_per_qo_elem = torch .finfo (query .dtype ).bits // 8
563
+ num_seqs = seqused_k .shape [0 ]
564
+ max_num_partitions = (max_seqlen_k + _PARTITION_SIZE_ROCM -
565
+ 1 ) // _PARTITION_SIZE_ROCM
566
+
567
+ workspace_buffer = torch .empty (
568
+ (num_seqs * num_heads * max_num_partitions * head_size ) *
569
+ nbytes_per_qo_elem + 2 *
570
+ (num_seqs * num_heads * max_num_partitions ) * 4 ,
571
+ dtype = torch .uint8 ,
572
+ device = output .device ,
573
+ )
574
+ else :
575
+ workspace_buffer = attn_metadata .workspace_buffer
525
576
526
577
torch .ops .aiter .paged_attention_v1 (
527
578
output [:num_actual_tokens ],
@@ -543,6 +594,12 @@ def forward(
543
594
None ,
544
595
_PARTITION_SIZE_ROCM ,
545
596
)
597
+ if workspace_buffer is not None :
598
+ workspace_buffer .zero_ ()
599
+ if attn_metadata .k_buffer is not None :
600
+ attn_metadata .k_buffer .zero_ ()
601
+ if attn_metadata .v_buffer is not None :
602
+ attn_metadata .v_buffer .zero_ ()
546
603
return output
547
604
else :
548
605
raise NotImplementedError (
0 commit comments