Skip to content

Commit 23194d8

Browse files
[BugFix] Fix DP/EP hang (vllm-project#25906)
Signed-off-by: Lucas Wilkinson <[email protected]>
1 parent 61aedb5 commit 23194d8

File tree

1 file changed

+15
-3
lines changed

1 file changed

+15
-3
lines changed

vllm/v1/worker/gpu_model_runner.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3075,13 +3075,19 @@ def _dummy_run(
30753075
# We currently only microbatch if the number of tokens is
30763076
# over a certain threshold.
30773077
if self.parallel_config.enable_dbo and allow_microbatching:
3078-
ubatch_slices, num_tokens_after_padding = ubatch_split(
3078+
ubatch_slices, ubatch_num_tokens_after_padding = ubatch_split(
30793079
num_scheduled_tokens,
30803080
total_num_scheduled_tokens,
30813081
total_num_scheduled_tokens,
30823082
uniform_decode=uniform_decode,
30833083
vllm_config=self.vllm_config,
30843084
)
3085+
# Currently when DBO is enabled `ubatch_split` returns
3086+
# the num_tokens_after_padding for a single ubatch, but we have 2
3087+
# TODO(sage,lucas): this is cruft that should be addressed in the
3088+
# padding refactor.
3089+
if ubatch_num_tokens_after_padding is not None:
3090+
num_tokens_after_padding = ubatch_num_tokens_after_padding * 2
30853091

30863092
# If we failed to microbatch, currently need to resynchronize
30873093
# TODO(lucas,sage): we should be able to avoid this second sync by
@@ -3198,7 +3204,7 @@ def _dummy_run(
31983204

31993205
# filter out the valid batch descriptor
32003206
_cg_mode, batch_descriptor = self.cudagraph_dispatcher.dispatch(
3201-
BatchDescriptor(num_tokens=num_tokens,
3207+
BatchDescriptor(num_tokens=num_tokens_after_padding,
32023208
uniform_decode=uniform_decode)) \
32033209
if not is_profile else (CUDAGraphMode.NONE, None)
32043210
if cudagraph_runtime_mode is not None:
@@ -3212,7 +3218,13 @@ def _dummy_run(
32123218
cudagraph_runtime_mode = _cg_mode
32133219

32143220
if ubatch_slices is not None:
3215-
num_tokens = num_tokens // 2
3221+
# Adjust values to reflect a single ubatch.
3222+
# TODO(sage,lucas): this is cruft that should be addressed in
3223+
# the padding refactor.
3224+
num_tokens_after_padding = ubatch_slices[0].num_tokens
3225+
if num_tokens_across_dp is not None:
3226+
num_tokens_across_dp[:] = num_tokens_after_padding
3227+
32163228
with self.maybe_randomize_inputs(input_ids), set_forward_context(
32173229
attn_metadata,
32183230
self.vllm_config,

0 commit comments

Comments
 (0)