Skip to content

Commit 756d721

Browse files
committed
dp metadata refactor
Signed-off-by: Sage Moore <[email protected]>
1 parent 7427b2d commit 756d721

File tree

3 files changed

+18
-24
lines changed

3 files changed

+18
-24
lines changed

vllm/compilation/ubatch_wrapper.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ def _ubatch_thread(results, model, ubatch_metadata):
185185

186186
def _make_ubatch_metadata(self, ubatch_slices, attn_metadata, input_ids,
187187
positions, inputs_embeds, intermediate_tensors,
188-
compute_stream, num_tokens_across_dp,
188+
compute_stream, dp_metadata,
189189
batch_descriptor,
190190
cudagraph_runtime_mode) -> list[UbatchMetadata]:
191191

@@ -198,8 +198,7 @@ def _make_ubatch_metadata(self, ubatch_slices, attn_metadata, input_ids,
198198
create_forward_context(
199199
attn_metadata[i] if attn_metadata is not None else None,
200200
self.vllm_config,
201-
num_tokens=num_tokens,
202-
num_tokens_across_dp=num_tokens_across_dp,
201+
dp_metadata=dp_metadata,
203202
batch_descriptor=batch_descriptor,
204203
cudagraph_runtime_mode=cudagraph_runtime_mode))
205204

@@ -270,8 +269,9 @@ def __call__(self, *args, **kwargs):
270269
compute_stream = torch.cuda.current_stream()
271270

272271
dp_metadata = forward_context.dp_metadata
272+
273+
# We shouldn't be here unless we are running with multiple DP ranks
273274
assert dp_metadata is not None
274-
num_tokens_across_dp = dp_metadata._num_tokens_across_dp
275275

276276
if num_tokens not in self.cudagraphs \
277277
and cudagraph_runtime_mode is CUDAGraphMode.FULL:
@@ -283,7 +283,7 @@ def __call__(self, *args, **kwargs):
283283
intermediate_tensors=intermediate_tensors,
284284
inputs_embeds=inputs_embeds,
285285
compute_stream=compute_stream,
286-
num_tokens_across_dp=num_tokens_across_dp,
286+
dp_metadata=dp_metadata,
287287
batch_descriptor=batch_descriptor,
288288
cudagraph_runtime_mode=CUDAGraphMode.NONE)
289289

@@ -301,7 +301,7 @@ def __call__(self, *args, **kwargs):
301301
intermediate_tensors=intermediate_tensors,
302302
inputs_embeds=inputs_embeds,
303303
compute_stream=compute_stream,
304-
num_tokens_across_dp=num_tokens_across_dp,
304+
dp_metadata=dp_metadata,
305305
batch_descriptor=batch_descriptor,
306306
cudagraph_runtime_mode=CUDAGraphMode.NONE)
307307
return self._run_ubatches(ubatch_metadata, self.model)

vllm/forward_context.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -237,18 +237,10 @@ def get_forward_context() -> ForwardContext:
237237
def create_forward_context(attn_metadata: Any,
238238
vllm_config: VllmConfig,
239239
virtual_engine: int = 0,
240-
num_tokens: Optional[int] = None,
241-
num_tokens_across_dp: Optional[torch.Tensor] = None,
240+
dp_metadata: Optional[DPMetadata] = None,
242241
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
243242
batch_descriptor: Optional[BatchDescriptor] = None,
244243
ubatch_slices: Optional[UBatchSlices] = None):
245-
dp_metadata: Optional[DPMetadata] = None
246-
if vllm_config.parallel_config.data_parallel_size > 1 and (
247-
attn_metadata is not None or num_tokens is not None):
248-
dp_metadata = DPMetadata.make(vllm_config.parallel_config,
249-
attn_metadata, num_tokens or 0,
250-
num_tokens_across_dp)
251-
252244
return ForwardContext(no_compile_layers=vllm_config.compilation_config.
253245
static_forward_context,
254246
virtual_engine=virtual_engine,
@@ -293,9 +285,15 @@ def set_forward_context(
293285
if need_to_track_batchsize:
294286
forward_start_time = time.perf_counter()
295287

288+
dp_metadata: Optional[DPMetadata] = None
289+
if vllm_config.parallel_config.data_parallel_size > 1 and (
290+
attn_metadata is not None or num_tokens is not None):
291+
dp_metadata = DPMetadata.make(vllm_config.parallel_config,
292+
attn_metadata, num_tokens or 0,
293+
num_tokens_across_dp)
294+
296295
forward_context = create_forward_context(attn_metadata, vllm_config,
297-
virtual_engine, num_tokens,
298-
num_tokens_across_dp,
296+
virtual_engine, dp_metadata,
299297
cudagraph_runtime_mode, batch_descriptor,
300298
ubatch_slices)
301299

vllm/v1/worker/gpu_model_runner.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1686,13 +1686,6 @@ def execute_model(
16861686
num_input_tokens)
16871687
num_input_tokens += num_pad
16881688

1689-
uniform_decode = (max_query_len == self.uniform_decode_query_len) and (
1690-
num_scheduled_tokens == self.input_batch.num_reqs * max_query_len)
1691-
batch_descriptor = BatchDescriptor(num_tokens=num_input_tokens,
1692-
uniform_decode=uniform_decode)
1693-
cudagraph_runtime_mode, batch_descriptor = \
1694-
self.cudagraph_dispatcher.dispatch(batch_descriptor)
1695-
16961689
if self.supports_mm_inputs:
16971690
# Run the multimodal encoder if any.
16981691
self._execute_mm_encoder(scheduler_output)
@@ -1747,6 +1740,8 @@ def execute_model(
17471740
uniform_decode=uniform_decode)
17481741
cudagraph_runtime_mode, batch_descriptor = \
17491742
self.cudagraph_dispatcher.dispatch(batch_descriptor)
1743+
1744+
logger.info(f"NUM TOKENS: {num_input_tokens} cudagraph_runtime_mode {cudagraph_runtime_mode} UBATCHING {ubatch_slices is not None}")
17501745

17511746
# Run the model.
17521747
# Use persistent buffers for CUDA graphs.
@@ -3138,6 +3133,7 @@ def initialize_cudagraph_capture(self) -> None:
31383133

31393134
# Trigger cudagraph dispatching keys initialization here (after
31403135
# initializing attn backends).
3136+
logger.info(f"INITIALIZING KEYS FOR MODE: {self.compilation_config.cudagraph_mode}")
31413137
self.cudagraph_dispatcher.initialize_cudagraph_keys(
31423138
self.compilation_config.cudagraph_mode,
31433139
self.uniform_decode_query_len)

0 commit comments

Comments
 (0)