Skip to content

Commit d90d8eb

Browse files
authored
[BugFix] Async scheduling and PP compatibility with DP (vllm-project#23770)
Signed-off-by: Nick Hill <[email protected]>
1 parent 0a2f4c0 commit d90d8eb

File tree

7 files changed

+107
-100
lines changed

7 files changed

+107
-100
lines changed

tests/v1/engine/test_engine_core.py

Lines changed: 22 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -306,17 +306,17 @@ def shutdown(self):
306306

307307
# Schedule Batch 1: (10, req0)
308308
assert engine_core.step_with_batch_queue()[0] is None
309-
assert engine_core.batch_queue.qsize() == 1
310-
scheduler_output = engine_core.batch_queue.queue[-1][1]
309+
assert len(engine_core.batch_queue) == 1
310+
scheduler_output = engine_core.batch_queue[-1][1]
311311
assert scheduler_output.num_scheduled_tokens["0"] == 10
312312
# num_computed_tokens should have been updated immediately.
313313
assert engine_core.scheduler.requests[
314314
req0.request_id].num_computed_tokens == 10
315315

316316
# Schedule Batch 2: (2, req0), (8, req1)
317-
assert engine_core.step_with_batch_queue()[0] is None
318-
assert engine_core.batch_queue.qsize() == 2
319-
scheduler_output = engine_core.batch_queue.queue[-1][1]
317+
assert engine_core.step_with_batch_queue()[0] == {}
318+
assert len(engine_core.batch_queue) == 1
319+
scheduler_output = engine_core.batch_queue[-1][1]
320320
assert scheduler_output.num_scheduled_tokens["0"] == 2
321321
assert scheduler_output.num_scheduled_tokens["1"] == 8
322322
# num_computed_tokens should have been updated immediately.
@@ -325,62 +325,47 @@ def shutdown(self):
325325

326326
assert engine_core.scheduler.get_num_unfinished_requests() == 2
327327

328-
# Batch queue is full. Finish Batch 1.
329-
engine_core.step_with_batch_queue()
330-
331-
# Schedule Batch 3: (4, req1). Note that req0 cannot be scheduled
328+
# Finish Batch 1 and schedule Batch 3: (4, req1).
329+
# Note that req0 cannot be scheduled
332330
# because it is in the decoding stage now.
333331
engine_core.step_with_batch_queue()
334-
assert engine_core.batch_queue.qsize() == 2
335-
scheduler_output = engine_core.batch_queue.queue[-1][1]
332+
assert len(engine_core.batch_queue) == 1
333+
scheduler_output = engine_core.batch_queue[-1][1]
336334
assert scheduler_output.num_scheduled_tokens["1"] == 4
337335

338-
# Batch queue is full. Finish Batch 2. Get first token of req0.
336+
# Finish Batch 2. Get first token of req0.
337+
# Schedule Batch 4: (1, req0).
339338
output = engine_core.step_with_batch_queue()[0].get(0)
340339
assert output is not None
341340
assert len(output.outputs) == 1
342341
assert engine_core.scheduler.requests[req0.request_id].num_tokens == 13
343-
344-
# Schedule Batch 4: (1, req0).
345-
engine_core.step_with_batch_queue()
346-
assert engine_core.batch_queue.qsize() == 2
347-
scheduler_output = engine_core.batch_queue.queue[-1][1]
342+
scheduler_output = engine_core.batch_queue[-1][1]
348343
assert scheduler_output.num_scheduled_tokens["0"] == 1
349344

350-
# Batch queue is full. Finish Batch 3. Get first token of req1.
345+
# Finish Batch 3. Get first token of req1. Schedule Batch 5: (1, req1).
351346
output = engine_core.step_with_batch_queue()[0].get(0)
352347
assert output is not None
353348
assert len(output.outputs) == 1
354349
assert engine_core.scheduler.requests[req1.request_id].num_tokens == 13
355-
356-
# Schedule Batch 5: (1, req1).
357-
engine_core.step_with_batch_queue()
358-
assert engine_core.batch_queue.qsize() == 2
359-
scheduler_output = engine_core.batch_queue.queue[-1][1]
350+
scheduler_output = engine_core.batch_queue[-1][1]
360351
assert scheduler_output.num_scheduled_tokens["1"] == 1
361352

362353
# Loop until req0 is finished.
363-
step = 0
364354
req_id = 0
365355
expected_num_tokens = [
366356
engine_core.scheduler.requests["0"].num_tokens + 1,
367357
engine_core.scheduler.requests["1"].num_tokens + 1,
368358
]
369359
while engine_core.scheduler.get_num_unfinished_requests() == 2:
370360
output = engine_core.step_with_batch_queue()[0]
371-
if step % 2 == 0:
372-
# Even steps consumes an output.
373-
assert output is not None
374-
assert len(output[0].outputs) == 1
375-
if req_id in engine_core.scheduler.requests:
376-
assert engine_core.scheduler.requests[
377-
req_id].num_tokens == expected_num_tokens[req_id]
378-
expected_num_tokens[req_id] += 1
379-
req_id = (req_id + 1) % 2
380-
else:
381-
# Odd steps schedules a new batch.
382-
assert output is None
383-
step += 1
361+
# Every step consumes an output.
362+
assert output is not None
363+
assert len(output[0].outputs) == 1
364+
if req_id in engine_core.scheduler.requests:
365+
assert engine_core.scheduler.requests[
366+
req_id].num_tokens == expected_num_tokens[req_id]
367+
expected_num_tokens[req_id] += 1
368+
req_id = (req_id + 1) % 2
384369

385370

386371
@multi_gpu_test(num_gpus=2)

tests/v1/test_async_llm_dp.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,9 +75,10 @@ async def generate(
7575
],
7676
)
7777
@pytest.mark.parametrize("data_parallel_backend", ["mp", "ray"])
78+
@pytest.mark.parametrize("async_scheduling", [True, False])
7879
@pytest.mark.asyncio
79-
async def test_load(output_kind: RequestOutputKind,
80-
data_parallel_backend: str):
80+
async def test_load(output_kind: RequestOutputKind, data_parallel_backend: str,
81+
async_scheduling: bool):
8182

8283
stats_loggers = {}
8384

@@ -105,6 +106,7 @@ def log_engine_initialized(self):
105106
prompt = "This is a test of data parallel"
106107

107108
engine_args.data_parallel_backend = data_parallel_backend
109+
engine_args.async_scheduling = async_scheduling
108110
engine = AsyncLLM.from_engine_args(engine_args,
109111
stat_loggers=[SimpleStatsLogger])
110112
after.callback(engine.shutdown)

vllm/executor/ray_utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import vllm.platforms
1212
from vllm.config import ParallelConfig
13+
from vllm.distributed import get_pp_group
1314
from vllm.executor.msgspec_utils import decode_hook, encode_hook
1415
from vllm.logger import init_logger
1516
from vllm.platforms import current_platform
@@ -136,6 +137,11 @@ def execute_model_ray(
136137
scheduler_output, intermediate_tensors)
137138
if isinstance(output, IntermediateTensors):
138139
output = scheduler_output, output
140+
elif not get_pp_group().is_last_rank:
141+
# Case where there are no scheduled requests
142+
# but may still be finished requests.
143+
assert not output or not output.req_ids
144+
output = scheduler_output, None
139145
return output
140146

141147
def override_env_vars(self, vars: Dict[str, str]):

vllm/v1/engine/core.py

Lines changed: 36 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -138,12 +138,12 @@ def __init__(self,
138138
# schedule and execute batches, and is required by pipeline parallelism
139139
# to eliminate pipeline bubbles.
140140
self.batch_queue_size = self.model_executor.max_concurrent_batches
141-
self.batch_queue: Optional[queue.Queue[tuple[Future[ModelRunnerOutput],
142-
SchedulerOutput]]] = None
141+
self.batch_queue: Optional[deque[tuple[Future[ModelRunnerOutput],
142+
SchedulerOutput]]] = None
143143
if self.batch_queue_size > 1:
144144
logger.info("Batch queue is enabled with size %d",
145145
self.batch_queue_size)
146-
self.batch_queue = queue.Queue(self.batch_queue_size)
146+
self.batch_queue = deque(maxlen=self.batch_queue_size)
147147

148148
self.request_block_hasher: Optional[Callable[[Request],
149149
list[BlockHash]]] = None
@@ -319,41 +319,43 @@ def step_with_batch_queue(
319319
batch in the job queue is finished.
320320
3. Update the scheduler from the output.
321321
"""
322-
assert self.batch_queue is not None
322+
batch_queue = self.batch_queue
323+
assert batch_queue is not None
323324

324-
engine_core_outputs = None
325-
scheduler_output = None
326325
# Try to schedule a new batch if the batch queue is not full, but
327326
# the scheduler may return an empty batch if all requests are scheduled.
328327
# Note that this is not blocking.
329-
if not self.batch_queue.full():
330-
scheduler_output = self.scheduler.schedule()
331-
if scheduler_output.total_num_scheduled_tokens > 0:
332-
future = self.model_executor.execute_model(scheduler_output)
333-
self.batch_queue.put_nowait(
334-
(future, scheduler_output)) # type: ignore
335-
336-
scheduled_batch = (scheduler_output is not None
337-
and scheduler_output.total_num_scheduled_tokens > 0)
338-
339-
# If no more requests can be scheduled and the job queue is not empty,
340-
# block until the first batch in the job queue is finished.
341-
# TODO(comaniac): Ideally we should peek the first batch in the
342-
# job queue to check if it's finished before scheduling a new batch,
343-
# but peeking the first element in a queue is not thread-safe,
344-
# so we need more work.
345-
if not scheduled_batch and not self.batch_queue.empty():
346-
future, scheduler_output = self.batch_queue.get_nowait()
328+
assert len(batch_queue) < self.batch_queue_size
347329

348-
# Blocking until the first result is available.
349-
model_output = self.execute_model_with_error_logging(
350-
lambda _: future.result(), scheduler_output)
330+
model_executed = False
331+
if self.scheduler.has_requests():
332+
scheduler_output = self.scheduler.schedule()
333+
future = self.model_executor.execute_model(scheduler_output)
334+
batch_queue.appendleft(
335+
(future, scheduler_output)) # type: ignore[arg-type]
336+
337+
model_executed = scheduler_output.total_num_scheduled_tokens > 0
338+
if model_executed and len(batch_queue) < self.batch_queue_size \
339+
and not batch_queue[-1][0].done():
340+
# Don't block on next worker response unless the queue is full
341+
# or there are no more requests to schedule.
342+
return None, True
343+
344+
elif not batch_queue:
345+
# Queue is empty. We should not reach here since this method should
346+
# only be called when the scheduler contains requests or the queue
347+
# is non-empty.
348+
return None, False
349+
350+
# Block until the next result is available.
351+
future, scheduler_output = batch_queue.pop()
352+
model_output = self.execute_model_with_error_logging(
353+
lambda _: future.result(), scheduler_output)
351354

352-
self.batch_queue.task_done()
353-
engine_core_outputs = (self.scheduler.update_from_output(
354-
scheduler_output, model_output))
355+
engine_core_outputs = self.scheduler.update_from_output(
356+
scheduler_output, model_output)
355357

356-
return engine_core_outputs, scheduled_batch
358+
return engine_core_outputs, model_executed
357359

358360
def shutdown(self):
359361
self.structured_output_manager.clear_backend()
@@ -388,7 +390,7 @@ def is_sleeping(self) -> bool:
388390
return self.model_executor.is_sleeping
389391

390392
def execute_dummy_batch(self):
391-
self.model_executor.collective_rpc("execute_dummy_batch")
393+
self.model_executor.execute_dummy_batch()
392394

393395
def add_lora(self, lora_request: LoRARequest) -> bool:
394396
return self.model_executor.add_lora(lora_request)
@@ -733,7 +735,8 @@ def _process_input_queue(self):
733735
"""Exits when an engine step needs to be performed."""
734736

735737
waited = False
736-
while not self.engines_running and not self.scheduler.has_requests():
738+
while not self.engines_running and not self.scheduler.has_requests() \
739+
and not self.batch_queue:
737740
if logger.isEnabledFor(DEBUG) and self.input_queue.empty():
738741
logger.debug("EngineCore waiting for work.")
739742
waited = True

vllm/v1/executor/abstract.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,12 +81,10 @@ def register_failure_callback(self, callback: FailureCallback):
8181
pass
8282

8383
def determine_available_memory(self) -> list[int]: # in bytes
84-
output = self.collective_rpc("determine_available_memory")
85-
return output
84+
return self.collective_rpc("determine_available_memory")
8685

8786
def get_kv_cache_specs(self) -> list[dict[str, KVCacheSpec]]:
88-
output = self.collective_rpc("get_kv_cache_spec")
89-
return output
87+
return self.collective_rpc("get_kv_cache_spec")
9088

9189
def execute_model(
9290
self,
@@ -96,6 +94,9 @@ def execute_model(
9694
args=(scheduler_output, ))
9795
return output[0]
9896

97+
def execute_dummy_batch(self) -> None:
98+
self.collective_rpc("execute_dummy_batch")
99+
99100
def take_draft_token_ids(self) -> Optional[DraftTokenIds]:
100101
output = self.collective_rpc("take_draft_token_ids")
101102
return output[0]

vllm/v1/executor/multiproc_executor.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,10 @@ def execute_model(
191191
outputs, self.output_rank)
192192
return self.kv_output_aggregator.aggregate(outputs, self.output_rank)
193193

194+
def execute_dummy_batch(self) -> None:
195+
self.collective_rpc("execute_dummy_batch",
196+
unique_reply_rank=self.output_rank)
197+
194198
def take_draft_token_ids(self) -> Optional[DraftTokenIds]:
195199
# OPTIMIZATION: Get output only from a single worker (output_rank)
196200
outputs = self.collective_rpc("take_draft_token_ids",
@@ -242,12 +246,17 @@ def get_response(w: WorkerProcHandle,
242246
dequeue_timeout = None if deadline is None else (
243247
deadline - time.monotonic())
244248

245-
if non_block:
249+
if self.io_thread_pool is not None:
250+
# We must consume worker_response_mq from a single thread.
246251
result = self.io_thread_pool.submit( # type: ignore
247252
get_response, w, dequeue_timeout, self.shutdown_event)
248-
else:
253+
if not non_block:
254+
result = result.result()
255+
elif not non_block:
249256
result = get_response(w, dequeue_timeout)
250-
257+
else:
258+
raise RuntimeError("non_block can only be used when"
259+
" max_concurrent_batches > 1")
251260
responses.append(result)
252261

253262
return responses

vllm/v1/worker/gpu_worker.py

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -354,36 +354,37 @@ def execute_model(
354354
scheduler_output: "SchedulerOutput",
355355
) -> Optional[ModelRunnerOutput]:
356356
intermediate_tensors = None
357-
if not get_pp_group().is_first_rank:
357+
forward_pass = scheduler_output.total_num_scheduled_tokens > 0
358+
if forward_pass and not get_pp_group().is_first_rank:
358359
intermediate_tensors = IntermediateTensors(
359360
get_pp_group().recv_tensor_dict(
360361
all_gather_group=get_tp_group()))
361362

362363
output = self.model_runner.execute_model(scheduler_output,
363364
intermediate_tensors)
365+
if isinstance(output, ModelRunnerOutput):
366+
return output
364367

368+
assert isinstance(output, IntermediateTensors)
365369
parallel_config = self.vllm_config.parallel_config
366-
if parallel_config.distributed_executor_backend != "external_launcher" \
367-
and not get_pp_group().is_last_rank:
368-
assert isinstance(output, IntermediateTensors)
369-
get_pp_group().send_tensor_dict(output.tensors,
370-
all_gather_group=get_tp_group())
371-
372-
kv_connector_output = output.kv_connector_output
373-
if not kv_connector_output:
374-
return None
375-
376-
# In case of PP with kv transfer, we need to pass through the
377-
# kv_connector_output
378-
if (not kv_connector_output.finished_sending
379-
and not kv_connector_output.finished_recving):
380-
return EMPTY_MODEL_RUNNER_OUTPUT
381-
382-
output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
383-
output.kv_connector_output = kv_connector_output
384-
return output
370+
assert parallel_config.distributed_executor_backend != (
371+
"external_launcher") and not get_pp_group().is_last_rank
372+
373+
get_pp_group().send_tensor_dict(output.tensors,
374+
all_gather_group=get_tp_group())
375+
376+
kv_connector_output = output.kv_connector_output
377+
if not kv_connector_output:
378+
return None
379+
380+
# In case of PP with kv transfer, we need to pass through the
381+
# kv_connector_output
382+
if (not kv_connector_output.finished_sending
383+
and not kv_connector_output.finished_recving):
384+
return EMPTY_MODEL_RUNNER_OUTPUT
385385

386-
assert isinstance(output, ModelRunnerOutput)
386+
output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
387+
output.kv_connector_output = kv_connector_output
387388
return output
388389

389390
def take_draft_token_ids(self) -> Optional[DraftTokenIds]:

0 commit comments

Comments
 (0)