Skip to content

Commit 2d1b7c7

Browse files
Pass token num cuda graph exec
Signed-off-by: Diego-Castan <[email protected]>
1 parent abd342f commit 2d1b7c7

File tree

2 files changed

+15
-7
lines changed

2 files changed

+15
-7
lines changed

vllm/v1/worker/gpu_model_runner.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2516,7 +2516,7 @@ def profile_run(self) -> None:
25162516
self.encoder_cache.clear()
25172517
gc.collect()
25182518

2519-
def capture_model(self) -> None:
2519+
def capture_model(self, specific_token_num: Optional[int]) -> None:
25202520
if not self.use_cuda_graph:
25212521
logger.warning(
25222522
"Skipping CUDA graph capture. To turn on CUDA graph capture, "
@@ -2550,16 +2550,15 @@ def freeze_gc():
25502550
with freeze_gc(), graph_capture(device=self.device):
25512551
full_cg = self.full_cuda_graph
25522552
# Only rank 0 should print progress bar during capture
2553-
compilation_cases = reversed(self.cudagraph_batch_sizes)
2554-
if is_global_first_rank():
2553+
compilation_cases = [specific_token_num] if specific_token_num else reversed(self.cudagraph_batch_sizes)
2554+
2555+
if is_global_first_rank() and specific_token_num is None:
25552556
compilation_cases = tqdm(
25562557
list(compilation_cases),
25572558
disable=not self.load_config.use_tqdm_on_load,
25582559
desc="Capturing CUDA graph shapes")
25592560
for num_tokens in compilation_cases:
25602561
# We skip EPLB here since we don't want to record dummy metrics
2561-
logger.info("DIEGO: compilation for number of tokens %d",
2562-
num_tokens)
25632562
for _ in range(
25642563
self.compilation_config.cudagraph_num_of_warmups):
25652564
self._dummy_run(num_tokens,

vllm/v1/worker/gpu_worker.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,9 @@ def __init__(
6464

6565
# Buffers saved before sleep
6666
self._sleep_saved_buffers: dict[str, torch.Tensor] = {}
67+
68+
# executed cuda graph
69+
self._token_compiled_cudagraphs: set[int] = set()
6770

6871
# Torch profiler. Enabled and configured through env vars:
6972
# VLLM_TORCH_PROFILER_DIR=/path/to/save/trace
@@ -310,8 +313,8 @@ def compile_or_warm_up_model(self) -> None:
310313
for size in sorted(warmup_sizes, reverse=True):
311314
logger.info("Compile and warming up model for size %d", size)
312315
self.model_runner._dummy_run(size, skip_eplb=True)
313-
if not self.model_config.enforce_eager:
314-
self.model_runner.capture_model()
316+
# if not self.model_config.enforce_eager:
317+
# self.model_runner.capture_model()
315318

316319
# Warm up sampler and preallocate memory buffer for logits and other
317320
# sampling related tensors of max possible shape to avoid memory
@@ -355,6 +358,12 @@ def execute_model(
355358
get_pp_group().recv_tensor_dict(
356359
all_gather_group=get_tp_group()))
357360

361+
# Adding capture model in execution time
362+
if scheduler_output.total_num_scheduled_tokens not in self._token_compiled_cudagraphs:
363+
logger.info("DIEGO: CUDAgraph in execution time for %d input tokens", scheduler_output.total_num_scheduled_tokens)
364+
self._token_compiled_cudagraphs.add(scheduler_output.total_num_scheduled_tokens)
365+
self.model_runner.capture_model(scheduler_output.total_num_scheduled_tokens)
366+
358367
output = self.model_runner.execute_model(scheduler_output,
359368
intermediate_tensors)
360369

0 commit comments

Comments
 (0)