@@ -3075,13 +3075,19 @@ def _dummy_run(
3075
3075
# We currently only microbatch if the number of tokens is
3076
3076
# over a certain threshold.
3077
3077
if self .parallel_config .enable_dbo and allow_microbatching :
3078
- ubatch_slices , num_tokens_after_padding = ubatch_split (
3078
+ ubatch_slices , ubatch_num_tokens_after_padding = ubatch_split (
3079
3079
num_scheduled_tokens ,
3080
3080
total_num_scheduled_tokens ,
3081
3081
total_num_scheduled_tokens ,
3082
3082
uniform_decode = uniform_decode ,
3083
3083
vllm_config = self .vllm_config ,
3084
3084
)
3085
+ # Currently when DBO is enabled `ubatch_split` returns
3086
+ # the num_tokens_after_padding for a single ubatch, but we have 2
3087
+ # TODO(sage,lucas): this is cruft that should be addressed in the
3088
+ # padding refactor.
3089
+ if ubatch_num_tokens_after_padding is not None :
3090
+ num_tokens_after_padding = ubatch_num_tokens_after_padding * 2
3085
3091
3086
3092
# If we failed to microbatch, currently need to resynchronize
3087
3093
# TODO(lucas,sage): we should be able to avoid this second sync by
@@ -3198,7 +3204,7 @@ def _dummy_run(
3198
3204
3199
3205
# filter out the valid batch descriptor
3200
3206
_cg_mode , batch_descriptor = self .cudagraph_dispatcher .dispatch (
3201
- BatchDescriptor (num_tokens = num_tokens ,
3207
+ BatchDescriptor (num_tokens = num_tokens_after_padding ,
3202
3208
uniform_decode = uniform_decode )) \
3203
3209
if not is_profile else (CUDAGraphMode .NONE , None )
3204
3210
if cudagraph_runtime_mode is not None :
@@ -3212,7 +3218,13 @@ def _dummy_run(
3212
3218
cudagraph_runtime_mode = _cg_mode
3213
3219
3214
3220
if ubatch_slices is not None :
3215
- num_tokens = num_tokens // 2
3221
+ # Adjust values to reflect a single ubatch.
3222
+ # TODO(sage,lucas): this is cruft that should be addressed in
3223
+ # the padding refactor.
3224
+ num_tokens_after_padding = ubatch_slices [0 ].num_tokens
3225
+ if num_tokens_across_dp is not None :
3226
+ num_tokens_across_dp [:] = num_tokens_after_padding
3227
+
3216
3228
with self .maybe_randomize_inputs (input_ids ), set_forward_context (
3217
3229
attn_metadata ,
3218
3230
self .vllm_config ,
0 commit comments