Skip to content

Commit e4c65a3

Browse files
committed
FlashInfer full cuda graoh support
Signed-off-by: fhl <[email protected]>
1 parent a56ec6f commit e4c65a3

File tree

2 files changed

+137
-25
lines changed

2 files changed

+137
-25
lines changed

vllm/v1/attention/backends/flashinfer.py

Lines changed: 126 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from __future__ import annotations
55

66
from dataclasses import dataclass
7-
from typing import TYPE_CHECKING, Any, Optional
7+
from typing import TYPE_CHECKING, Any, ClassVar, Optional
88

99
import torch
1010
from flashinfer import (BatchDecodeWithPagedKVCacheWrapper,
@@ -218,22 +218,43 @@ def __post_init__(self):
218218

219219

220220
class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
221+
full_cudagraph_supported: ClassVar[bool] = True
221222

222223
def __init__(self, runner: GPUModelRunner, kv_cache_spec: AttentionSpec,
223224
block_table: BlockTable):
224225
self.runner = runner
226+
self.vllm_config = runner.vllm_config
225227
self._workspace_buffer = None
226228
self._prefill_wrapper = None # Wrapper for prefill/append
227-
self._decode_wrapper = None # Wrapper for decode
229+
self._decode_wrapper = None # Wrapper for decode (general shape)
230+
self.enable_cuda_graph = self.vllm_config.compilation_config.full_cuda_graph
231+
if self.enable_cuda_graph:
232+
# For full cudagraph capture, one `decode_wrapper` for each batch
233+
# size is needed for FlashInfer.
234+
self._decode_wrappers_cudagraph: dict[int, BatchDecodeWithPagedKVCacheWrapper] = {}
235+
self._decode_cudagraph_max_bs = min(runner.max_num_reqs,
236+
runner.cudagraph_batch_sizes[-1])
237+
228238
self._cascade_wrapper = None # Wrapper for cascade attention
229239

230240
# Global hyperparameters shared by all attention layers
231241
self.global_hyperparameters: Optional[PerLayerParameters] = None
232242

233-
self.vllm_config = runner.vllm_config
234243
self.kv_cache_spec = kv_cache_spec
235244
self.block_table = block_table
236245

246+
# Preparing persistent buffers
247+
self.paged_kv_indptr = torch.zeros(
248+
self.runner.max_num_reqs + 1,
249+
dtype=torch.int32,
250+
device=self.runner.device)
251+
self.paged_kv_indices = torch.zeros(
252+
block_table.get_device_tensor().numel(), # max num pages possible
253+
dtype=torch.int32, device=self.runner.device)
254+
self.paged_kv_last_page_len = torch.zeros(
255+
self.runner.max_num_reqs,
256+
dtype=torch.int32, device=self.runner.device)
257+
237258
def reorder_batch(self, input_batch: InputBatch,
238259
scheduler_output: SchedulerOutput) -> bool:
239260
# We now want to reorder the batch so that the "decode" requests are and
@@ -307,19 +328,47 @@ def _get_prefill_wrapper(self):
307328
self._get_workspace_buffer(), get_kv_cache_layout())
308329
return self._prefill_wrapper
309330

310-
def _get_decode_wrapper(self):
311-
if self._decode_wrapper is None:
331+
def _get_decode_wrapper(self, batch_size: int, pure_decode: bool = False):
332+
use_cudagraph = (self.enable_cuda_graph and pure_decode
333+
and batch_size <= self._decode_cudagraph_max_bs)
334+
335+
if use_cudagraph:
336+
decode_wrapper = self._decode_wrappers_cudagraph.get(batch_size, None)
337+
else:
338+
decode_wrapper = self._decode_wrapper
339+
340+
if decode_wrapper is None:
312341
num_qo_heads = (self.runner.model_config.get_num_attention_heads(
313342
self.runner.parallel_config))
314343
num_kv_heads = self.runner.model_config.get_num_kv_heads(
315344
self.runner.parallel_config)
316345
use_tensor_cores = envs.VLLM_FLASHINFER_FORCE_TENSOR_CORES or (
317346
num_qo_heads // num_kv_heads > 4)
318-
self._decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
347+
348+
if use_cudagraph:
349+
paged_kv_indptr = self.paged_kv_indptr[:batch_size + 1]
350+
paged_kv_indices = self.paged_kv_indices
351+
paged_kv_last_page_len = self.paged_kv_last_page_len[:batch_size]
352+
else:
353+
paged_kv_indptr = None
354+
paged_kv_indices = None
355+
paged_kv_last_page_len = None
356+
decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
319357
self._get_workspace_buffer(),
320358
get_kv_cache_layout(),
359+
use_cuda_graph=use_cudagraph,
360+
paged_kv_indptr_buffer=paged_kv_indptr,
361+
paged_kv_indices_buffer=paged_kv_indices,
362+
paged_kv_last_page_len_buffer=paged_kv_last_page_len,
321363
use_tensor_cores=use_tensor_cores)
322-
return self._decode_wrapper
364+
365+
# save the decode wrapper
366+
if use_cudagraph:
367+
self._decode_wrappers_cudagraph[batch_size] = decode_wrapper
368+
else:
369+
self._decode_wrapper = decode_wrapper
370+
371+
return decode_wrapper
323372

324373
def _get_cascade_wrapper(self):
325374
if self._cascade_wrapper is None:
@@ -395,11 +444,27 @@ def _plan(self, attn_metadata: FlashInferMetadata):
395444
)
396445

397446
if self._num_decodes > 0:
398-
attn_metadata.decode_wrapper = self._get_decode_wrapper()
447+
pure_decode = self._num_prefills == 0
448+
# possible required padding for cudagraph replay
449+
if self.enable_cuda_graph and pure_decode and \
450+
self._num_decodes <= self._decode_cudagraph_max_bs:
451+
num_input_tokens_decode = self.vllm_config.pad_for_cudagraph(
452+
self._num_decodes)
453+
else:
454+
num_input_tokens_decode = self._num_decodes
455+
456+
attn_metadata.decode_wrapper = self._get_decode_wrapper(
457+
num_input_tokens_decode, pure_decode)
458+
# TODO: Override flashinfer's plan function to avoid some
459+
# host-to-device copy overhead.
399460
attn_metadata.decode_wrapper.plan(
400-
attn_metadata.paged_kv_indptr[:self._num_decodes + 1],
401-
attn_metadata.paged_kv_indices,
402-
attn_metadata.paged_kv_last_page_len[:self._num_decodes],
461+
# NOTE: Use the persistent buffer with padding length,
462+
# instead of the chunked length buffers in the atten_metadata.
463+
# This is to compatible with FlashInfer's decode_wrapper
464+
# cudagraph requirement.
465+
self.paged_kv_indptr[:num_input_tokens_decode + 1],
466+
self.paged_kv_indices,
467+
self.paged_kv_last_page_len[:num_input_tokens_decode],
403468
attn_metadata.num_qo_heads,
404469
attn_metadata.num_kv_heads,
405470
attn_metadata.head_dim,
@@ -426,9 +491,16 @@ def build(self, common_prefix_len: int,
426491
device = self.runner.device
427492
qo_indptr = common_attn_metadata.query_start_loc
428493
seq_lens = common_attn_metadata.seq_lens
429-
block_table_tensor = self.block_table.get_device_tensor()[:num_reqs]
430-
slot_mapping = self.block_table.slot_mapping_cpu[:num_actual_tokens].to(
431-
self.runner.device, non_blocking=True).long()
494+
block_table = self.block_table
495+
block_table_tensor = block_table.get_device_tensor()[:num_reqs]
496+
block_table.slot_mapping[:num_actual_tokens].copy_(
497+
block_table.slot_mapping_cpu[:num_actual_tokens],
498+
non_blocking=True)
499+
# Fill unused with -1. Needed for reshape_and_cache in full cuda graph
500+
# mode.
501+
block_table.slot_mapping[num_actual_tokens:].fill_(-1)
502+
503+
slot_mapping = block_table.slot_mapping[:num_actual_tokens]
432504

433505
block_table_bounds = (seq_lens + page_size - 1) // page_size
434506

@@ -462,24 +534,37 @@ def build(self, common_prefix_len: int,
462534
device=block_table_tensor.device).unsqueeze(0)
463535
< block_table_bounds.unsqueeze(1))
464536
paged_kv_indices = block_table_tensor[mask]
537+
num_actual_pages = paged_kv_indices.size(0)
538+
self.paged_kv_indices[:num_actual_pages].copy_(
539+
paged_kv_indices, non_blocking=True)
540+
self.paged_kv_indices[num_actual_pages:].fill_(-1)
465541

466542
paged_kv_indptr = torch.cat([
467543
torch.zeros(1,
468544
dtype=block_table_bounds.dtype,
469545
device=block_table_bounds.device),
470546
block_table_bounds.cumsum(dim=0, dtype=torch.int32)
471547
])
548+
self.paged_kv_indptr[:1+num_reqs].copy_(
549+
paged_kv_indptr, non_blocking=True)
550+
# make sure self.paged_kv_indptr is not decreasing
551+
self.paged_kv_indptr[1+num_reqs:].fill_(
552+
paged_kv_indptr[-1])
472553

473554
paged_kv_last_page_len = seq_lens % page_size
474555
paged_kv_last_page_len = torch.where(paged_kv_last_page_len == 0,
475556
page_size, paged_kv_last_page_len)
557+
self.paged_kv_last_page_len[:num_reqs].copy_(
558+
paged_kv_last_page_len, non_blocking=True)
559+
self.paged_kv_last_page_len[num_reqs:].fill_(
560+
0)
476561

477562
attn_metadata = FlashInferMetadata(
478563
num_actual_tokens=num_actual_tokens,
479564
qo_indptr=qo_indptr,
480-
paged_kv_indptr=paged_kv_indptr,
481-
paged_kv_indices=paged_kv_indices,
482-
paged_kv_last_page_len=paged_kv_last_page_len,
565+
paged_kv_indptr=self.paged_kv_indptr[:1+num_reqs],
566+
paged_kv_indices=self.paged_kv_indices[:num_actual_pages],
567+
paged_kv_last_page_len=self.paged_kv_last_page_len[:num_reqs],
483568
num_qo_heads=self.runner.num_query_heads,
484569
num_kv_heads=self.kv_cache_spec.num_kv_heads,
485570
head_dim=self.kv_cache_spec.head_size,
@@ -502,6 +587,30 @@ def build(self, common_prefix_len: int,
502587

503588
return attn_metadata
504589

590+
def build_for_cudagraph_capture(
591+
self, common_attn_metadata: CommonAttentionMetadata):
592+
"""
593+
This method builds the metadata for full cudagraph capture.
594+
Currently, only decode is supported for full cudagraphs with FlashInfer.
595+
"""
596+
m = common_attn_metadata
597+
m.query_start_loc.copy_(torch.arange(m.num_actual_tokens+1,
598+
dtype=torch.int32,
599+
device=self.runner.device),
600+
non_blocking=True)
601+
assert m.num_reqs == m.num_actual_tokens, \
602+
"FlashInfer only supports decode-only full CUDAGraph capture. " \
603+
"Make sure all cudagraph capture sizes <= max_num_seq."
604+
605+
m.max_query_len = 1 # decode-only
606+
607+
# Update state usually set in reorder_batch.
608+
self._num_decodes = m.num_reqs
609+
self._num_decode_tokens = m.num_actual_tokens
610+
self._num_prefills = 0
611+
self._num_prefill_tokens = 0
612+
return self.build(0, m)
613+
505614
def can_run_in_cudagraph(
506615
self, common_attn_metadata: CommonAttentionMetadata) -> bool:
507616
return common_attn_metadata.max_query_len == 1

vllm/v1/worker/gpu_model_runner.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1948,14 +1948,17 @@ def _dummy_run(
19481948
skip_attention_cuda_graphs = not attention_cuda_graphs \
19491949
if self.full_cuda_graph else True
19501950

1951-
for kv_cache_group_id, kv_cache_group_spec in enumerate(
1952-
self.kv_cache_config.kv_cache_groups):
1953-
1954-
attn_metadata_i = self.attn_metadata_builders[
1955-
kv_cache_group_id].build_for_cudagraph_capture(
1956-
common_attn_metadata)
1957-
for layer_name in kv_cache_group_spec.layer_names:
1958-
attn_metadata[layer_name] = attn_metadata_i
1951+
if not skip_attention_cuda_graphs:
1952+
for kv_cache_group_id, kv_cache_group_spec in enumerate(
1953+
self.kv_cache_config.kv_cache_groups):
1954+
1955+
attn_metadata_i = self.attn_metadata_builders[
1956+
kv_cache_group_id].build_for_cudagraph_capture(
1957+
common_attn_metadata)
1958+
for layer_name in kv_cache_group_spec.layer_names:
1959+
attn_metadata[layer_name] = attn_metadata_i
1960+
else:
1961+
attn_metadata = None # reset to None other than empty dict
19591962

19601963
with self.maybe_dummy_run_with_lora(self.lora_config,
19611964
num_scheduled_tokens):

0 commit comments

Comments
 (0)