Skip to content

Commit c3752dc

Browse files
committed
Update the EngineCore interface for backward compatibility
Signed-off-by: Jialin Ouyang <[email protected]>
1 parent 201ad34 commit c3752dc

File tree

2 files changed

+18
-15
lines changed

2 files changed

+18
-15
lines changed

vllm/v1/engine/core.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,9 @@ def _initialize_kv_caches(
195195
"warmup model) took %.2f seconds"), elapsed)
196196
return num_gpu_blocks, num_cpu_blocks, scheduler_kv_cache_config
197197

198-
def add_request(self, request: Request):
198+
def add_request(self, request: Union[EngineCoreRequest, Request]):
199+
if type(request) is EngineCoreRequest:
200+
request = self._preprocess_add_request(request)
199201
"""Add request to the scheduler."""
200202
if pooling_params := request.pooling_params:
201203
supported_pooling_tasks = (
@@ -204,13 +206,13 @@ def add_request(self, request: Request):
204206
raise ValueError(f"Unsupported task: {pooling_params.task!r} "
205207
f"Supported tasks: {supported_pooling_tasks}")
206208

207-
if request.mm_hashes is not None:
209+
if request.mm_hashes:
208210
# Here, if hash exists for a multimodal input, then it will be
209211
# fetched from the cache, else it will be added to the cache.
210212
# Note that the cache here is mirrored with the client cache, so
211213
# anything that has a hash must have a HIT cache entry here
212214
# as well.
213-
assert request.mm_inputs is not None
215+
assert request.mm_inputs
214216
updated_mm_inputs = self.mm_input_cache_server.get_and_update_p1(
215217
request.mm_inputs, request.mm_hashes)
216218
assert isinstance(updated_mm_inputs, list)
@@ -389,6 +391,13 @@ def save_tensorized_model(
389391
self.model_executor.save_tensorized_model(
390392
tensorizer_config=tensorizer_config, )
391393

394+
def _preprocess_add_request(self, request: EngineCoreRequest) -> Request:
395+
"""Preprocess the request.
396+
397+
This function could be directly used in input processing thread to allow
398+
request initialization running in parallel with Model forward"""
399+
return Request.from_engine_core_request(request)
400+
392401

393402
class EngineCoreProc(EngineCore):
394403
"""ZMQ-wrapper for running EngineCore in background process."""
@@ -772,7 +781,7 @@ def process_input_sockets(self, input_addresses: list[str],
772781
# Deserialize the request data.
773782
if request_type == EngineCoreRequestType.ADD:
774783
request = add_request_decoder.decode(data_frames)
775-
request = self._post_process_add_request(request)
784+
request = self._preprocess_add_request(request)
776785
else:
777786
request = generic_decoder.decode(data_frames)
778787

@@ -840,13 +849,6 @@ def process_output_sockets(self, output_paths: list[str],
840849
# Limit the number of buffers to reuse.
841850
reuse_buffers.append(buffer)
842851

843-
def _post_process_add_request(self, request: EngineCoreRequest) -> Request:
844-
"""Post-processes the request before reaching to EngineCore.
845-
846-
This call would be executed in parallel with Model forward which
847-
relaxes request preparation works out from critical path."""
848-
return Request.from_engine_core_request(request)
849-
850852

851853
class DPEngineCoreProc(EngineCoreProc):
852854
"""ZMQ-wrapper for running EngineCore in background process
@@ -927,7 +929,7 @@ def shutdown(self):
927929
if dp_group := getattr(self, "dp_group", None):
928930
stateless_destroy_torch_distributed_process_group(dp_group)
929931

930-
def add_request(self, request: Request):
932+
def add_request(self, request: Union[EngineCoreRequest, Request]):
931933
if self.has_coordinator and request.current_wave != self.current_wave:
932934
if request.current_wave > self.current_wave:
933935
self.current_wave = request.current_wave

vllm/v1/engine/core_client.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from vllm.v1.engine.utils import (CoreEngineActorManager,
3333
CoreEngineProcManager, launch_core_engines)
3434
from vllm.v1.executor.abstract import Executor
35+
from vllm.v1.request import Request
3536
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder, bytestr
3637

3738
logger = init_logger(__name__)
@@ -104,7 +105,7 @@ def shutdown(self):
104105
def get_output(self) -> EngineCoreOutputs:
105106
raise NotImplementedError
106107

107-
def add_request(self, request: EngineCoreRequest) -> None:
108+
def add_request(self, request: Union[EngineCoreRequest, Request]) -> None:
108109
raise NotImplementedError
109110

110111
def profile(self, is_start: bool = True) -> None:
@@ -238,7 +239,7 @@ def get_output(self) -> EngineCoreOutputs:
238239
outputs, _ = self.engine_core.step()
239240
return outputs.get(0) or EngineCoreOutputs()
240241

241-
def add_request(self, request: EngineCoreRequest) -> None:
242+
def add_request(self, request: Union[EngineCoreRequest, Request]) -> None:
242243
self.engine_core.add_request(request)
243244

244245
def abort_requests(self, request_ids: list[str]) -> None:
@@ -603,7 +604,7 @@ def call_utility(self, method: str, *args) -> Any:
603604

604605
return future.result()
605606

606-
def add_request(self, request: EngineCoreRequest) -> None:
607+
def add_request(self, request: Union[EngineCoreRequest, Request]) -> None:
607608
if self.is_dp:
608609
self.engines_running = True
609610
self._send_input(EngineCoreRequestType.ADD, request)

0 commit comments

Comments
 (0)