4
4
from __future__ import annotations
5
5
6
6
from dataclasses import dataclass
7
- from typing import TYPE_CHECKING , Any , Optional
7
+ from typing import TYPE_CHECKING , Any , ClassVar , Optional
8
8
9
9
import torch
10
10
from flashinfer import (BatchDecodeWithPagedKVCacheWrapper ,
@@ -218,22 +218,43 @@ def __post_init__(self):
218
218
219
219
220
220
class FlashInferMetadataBuilder (AttentionMetadataBuilder [FlashInferMetadata ]):
221
+ full_cudagraph_supported : ClassVar [bool ] = True
221
222
222
223
def __init__ (self , runner : GPUModelRunner , kv_cache_spec : AttentionSpec ,
223
224
block_table : BlockTable ):
224
225
self .runner = runner
226
+ self .vllm_config = runner .vllm_config
225
227
self ._workspace_buffer = None
226
228
self ._prefill_wrapper = None # Wrapper for prefill/append
227
- self ._decode_wrapper = None # Wrapper for decode
229
+ self ._decode_wrapper = None # Wrapper for decode (general shape)
230
+ self .enable_cuda_graph = self .vllm_config .compilation_config .full_cuda_graph
231
+ if self .enable_cuda_graph :
232
+ # For full cudagraph capture, one `decode_wrapper` for each batch
233
+ # size is needed for FlashInfer.
234
+ self ._decode_wrappers_cudagraph : dict [int , BatchDecodeWithPagedKVCacheWrapper ] = {}
235
+ self ._decode_cudagraph_max_bs = min (runner .max_num_reqs ,
236
+ runner .cudagraph_batch_sizes [- 1 ])
237
+
228
238
self ._cascade_wrapper = None # Wrapper for cascade attention
229
239
230
240
# Global hyperparameters shared by all attention layers
231
241
self .global_hyperparameters : Optional [PerLayerParameters ] = None
232
242
233
- self .vllm_config = runner .vllm_config
234
243
self .kv_cache_spec = kv_cache_spec
235
244
self .block_table = block_table
236
245
246
+ # Preparing persistent buffers
247
+ self .paged_kv_indptr = torch .zeros (
248
+ self .runner .max_num_reqs + 1 ,
249
+ dtype = torch .int32 ,
250
+ device = self .runner .device )
251
+ self .paged_kv_indices = torch .zeros (
252
+ block_table .get_device_tensor ().numel (), # max num pages possible
253
+ dtype = torch .int32 , device = self .runner .device )
254
+ self .paged_kv_last_page_len = torch .zeros (
255
+ self .runner .max_num_reqs ,
256
+ dtype = torch .int32 , device = self .runner .device )
257
+
237
258
def reorder_batch (self , input_batch : InputBatch ,
238
259
scheduler_output : SchedulerOutput ) -> bool :
239
260
# We now want to reorder the batch so that the "decode" requests are and
@@ -307,19 +328,47 @@ def _get_prefill_wrapper(self):
307
328
self ._get_workspace_buffer (), get_kv_cache_layout ())
308
329
return self ._prefill_wrapper
309
330
310
- def _get_decode_wrapper (self ):
311
- if self ._decode_wrapper is None :
331
+ def _get_decode_wrapper (self , batch_size : int , pure_decode : bool = False ):
332
+ use_cudagraph = (self .enable_cuda_graph and pure_decode
333
+ and batch_size <= self ._decode_cudagraph_max_bs )
334
+
335
+ if use_cudagraph :
336
+ decode_wrapper = self ._decode_wrappers_cudagraph .get (batch_size , None )
337
+ else :
338
+ decode_wrapper = self ._decode_wrapper
339
+
340
+ if decode_wrapper is None :
312
341
num_qo_heads = (self .runner .model_config .get_num_attention_heads (
313
342
self .runner .parallel_config ))
314
343
num_kv_heads = self .runner .model_config .get_num_kv_heads (
315
344
self .runner .parallel_config )
316
345
use_tensor_cores = envs .VLLM_FLASHINFER_FORCE_TENSOR_CORES or (
317
346
num_qo_heads // num_kv_heads > 4 )
318
- self ._decode_wrapper = BatchDecodeWithPagedKVCacheWrapper (
347
+
348
+ if use_cudagraph :
349
+ paged_kv_indptr = self .paged_kv_indptr [:batch_size + 1 ]
350
+ paged_kv_indices = self .paged_kv_indices
351
+ paged_kv_last_page_len = self .paged_kv_last_page_len [:batch_size ]
352
+ else :
353
+ paged_kv_indptr = None
354
+ paged_kv_indices = None
355
+ paged_kv_last_page_len = None
356
+ decode_wrapper = BatchDecodeWithPagedKVCacheWrapper (
319
357
self ._get_workspace_buffer (),
320
358
get_kv_cache_layout (),
359
+ use_cuda_graph = use_cudagraph ,
360
+ paged_kv_indptr_buffer = paged_kv_indptr ,
361
+ paged_kv_indices_buffer = paged_kv_indices ,
362
+ paged_kv_last_page_len_buffer = paged_kv_last_page_len ,
321
363
use_tensor_cores = use_tensor_cores )
322
- return self ._decode_wrapper
364
+
365
+ # save the decode wrapper
366
+ if use_cudagraph :
367
+ self ._decode_wrappers_cudagraph [batch_size ] = decode_wrapper
368
+ else :
369
+ self ._decode_wrapper = decode_wrapper
370
+
371
+ return decode_wrapper
323
372
324
373
def _get_cascade_wrapper (self ):
325
374
if self ._cascade_wrapper is None :
@@ -395,11 +444,27 @@ def _plan(self, attn_metadata: FlashInferMetadata):
395
444
)
396
445
397
446
if self ._num_decodes > 0 :
398
- attn_metadata .decode_wrapper = self ._get_decode_wrapper ()
447
+ pure_decode = self ._num_prefills == 0
448
+ # possible required padding for cudagraph replay
449
+ if self .enable_cuda_graph and pure_decode and \
450
+ self ._num_decodes <= self ._decode_cudagraph_max_bs :
451
+ num_input_tokens_decode = self .vllm_config .pad_for_cudagraph (
452
+ self ._num_decodes )
453
+ else :
454
+ num_input_tokens_decode = self ._num_decodes
455
+
456
+ attn_metadata .decode_wrapper = self ._get_decode_wrapper (
457
+ num_input_tokens_decode , pure_decode )
458
+ # TODO: Override flashinfer's plan function to avoid some
459
+ # host-to-device copy overhead.
399
460
attn_metadata .decode_wrapper .plan (
400
- attn_metadata .paged_kv_indptr [:self ._num_decodes + 1 ],
401
- attn_metadata .paged_kv_indices ,
402
- attn_metadata .paged_kv_last_page_len [:self ._num_decodes ],
461
+ # NOTE: Use the persistent buffer with padding length,
462
+ # instead of the chunked length buffers in the atten_metadata.
463
+ # This is to compatible with FlashInfer's decode_wrapper
464
+ # cudagraph requirement.
465
+ self .paged_kv_indptr [:num_input_tokens_decode + 1 ],
466
+ self .paged_kv_indices ,
467
+ self .paged_kv_last_page_len [:num_input_tokens_decode ],
403
468
attn_metadata .num_qo_heads ,
404
469
attn_metadata .num_kv_heads ,
405
470
attn_metadata .head_dim ,
@@ -426,9 +491,16 @@ def build(self, common_prefix_len: int,
426
491
device = self .runner .device
427
492
qo_indptr = common_attn_metadata .query_start_loc
428
493
seq_lens = common_attn_metadata .seq_lens
429
- block_table_tensor = self .block_table .get_device_tensor ()[:num_reqs ]
430
- slot_mapping = self .block_table .slot_mapping_cpu [:num_actual_tokens ].to (
431
- self .runner .device , non_blocking = True ).long ()
494
+ block_table = self .block_table
495
+ block_table_tensor = block_table .get_device_tensor ()[:num_reqs ]
496
+ block_table .slot_mapping [:num_actual_tokens ].copy_ (
497
+ block_table .slot_mapping_cpu [:num_actual_tokens ],
498
+ non_blocking = True )
499
+ # Fill unused with -1. Needed for reshape_and_cache in full cuda graph
500
+ # mode.
501
+ block_table .slot_mapping [num_actual_tokens :].fill_ (- 1 )
502
+
503
+ slot_mapping = block_table .slot_mapping [:num_actual_tokens ]
432
504
433
505
block_table_bounds = (seq_lens + page_size - 1 ) // page_size
434
506
@@ -462,24 +534,37 @@ def build(self, common_prefix_len: int,
462
534
device = block_table_tensor .device ).unsqueeze (0 )
463
535
< block_table_bounds .unsqueeze (1 ))
464
536
paged_kv_indices = block_table_tensor [mask ]
537
+ num_actual_pages = paged_kv_indices .size (0 )
538
+ self .paged_kv_indices [:num_actual_pages ].copy_ (
539
+ paged_kv_indices , non_blocking = True )
540
+ self .paged_kv_indices [num_actual_pages :].fill_ (- 1 )
465
541
466
542
paged_kv_indptr = torch .cat ([
467
543
torch .zeros (1 ,
468
544
dtype = block_table_bounds .dtype ,
469
545
device = block_table_bounds .device ),
470
546
block_table_bounds .cumsum (dim = 0 , dtype = torch .int32 )
471
547
])
548
+ self .paged_kv_indptr [:1 + num_reqs ].copy_ (
549
+ paged_kv_indptr , non_blocking = True )
550
+ # make sure self.paged_kv_indptr is not decreasing
551
+ self .paged_kv_indptr [1 + num_reqs :].fill_ (
552
+ paged_kv_indptr [- 1 ])
472
553
473
554
paged_kv_last_page_len = seq_lens % page_size
474
555
paged_kv_last_page_len = torch .where (paged_kv_last_page_len == 0 ,
475
556
page_size , paged_kv_last_page_len )
557
+ self .paged_kv_last_page_len [:num_reqs ].copy_ (
558
+ paged_kv_last_page_len , non_blocking = True )
559
+ self .paged_kv_last_page_len [num_reqs :].fill_ (
560
+ 0 )
476
561
477
562
attn_metadata = FlashInferMetadata (
478
563
num_actual_tokens = num_actual_tokens ,
479
564
qo_indptr = qo_indptr ,
480
- paged_kv_indptr = paged_kv_indptr ,
481
- paged_kv_indices = paged_kv_indices ,
482
- paged_kv_last_page_len = paged_kv_last_page_len ,
565
+ paged_kv_indptr = self . paged_kv_indptr [: 1 + num_reqs ] ,
566
+ paged_kv_indices = self . paged_kv_indices [: num_actual_pages ] ,
567
+ paged_kv_last_page_len = self . paged_kv_last_page_len [: num_reqs ] ,
483
568
num_qo_heads = self .runner .num_query_heads ,
484
569
num_kv_heads = self .kv_cache_spec .num_kv_heads ,
485
570
head_dim = self .kv_cache_spec .head_size ,
@@ -502,6 +587,30 @@ def build(self, common_prefix_len: int,
502
587
503
588
return attn_metadata
504
589
590
+ def build_for_cudagraph_capture (
591
+ self , common_attn_metadata : CommonAttentionMetadata ):
592
+ """
593
+ This method builds the metadata for full cudagraph capture.
594
+ Currently, only decode is supported for full cudagraphs with FlashInfer.
595
+ """
596
+ m = common_attn_metadata
597
+ m .query_start_loc .copy_ (torch .arange (m .num_actual_tokens + 1 ,
598
+ dtype = torch .int32 ,
599
+ device = self .runner .device ),
600
+ non_blocking = True )
601
+ assert m .num_reqs == m .num_actual_tokens , \
602
+ "FlashInfer only supports decode-only full CUDAGraph capture. " \
603
+ "Make sure all cudagraph capture sizes <= max_num_seq."
604
+
605
+ m .max_query_len = 1 # decode-only
606
+
607
+ # Update state usually set in reorder_batch.
608
+ self ._num_decodes = m .num_reqs
609
+ self ._num_decode_tokens = m .num_actual_tokens
610
+ self ._num_prefills = 0
611
+ self ._num_prefill_tokens = 0
612
+ return self .build (0 , m )
613
+
505
614
def can_run_in_cudagraph (
506
615
self , common_attn_metadata : CommonAttentionMetadata ) -> bool :
507
616
return common_attn_metadata .max_query_len == 1
0 commit comments