@@ -195,7 +195,9 @@ def _initialize_kv_caches(
195
195
"warmup model) took %.2f seconds" ), elapsed )
196
196
return num_gpu_blocks , num_cpu_blocks , scheduler_kv_cache_config
197
197
198
- def add_request (self , request : Request ):
198
+ def add_request (self , request : Union [EngineCoreRequest , Request ]):
199
+ if type (request ) is EngineCoreRequest :
200
+ request = self ._preprocess_add_request (request )
199
201
"""Add request to the scheduler."""
200
202
if pooling_params := request .pooling_params :
201
203
supported_pooling_tasks = (
@@ -204,13 +206,13 @@ def add_request(self, request: Request):
204
206
raise ValueError (f"Unsupported task: { pooling_params .task !r} "
205
207
f"Supported tasks: { supported_pooling_tasks } " )
206
208
207
- if request .mm_hashes is not None :
209
+ if request .mm_hashes :
208
210
# Here, if hash exists for a multimodal input, then it will be
209
211
# fetched from the cache, else it will be added to the cache.
210
212
# Note that the cache here is mirrored with the client cache, so
211
213
# anything that has a hash must have a HIT cache entry here
212
214
# as well.
213
- assert request .mm_inputs is not None
215
+ assert request .mm_inputs
214
216
updated_mm_inputs = self .mm_input_cache_server .get_and_update_p1 (
215
217
request .mm_inputs , request .mm_hashes )
216
218
assert isinstance (updated_mm_inputs , list )
@@ -389,6 +391,13 @@ def save_tensorized_model(
389
391
self .model_executor .save_tensorized_model (
390
392
tensorizer_config = tensorizer_config , )
391
393
394
+ def _preprocess_add_request (self , request : EngineCoreRequest ) -> Request :
395
+ """Preprocess the request.
396
+
397
+ This function could be directly used in input processing thread to allow
398
+ request initialization running in parallel with Model forward"""
399
+ return Request .from_engine_core_request (request )
400
+
392
401
393
402
class EngineCoreProc (EngineCore ):
394
403
"""ZMQ-wrapper for running EngineCore in background process."""
@@ -772,7 +781,7 @@ def process_input_sockets(self, input_addresses: list[str],
772
781
# Deserialize the request data.
773
782
if request_type == EngineCoreRequestType .ADD :
774
783
request = add_request_decoder .decode (data_frames )
775
- request = self ._post_process_add_request (request )
784
+ request = self ._preprocess_add_request (request )
776
785
else :
777
786
request = generic_decoder .decode (data_frames )
778
787
@@ -840,13 +849,6 @@ def process_output_sockets(self, output_paths: list[str],
840
849
# Limit the number of buffers to reuse.
841
850
reuse_buffers .append (buffer )
842
851
843
- def _post_process_add_request (self , request : EngineCoreRequest ) -> Request :
844
- """Post-processes the request before reaching to EngineCore.
845
-
846
- This call would be executed in parallel with Model forward which
847
- relaxes request preparation works out from critical path."""
848
- return Request .from_engine_core_request (request )
849
-
850
852
851
853
class DPEngineCoreProc (EngineCoreProc ):
852
854
"""ZMQ-wrapper for running EngineCore in background process
@@ -927,7 +929,7 @@ def shutdown(self):
927
929
if dp_group := getattr (self , "dp_group" , None ):
928
930
stateless_destroy_torch_distributed_process_group (dp_group )
929
931
930
- def add_request (self , request : Request ):
932
+ def add_request (self , request : Union [ EngineCoreRequest , Request ] ):
931
933
if self .has_coordinator and request .current_wave != self .current_wave :
932
934
if request .current_wave > self .current_wave :
933
935
self .current_wave = request .current_wave
0 commit comments