@@ -96,6 +96,7 @@ def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig,
96
96
runtime_shape = shape ,
97
97
need_to_compile = shape in self .compile_sizes ,
98
98
use_cudagraph = shape in self .cudagraph_capture_sizes ,
99
+ usage_type = "piecewise(general)" , # for logging only
99
100
)
100
101
101
102
def check_for_ending_compilation (self ):
@@ -139,27 +140,32 @@ def __call__(self, *args) -> Any:
139
140
self .check_for_ending_compilation ()
140
141
141
142
# Skip CUDA graphs if this entry doesn't use them OR
142
- # if we're supposed to skip them globally
143
- skip_cuda_graphs = get_forward_context ().skip_cuda_graphs
144
- if not entry .use_cudagraph or skip_cuda_graphs :
143
+ # if we're supposed to treat the piecewise graphs as a whole,
144
+ # which implies forward_context.skip_attention_cuda_graphs is False.
145
+ # In the latter case, we rely on a wrapper class to capture
146
+ # the full cudagraph outside the fx graph.
147
+ skip_attention_cuda_graphs = get_forward_context ().skip_attention_cuda_graphs
148
+ if not entry .use_cudagraph or not skip_attention_cuda_graphs :
145
149
return entry .runnable (* args )
146
150
147
151
if entry .cudagraph is None :
148
152
if entry .num_finished_warmup < self .compilation_config .cudagraph_num_of_warmups : # noqa
149
153
entry .num_finished_warmup += 1
150
154
if self .is_first_graph :
151
155
logger .debug (
152
- "Warming up %s/%s for shape %s" ,
156
+ "Warming up %s/%s of %s usage for shape %s" ,
153
157
entry .num_finished_warmup ,
154
158
self .compilation_config .cudagraph_num_of_warmups ,
159
+ entry .usage_type ,
155
160
runtime_shape )
156
161
return entry .runnable (* args )
157
162
158
163
if self .is_first_graph :
159
164
# Since we capture cudagraph for many different shapes and
160
165
# capturing is fast, we don't need to log it for every shape.
161
166
# We only log it in the debug mode.
162
- logger .debug ("Capturing a cudagraph for shape %s" ,
167
+ logger .debug ("Capturing a cudagraph of %s usage for shape %s" ,
168
+ entry .usage_type ,
163
169
runtime_shape )
164
170
165
171
input_addresses = [
@@ -216,3 +222,137 @@ def __call__(self, *args) -> Any:
216
222
217
223
entry .cudagraph .replay ()
218
224
return entry .output
225
+
226
+
227
+ class FullCudagraphWrapper :
228
+ def __init__ (self , graph : fx .GraphModule , vllm_config : VllmConfig ,
229
+ graph_pool : Any , sym_shape_indices : list [int ],
230
+ ):
231
+ self .graph = graph
232
+ self .vllm_config = vllm_config
233
+ self .compilation_config = vllm_config .compilation_config
234
+ self .graph_pool = graph_pool
235
+ self .sym_shape_indices = sym_shape_indices
236
+
237
+ self .separate_attention_routine = vllm_config .compilation_config .separate_attention_routine
238
+
239
+ self .is_debugging_mode = envs .VLLM_LOGGING_LEVEL == "DEBUG"
240
+
241
+ self .first_run_finished = False
242
+
243
+ self .cudagraph_capture_sizes : set [int ] = set (
244
+ self .compilation_config .cudagraph_capture_sizes
245
+ ) if self .compilation_config .use_cudagraph else set ()
246
+
247
+ self .concrete_size_entries : dict [int , ConcreteSizeEntry ] = {}
248
+ self .concrete_size_entries_decode : dict [int , ConcreteSizeEntry ] = {}
249
+
250
+
251
+ for shape in self .cudagraph_capture_sizes :
252
+ self .concrete_size_entries [shape ] = ConcreteSizeEntry (
253
+ runtime_shape = shape ,
254
+ need_to_compile = False ,
255
+ use_cudagraph = True ,
256
+ usage_type = "general" ,
257
+ )
258
+ if self .separate_attention_routine :
259
+ self .concrete_size_entries_decode [shape ] = ConcreteSizeEntry (
260
+ runtime_shape = shape ,
261
+ need_to_compile = False ,
262
+ use_cudagraph = True ,
263
+ usage_type = "decode" ,
264
+ )
265
+
266
+ def __call__ (self , * args ) -> Any :
267
+ if not self .first_run_finished :
268
+ self .first_run_finished = True
269
+ return self .graph (* args )
270
+ list_args = list (args )
271
+ runtime_shape = list_args [self .sym_shape_indices [0 ]].shape [0 ]
272
+ forward_context = get_forward_context ()
273
+
274
+ if forward_context .skip_attention_cuda_graphs :
275
+ # turn back to piecewise cudagraphs backend, which is responsible
276
+ # for capturing and running the piecewise cudagraphs.
277
+ return self .graph (* args )
278
+
279
+ # if not skip, the fx graph and its sub-graphs will only be supposed to
280
+ # eagerly run the compiled graphs, which should be cudagraph capturable
281
+ # as a whole.
282
+
283
+ concrete_size_entries = self .concrete_size_entries # default as general usage
284
+ if self .separate_attention_routine and forward_context .is_pure_decoding :
285
+ concrete_size_entries = self .concrete_size_entries_decode
286
+
287
+ if not runtime_shape in concrete_size_entries :
288
+ # we don't need to do anything for this shape.
289
+ return self .graph (* args )
290
+
291
+ entry = concrete_size_entries [runtime_shape ]
292
+
293
+ if entry .runnable is None :
294
+ entry .runnable = self .graph
295
+
296
+ if not entry .use_cudagraph :
297
+ return entry .runnable (* args )
298
+
299
+ if entry .cudagraph is None :
300
+ if entry .num_finished_warmup < self .compilation_config .cudagraph_num_of_warmups : # noqa
301
+ entry .num_finished_warmup += 1
302
+ logger .debug (
303
+ "Warming up %s/%s of %s usage for shape %s" ,
304
+ entry .num_finished_warmup ,
305
+ self .compilation_config .cudagraph_num_of_warmups ,
306
+ entry .usage_type ,
307
+ runtime_shape )
308
+ return entry .runnable (* args )
309
+
310
+
311
+ # Since we capture cudagraph for many different shapes and
312
+ # capturing is fast, we don't need to log it for every shape.
313
+ # We only log it in the debug mode.
314
+
315
+ logger .debug ("Capturing a cudagraph of %s usage for shape %s" ,
316
+ entry .usage_type ,
317
+ runtime_shape )
318
+
319
+ input_addresses = [
320
+ x .data_ptr () for x in args if isinstance (x , torch .Tensor )
321
+ ]
322
+ entry .input_addresses = input_addresses
323
+ cudagraph = torch .cuda .CUDAGraph ()
324
+
325
+ with ExitStack () as stack :
326
+ # mind-exploding: carefully manage the reference and memory.
327
+ with torch .cuda .graph (cudagraph , pool = self .graph_pool ):
328
+ # `output` is managed by pytorch's cudagraph pool
329
+ output = entry .runnable (* args )
330
+ # by converting it to weak ref,
331
+ # the original `output` will immediately be released
332
+ # to save memory.
333
+ output = weak_ref_tensors (output )
334
+
335
+ # here we always use weak ref for the output
336
+ # to save memory
337
+ entry .output = weak_ref_tensors (output )
338
+ entry .cudagraph = cudagraph
339
+
340
+ compilation_counter .num_cudagraph_captured += 1
341
+
342
+ # important: we need to return the output, rather than
343
+ # the weak ref of the output, so that pytorch can correctly
344
+ # manage the memory during cuda graph capture
345
+ return output
346
+
347
+ if self .is_debugging_mode :
348
+ # check if the input addresses are the same
349
+ new_input_addresses = [
350
+ x .data_ptr () for x in args if isinstance (x , torch .Tensor )
351
+ ]
352
+ assert new_input_addresses == entry .input_addresses , (
353
+ "Input addresses for cudagraphs are different during replay."
354
+ f" Expected { entry .input_addresses } , got { new_input_addresses } "
355
+ )
356
+
357
+ entry .cudagraph .replay ()
358
+ return entry .output
0 commit comments