Skip to content
This repository was archived by the owner on Sep 4, 2025. It is now read-only.

Commit 6d646d0

Browse files
authored
[Core] Optimize Async + Multi-step (vllm-project#8050)
1 parent 95a178f commit 6d646d0

File tree

8 files changed

+326
-248
lines changed

8 files changed

+326
-248
lines changed

tests/multi_step/test_correctness_async_llm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,13 +103,13 @@ async def test_multi_step(
103103
model,
104104
server_args + distributed_args,
105105
num_logprobs,
106-
max_wait_seconds=3 * 240)
106+
max_wait_seconds=5 * 240)
107107
test_completions = await completions_with_server_args(
108108
prompts,
109109
model,
110110
ms_server_args + distributed_args,
111111
num_logprobs,
112-
max_wait_seconds=3 * 240)
112+
max_wait_seconds=5 * 240)
113113

114114
# Assert multi-step scheduling produces identical tokens
115115
# to single-step scheduling.

vllm/engine/async_llm_engine.py

Lines changed: 54 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -280,40 +280,27 @@ async def step_async(
280280
scheduler_outputs = cached_outputs.scheduler_outputs
281281
allow_async_output_proc = cached_outputs.allow_async_output_proc
282282

283-
# Detect async + multi-step
284-
use_async_and_multi_step = (self.scheduler_config.is_multi_step
285-
and allow_async_output_proc)
286-
287283
ctx = self.scheduler_contexts[virtual_engine]
288284

285+
# Clear outputs for each new scheduler iteration
286+
ctx.request_outputs.clear()
287+
289288
# skip the scheduler if there are any remaining steps in the seq groups.
290289
# This ensures that the scheduler is only called again when the current
291290
# batch has completed.
292291
if not self._has_remaining_steps(seq_group_metadata_list):
293292

294-
# Clear outputs on scheduler iteration start
295-
ctx.request_outputs.clear()
296-
297293
# Schedule iteration
298294
(seq_group_metadata_list, scheduler_outputs,
299295
allow_async_output_proc
300296
) = self.scheduler[virtual_engine].schedule()
301297

302-
# Detect async + multi-step
303-
use_async_and_multi_step = (self.scheduler_config.is_multi_step
304-
and allow_async_output_proc)
298+
ctx.seq_group_metadata_list = seq_group_metadata_list
299+
ctx.scheduler_outputs = scheduler_outputs
305300

306301
# Maybe switch from async mode to sync mode
307302
if not allow_async_output_proc and len(ctx.output_queue) > 0:
308-
self._process_model_outputs(virtual_engine=virtual_engine,
309-
is_async=True)
310-
311-
# For async + multi-step, init the queue
312-
if use_async_and_multi_step:
313-
assert len(ctx.output_queue) == 0
314-
assert seq_group_metadata_list is not None
315-
ctx.output_queue.append(
316-
(None, seq_group_metadata_list, scheduler_outputs))
303+
self._process_model_outputs(ctx=ctx)
317304

318305
if (self.scheduler_config.is_multi_step
319306
and scheduler_outputs.num_lookahead_slots > 0):
@@ -351,26 +338,20 @@ async def step_async(
351338
last_sampled_token_ids=last_sampled_token_ids)
352339

353340
if allow_async_output_proc:
354-
async_callback = self.async_callback_multi_step[
355-
virtual_engine] if use_async_and_multi_step \
356-
else self.async_callback[virtual_engine]
357-
358-
execute_model_req.async_callback = async_callback
359-
execute_model_req.use_async_and_multi_step = \
360-
use_async_and_multi_step
341+
execute_model_req.async_callback = self.async_callbacks[
342+
virtual_engine]
361343

362344
# Execute the model.
363345
output = await self.model_executor.execute_model_async(
364346
execute_model_req)
347+
365348
# we need to do this here so that last step's sampled_token_ids can
366349
# be passed to the next iteration for PP.
367350
if self.scheduler_config.is_multi_step:
368351
self._update_cached_scheduler_output(virtual_engine, output)
369352
else:
370-
if not use_async_and_multi_step and len(ctx.output_queue) > 0:
371-
assert not self.scheduler_config.is_multi_step
372-
self._process_model_outputs(virtual_engine=virtual_engine,
373-
is_async=True)
353+
if len(ctx.output_queue) > 0:
354+
self._process_model_outputs(ctx=ctx)
374355
output = []
375356

376357
# Finish the current step for all the sequence groups.
@@ -384,24 +365,22 @@ async def step_async(
384365
self.cached_scheduler_outputs[
385366
virtual_engine] = SchedulerOutputState()
386367

387-
if use_async_and_multi_step:
388-
# For async + multi-step, clear the queue
389-
ctx.output_queue.clear()
390-
else:
391-
ctx.output_queue.append(
392-
(output, seq_group_metadata_list, scheduler_outputs))
368+
is_async = allow_async_output_proc
369+
is_last_step = True
370+
ctx.output_queue.append(
371+
(output, seq_group_metadata_list, scheduler_outputs, is_async,
372+
is_last_step))
393373

394-
if output and allow_async_output_proc:
395-
assert len(
396-
output
397-
) == 1, "Multi step decoding does not work with async output processing." # noqa: E501
398-
self._advance_to_next_step(
399-
output[0], seq_group_metadata_list,
400-
scheduler_outputs.scheduled_seq_groups)
374+
if output and allow_async_output_proc:
375+
assert len(
376+
output
377+
) == 1, "Async postprocessor expects only a single output set"
378+
self._advance_to_next_step(
379+
output[0], seq_group_metadata_list,
380+
scheduler_outputs.scheduled_seq_groups)
401381

402382
if not allow_async_output_proc:
403-
self._process_model_outputs(virtual_engine=virtual_engine,
404-
is_async=False)
383+
self._process_model_outputs(ctx=ctx)
405384

406385
# Log stats.
407386
self.do_log_stats(scheduler_outputs, output)
@@ -411,17 +390,12 @@ async def step_async(
411390

412391
else:
413392
# Multi-step case
414-
if use_async_and_multi_step:
415-
return []
416-
else:
417-
ctx.request_outputs = []
393+
return ctx.request_outputs
418394

419395
if not self.has_unfinished_requests():
420396
# Drain async postprocessor (if exists)
421397
if len(ctx.output_queue) > 0:
422-
assert not self.scheduler_config.is_multi_step
423-
self._process_model_outputs(virtual_engine=virtual_engine,
424-
is_async=True)
398+
self._process_model_outputs(ctx=ctx)
425399
assert len(ctx.output_queue) == 0
426400

427401
return ctx.request_outputs
@@ -640,6 +614,17 @@ def __init__(self,
640614
self.log_requests = log_requests
641615
self.engine = self._init_engine(*args, **kwargs)
642616

617+
# This ensures quick processing of request outputs
618+
# so the append to asyncio queues is not delayed,
619+
# especially for multi-step.
620+
#
621+
# TODO: Currently, disabled for engine_use_ray, ask
622+
# Cody/Will/Woosuk about this case.
623+
self.use_process_request_outputs_callback = not self.engine_use_ray
624+
if self.use_process_request_outputs_callback:
625+
self.engine.process_request_outputs_callback = \
626+
self.process_request_outputs
627+
643628
if self.engine_use_ray:
644629
print_warning_once(
645630
"DEPRECATED. `--engine-use-ray` is deprecated and will "
@@ -883,13 +868,27 @@ async def engine_step(self, virtual_engine: int) -> bool:
883868
request_outputs = await self.engine.step_async(virtual_engine)
884869

885870
# Put the outputs into the corresponding streams.
886-
finished = True
871+
# If used as a callback, then already invoked inside
872+
# LLMEngine's _process_model_outputs
873+
if not self.use_process_request_outputs_callback:
874+
all_finished = self.process_request_outputs(request_outputs)
875+
else:
876+
# For callback case, we only need to detect when all
877+
# requests are finished
878+
all_finished = all(request_output.finished
879+
for request_output in request_outputs)
880+
881+
return not all_finished
882+
883+
def process_request_outputs(self, request_outputs) -> bool:
884+
# Put the outputs into the corresponding streams.
885+
all_finished = True
887886
for request_output in request_outputs:
888887
self._request_tracker.process_request_output(
889888
request_output, verbose=self.log_requests)
890-
finished = finished and request_output.finished
889+
all_finished = all_finished and request_output.finished
891890

892-
return not finished
891+
return all_finished
893892

894893
async def _engine_abort(self, request_ids: Iterable[str]):
895894
if self.engine_use_ray:

0 commit comments

Comments
 (0)