Skip to content

Commit 9e1f1af

Browse files
committed
misc gpu model runner refactoring
Signed-off-by: Sage Moore <[email protected]>
1 parent 10518bd commit 9e1f1af

File tree

1 file changed

+13
-7
lines changed

1 file changed

+13
-7
lines changed

vllm/v1/worker/gpu_model_runner.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -593,8 +593,8 @@ def _ubatch_split(
593593
if not self.parallel_config.enable_microbatching:
594594
return (None, 0, None)
595595

596+
# Check preconditions for microbatching
596597
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
597-
num_reqs = self.input_batch.num_reqs
598598
should_attempt_ubatching = \
599599
self.parallel_config.enable_microbatching and \
600600
total_num_scheduled_tokens >= \
@@ -610,11 +610,13 @@ def _ubatch_split(
610610
if not should_ubatch:
611611
return (None, 0, None)
612612

613-
# This doesn't actually pad the ubatch slices. It just initialize the
614-
# split point to the correct value so that padding can be applied
613+
# This doesn't actually pad the ubatch slices. It just initializes the
614+
# split point to the padded value so that padding can be applied
615615
# to the second ubatch in pad_out_ubatch_slice after attention
616616
# metadata creation
617-
assert num_pad_tokens < total_num_scheduled_tokens, f"num_pad_tokens {num_pad_tokens} original_num_tokens {total_num_scheduled_tokens}"
617+
assert num_pad_tokens < total_num_scheduled_tokens,\
618+
f"num_pad_tokens {num_pad_tokens} "\
619+
f"original_num_tokens {total_num_scheduled_tokens}"
618620
total_num_tokens_per_ubatch = (total_num_scheduled_tokens +
619621
num_pad_tokens) // 2
620622
padded_first_ubatch_slice = slice(0, total_num_tokens_per_ubatch)
@@ -2945,14 +2947,17 @@ def _capture_cudagraphs(self, compilation_cases: list[int],
29452947
"decode" if uniform_decode else "mixed prefill-decode",
29462948
cudagraph_runtime_mode.name))
29472949
enable_microbatching = self.parallel_config.enable_microbatching
2948-
# We skip EPLB here since we don't want to record dummy metrics
2950+
# DBO Only supports running Full cudagraphs with uniform
2951+
# decode lengths
29492952
if enable_microbatching and uniform_decode:
29502953
for num_tokens in compilation_cases:
29512954
# If the number of tokens is greater than the microbatching
29522955
# threshold, don't generate a microbatched cudagraph
29532956
if (num_tokens
29542957
< self.parallel_config.microbatching_token_threshold):
29552958
continue
2959+
2960+
# Warmup
29562961
for _ in range(
29572962
self.compilation_config.cudagraph_num_of_warmups):
29582963
force_attention = (
@@ -2963,13 +2968,14 @@ def _capture_cudagraphs(self, compilation_cases: list[int],
29632968
uniform_decode=True,
29642969
allow_microbatching=True,
29652970
skip_eplb=True)
2966-
# DBO Only supports running with Full cudagraphs with uniform
2967-
# decode lengths
2971+
2972+
# Graph Capture
29682973
self._dummy_run(num_tokens,
29692974
cudagraph_runtime_mode=CUDAGraphMode.FULL,
29702975
uniform_decode=True,
29712976
allow_microbatching=True,
29722977
skip_eplb=True)
2978+
# We skip EPLB here since we don't want to record dummy metrics
29732979
for num_tokens in compilation_cases:
29742980
for _ in range(self.compilation_config.cudagraph_num_of_warmups):
29752981
# Use CUDAGraphRuntimeStyle.NONE (default) for warmup.

0 commit comments

Comments
 (0)