Skip to content

Commit 6302a7d

Browse files
committed
full_cudagraph support for FA2
Signed-off-by: fhl <[email protected]>
1 parent ee9a153 commit 6302a7d

File tree

10 files changed

+303
-36
lines changed

10 files changed

+303
-36
lines changed

vllm/compilation/backends.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -563,9 +563,7 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
563563

564564
self._called = True
565565

566-
if not self.compilation_config.use_cudagraph or \
567-
not self.compilation_config.cudagraph_copy_inputs:
568-
return self.split_gm
566+
569567

570568
# if we need to copy input buffers for cudagraph
571569
from torch._guards import detect_fake_mode
@@ -585,6 +583,18 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
585583
any(is_symbolic(d) for d in x.size())
586584
]
587585

586+
if self.compilation_config.full_cuda_graph:
587+
assert self.compilation_config.use_cudagraph, \
588+
"full_cuda_graph mode requires use_cudagraph to be True"
589+
fullgraph_wrapper = resolve_obj_by_qualname(
590+
current_platform.get_fullgraph_wrapper_cls())
591+
self.split_gm = fullgraph_wrapper(self.split_gm, self.vllm_config,
592+
self.graph_pool, self.sym_tensor_indices)
593+
594+
if not self.compilation_config.use_cudagraph or \
595+
not self.compilation_config.cudagraph_copy_inputs:
596+
return self.split_gm
597+
588598
# compiler managed cudagraph input buffers
589599
# we assume the first run with symbolic shapes
590600
# has the maximum size among all the tensors

vllm/compilation/base_piecewise_backend.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,3 +70,46 @@ def __call__(self, *args) -> Any:
7070
or a replayed static graph.
7171
"""
7272
raise NotImplementedError
73+
74+
75+
class AbstractFullgraphWrapper(Protocol):
76+
"""
77+
FullgraphWrapper interface that allows platforms to wrap the piecewise graph
78+
to be viewed or captured as a full graph.
79+
"""
80+
81+
def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig,
82+
graph_pool: Any, sym_shape_indices: list[int], **kwargs):
83+
"""
84+
Initializes the FullgraphWrapper class with compilation and
85+
execution-related configurations.
86+
87+
Args:
88+
graph (fx.GraphModule): The graph represented in fx.
89+
vllm_config (VllmConfig): Global configuration for vLLM.
90+
graph_pool (Any):
91+
Graph memory pool handle, e.g.,
92+
`torch.cuda.graph_pool_handle()`.
93+
sym_shape_indices (list[int]):
94+
Indices of symbolic shape.
95+
96+
Keyword Args:
97+
kwargs: Additional keyword arguments reserved for future
98+
extensions or custom platforms.
99+
100+
"""
101+
raise NotImplementedError
102+
103+
def __call__(self, *args) -> Any:
104+
"""
105+
Executes the wrapped graph for given input args.
106+
107+
Args:
108+
*args: Variable length input arguments to be passed into the
109+
graph. The symbolic shape is expected to be in position
110+
`sym_shape_indices[0]`.
111+
112+
Returns:
113+
Any: Output of the executed wrapped graph.
114+
"""
115+
raise NotImplementedError

vllm/compilation/cuda_piecewise_backend.py

Lines changed: 145 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig,
9696
runtime_shape=shape,
9797
need_to_compile=shape in self.compile_sizes,
9898
use_cudagraph=shape in self.cudagraph_capture_sizes,
99+
usage_type="piecewise(general)", # for logging only
99100
)
100101

101102
def check_for_ending_compilation(self):
@@ -139,27 +140,32 @@ def __call__(self, *args) -> Any:
139140
self.check_for_ending_compilation()
140141

141142
# Skip CUDA graphs if this entry doesn't use them OR
142-
# if we're supposed to skip them globally
143-
skip_cuda_graphs = get_forward_context().skip_cuda_graphs
144-
if not entry.use_cudagraph or skip_cuda_graphs:
143+
# if we're supposed to treat the piecewise graphs as a whole,
144+
# which implies forward_context.skip_attention_cuda_graphs is False.
145+
# In the latter case, we rely on a wrapper class to capture
146+
# the full cudagraph outside the fx graph.
147+
skip_attention_cuda_graphs = get_forward_context().skip_attention_cuda_graphs
148+
if not entry.use_cudagraph or not skip_attention_cuda_graphs:
145149
return entry.runnable(*args)
146150

147151
if entry.cudagraph is None:
148152
if entry.num_finished_warmup < self.compilation_config.cudagraph_num_of_warmups: # noqa
149153
entry.num_finished_warmup += 1
150154
if self.is_first_graph:
151155
logger.debug(
152-
"Warming up %s/%s for shape %s",
156+
"Warming up %s/%s of %s usage for shape %s",
153157
entry.num_finished_warmup,
154158
self.compilation_config.cudagraph_num_of_warmups,
159+
entry.usage_type,
155160
runtime_shape)
156161
return entry.runnable(*args)
157162

158163
if self.is_first_graph:
159164
# Since we capture cudagraph for many different shapes and
160165
# capturing is fast, we don't need to log it for every shape.
161166
# We only log it in the debug mode.
162-
logger.debug("Capturing a cudagraph for shape %s",
167+
logger.debug("Capturing a cudagraph of %s usage for shape %s",
168+
entry.usage_type,
163169
runtime_shape)
164170

165171
input_addresses = [
@@ -216,3 +222,137 @@ def __call__(self, *args) -> Any:
216222

217223
entry.cudagraph.replay()
218224
return entry.output
225+
226+
227+
class FullCudagraphWrapper:
228+
def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig,
229+
graph_pool: Any, sym_shape_indices: list[int],
230+
):
231+
self.graph = graph
232+
self.vllm_config = vllm_config
233+
self.compilation_config = vllm_config.compilation_config
234+
self.graph_pool = graph_pool
235+
self.sym_shape_indices = sym_shape_indices
236+
237+
self.separate_attention_routine = vllm_config.compilation_config.separate_attention_routine
238+
239+
self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG"
240+
241+
self.first_run_finished = False
242+
243+
self.cudagraph_capture_sizes: set[int] = set(
244+
self.compilation_config.cudagraph_capture_sizes
245+
) if self.compilation_config.use_cudagraph else set()
246+
247+
self.concrete_size_entries: dict[int, ConcreteSizeEntry] = {}
248+
self.concrete_size_entries_decode: dict[int, ConcreteSizeEntry] = {}
249+
250+
251+
for shape in self.cudagraph_capture_sizes:
252+
self.concrete_size_entries[shape] = ConcreteSizeEntry(
253+
runtime_shape=shape,
254+
need_to_compile=False,
255+
use_cudagraph=True,
256+
usage_type="general",
257+
)
258+
if self.separate_attention_routine:
259+
self.concrete_size_entries_decode[shape] = ConcreteSizeEntry(
260+
runtime_shape=shape,
261+
need_to_compile=False,
262+
use_cudagraph=True,
263+
usage_type="decode",
264+
)
265+
266+
def __call__(self, *args) -> Any:
267+
if not self.first_run_finished:
268+
self.first_run_finished = True
269+
return self.graph(*args)
270+
list_args = list(args)
271+
runtime_shape = list_args[self.sym_shape_indices[0]].shape[0]
272+
forward_context = get_forward_context()
273+
274+
if forward_context.skip_attention_cuda_graphs:
275+
# turn back to piecewise cudagraphs backend, which is responsible
276+
# for capturing and running the piecewise cudagraphs.
277+
return self.graph(*args)
278+
279+
# if not skip, the fx graph and its sub-graphs will only be supposed to
280+
# eagerly run the compiled graphs, which should be cudagraph capturable
281+
# as a whole.
282+
283+
concrete_size_entries = self.concrete_size_entries # default as general usage
284+
if self.separate_attention_routine and forward_context.is_pure_decoding:
285+
concrete_size_entries = self.concrete_size_entries_decode
286+
287+
if not runtime_shape in concrete_size_entries:
288+
# we don't need to do anything for this shape.
289+
return self.graph(*args)
290+
291+
entry = concrete_size_entries[runtime_shape]
292+
293+
if entry.runnable is None:
294+
entry.runnable = self.graph
295+
296+
if not entry.use_cudagraph:
297+
return entry.runnable(*args)
298+
299+
if entry.cudagraph is None:
300+
if entry.num_finished_warmup < self.compilation_config.cudagraph_num_of_warmups: # noqa
301+
entry.num_finished_warmup += 1
302+
logger.debug(
303+
"Warming up %s/%s of %s usage for shape %s",
304+
entry.num_finished_warmup,
305+
self.compilation_config.cudagraph_num_of_warmups,
306+
entry.usage_type,
307+
runtime_shape)
308+
return entry.runnable(*args)
309+
310+
311+
# Since we capture cudagraph for many different shapes and
312+
# capturing is fast, we don't need to log it for every shape.
313+
# We only log it in the debug mode.
314+
315+
logger.debug("Capturing a cudagraph of %s usage for shape %s",
316+
entry.usage_type,
317+
runtime_shape)
318+
319+
input_addresses = [
320+
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
321+
]
322+
entry.input_addresses = input_addresses
323+
cudagraph = torch.cuda.CUDAGraph()
324+
325+
with ExitStack() as stack:
326+
# mind-exploding: carefully manage the reference and memory.
327+
with torch.cuda.graph(cudagraph, pool=self.graph_pool):
328+
# `output` is managed by pytorch's cudagraph pool
329+
output = entry.runnable(*args)
330+
# by converting it to weak ref,
331+
# the original `output` will immediately be released
332+
# to save memory.
333+
output = weak_ref_tensors(output)
334+
335+
# here we always use weak ref for the output
336+
# to save memory
337+
entry.output = weak_ref_tensors(output)
338+
entry.cudagraph = cudagraph
339+
340+
compilation_counter.num_cudagraph_captured += 1
341+
342+
# important: we need to return the output, rather than
343+
# the weak ref of the output, so that pytorch can correctly
344+
# manage the memory during cuda graph capture
345+
return output
346+
347+
if self.is_debugging_mode:
348+
# check if the input addresses are the same
349+
new_input_addresses = [
350+
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
351+
]
352+
assert new_input_addresses == entry.input_addresses, (
353+
"Input addresses for cudagraphs are different during replay."
354+
f" Expected {entry.input_addresses}, got {new_input_addresses}"
355+
)
356+
357+
entry.cudagraph.replay()
358+
return entry.output

vllm/config.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3981,6 +3981,14 @@ class CompilationConfig:
39813981
splitting certain operations such as attention into subgraphs. Thus this
39823982
flag cannot be used together with splitting_ops. This may provide
39833983
performance benefits for smaller models."""
3984+
separate_attention_routine: bool = False
3985+
"""
3986+
Enable a distinct attention calls routine under an attention backend for full
3987+
cuda graph capturing. This is because some attention backends like FlashMLA,
3988+
FlashInfer, FA2, etc. implement different branches for mix prefill-decode and
3989+
pure decode cases. This flag enables us to potentially capture the cudagraph
3990+
separately for each branch.
3991+
"""
39843992

39853993
pass_config: PassConfig = field(default_factory=PassConfig)
39863994
"""Custom inductor passes, see PassConfig for more details"""
@@ -4179,13 +4187,15 @@ def init_with_cudagraph_sizes(self,
41794187

41804188
def set_splitting_ops_for_v1(self):
41814189
# NOTE: this function needs to be called
4182-
if self.splitting_ops and self.full_cuda_graph:
4183-
raise ValueError("full_cuda_graph cannot be used together with "
4184-
"splitting_ops, as Full CUDA graph will override "
4185-
f"the splitting_ops: {self.splitting_ops}")
4186-
4190+
# NOTE: When full_cuda_graph is True, instead of setting an empty list
4191+
# and capture the full cudagraph inside the flattened fx graph,
4192+
# we keep the piecewise fx graph structure but capture the full cudagraph
4193+
# outside the fx graph. This reduces some cpu overhead when the runtime
4194+
# batch_size is not cudagraph captured.
4195+
if self.separate_attention_routine:
4196+
assert self.full_cuda_graph, "separate_attention_routine requires full_cuda_graph to be True"
41874197
if not self.splitting_ops:
4188-
self.splitting_ops = [] if self.full_cuda_graph else [
4198+
self.splitting_ops = [
41894199
"vllm.unified_attention",
41904200
"vllm.unified_attention_with_output",
41914201
]

vllm/forward_context.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,11 @@ class ForwardContext:
9494
virtual_engine: int # set dynamically for each forward pass
9595
# set dynamically for each forward pass
9696
dp_metadata: Optional[DPMetadata] = None
97-
skip_cuda_graphs: bool = False
97+
# determine whether to use a full cudagraph for attention or piecewise
98+
# cudagraphs that skip the attention part. By default true, we use piecewise
99+
# cudagraphs.
100+
skip_attention_cuda_graphs: bool = True,
101+
is_pure_decoding: bool = False
98102

99103

100104
_forward_context: Optional[ForwardContext] = None
@@ -115,7 +119,8 @@ def set_forward_context(
115119
virtual_engine: int = 0,
116120
num_tokens: Optional[int] = None,
117121
num_tokens_across_dp: Optional[torch.Tensor] = None,
118-
skip_cuda_graphs: bool = False,
122+
skip_attention_cuda_graphs: bool = True,
123+
is_pure_decoding: bool = False,
119124
):
120125
"""A context manager that stores the current forward context,
121126
can be attention metadata, etc.
@@ -140,7 +145,8 @@ def set_forward_context(
140145
virtual_engine=virtual_engine,
141146
attn_metadata=attn_metadata,
142147
dp_metadata=dp_metadata,
143-
skip_cuda_graphs=skip_cuda_graphs,
148+
skip_attention_cuda_graphs=skip_attention_cuda_graphs,
149+
is_pure_decoding=is_pure_decoding,
144150
)
145151

146152
try:

vllm/platforms/cuda.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -359,6 +359,10 @@ def use_custom_allreduce(cls) -> bool:
359359
@classmethod
360360
def get_piecewise_backend_cls(cls) -> str:
361361
return "vllm.compilation.cuda_piecewise_backend.CUDAPiecewiseBackend" # noqa
362+
363+
@classmethod
364+
def get_fullgraph_wrapper_cls(cls) -> str:
365+
return "vllm.compilation.cuda_piecewise_backend.FullCudagraphWrapper" # noqa
362366

363367
@classmethod
364368
def stateless_init_device_torch_dist_pg(

vllm/platforms/interface.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -524,6 +524,13 @@ def get_piecewise_backend_cls(cls) -> str:
524524
Get piecewise backend class for piecewise graph.
525525
"""
526526
return "vllm.compilation.base_piecewise_backend.AbstractPiecewiseBackend" # noqa
527+
528+
@classmethod
529+
def get_fullgraph_wrapper_cls(cls) -> str:
530+
"""
531+
Get fullgraph wrapper class for fullgraph static graph.
532+
"""
533+
return "vllm.compilation.base_piecewise_backend.AbstractFullgraphWrapper" # noqa
527534

528535
@classmethod
529536
def stateless_init_device_torch_dist_pg(

vllm/v1/attention/backends/flash_attn.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -158,9 +158,6 @@ def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec,
158158

159159
self.aot_schedule = (get_flash_attn_version() == 3)
160160
self.use_full_cuda_graph = compilation_config.full_cuda_graph
161-
if self.use_full_cuda_graph and not self.aot_schedule:
162-
raise ValueError("Full CUDA graph mode requires AOT scheduling, "
163-
"which requires FlashAttention 3.")
164161
self.scheduler_metadata = torch.zeros(self.runner.max_num_reqs + 1,
165162
dtype=torch.int32,
166163
device=self.runner.device)
@@ -299,8 +296,7 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
299296
max_seq_len=max_seq_len,
300297
causal=True)
301298

302-
if self.use_full_cuda_graph:
303-
assert scheduler_metadata is not None
299+
if scheduler_metadata is not None:
304300
n = scheduler_metadata.shape[0]
305301
self.scheduler_metadata[:n].copy_(scheduler_metadata,
306302
non_blocking=True)
@@ -332,7 +328,7 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
332328

333329
def can_run_in_cudagraph(
334330
self, common_attn_metadata: CommonAttentionMetadata) -> bool:
335-
# Full CUDA Graph always supported (FA2 support checked separately)
331+
# Full CUDA Graph always supported (FA2 and FA3 support)
336332
return True
337333

338334
def use_cascade_attention(self, *args, **kwargs) -> bool:

vllm/v1/attention/backends/flashinfer.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -501,7 +501,11 @@ def build(self, common_prefix_len: int,
501501
self._plan(attn_metadata)
502502

503503
return attn_metadata
504-
504+
505+
def can_run_in_cudagraph(
506+
self, common_attn_metadata: CommonAttentionMetadata) -> bool:
507+
return common_attn_metadata.max_query_len == 1
508+
505509
def use_cascade_attention(self, *args, **kwargs) -> bool:
506510
if self.kv_cache_spec.dtype != self.runner.model_config.dtype:
507511
# TODO: The cascade wrapper currently does not support setting

0 commit comments

Comments
 (0)