1111from ....executor import GenerationExecutor
1212from ....executor .request import GenerationRequest
1313from ....executor .result import CompletionOutput , GenerationResult
14+ from ....inputs .multimodal import MultimodalParams
1415from ....sampling_params import SamplingParams
1516from ...pyexecutor .sampler import greedy_search_sampling_batch , top_k_sampling_batch
1617from ..distributed import common as dist_ad
@@ -35,8 +36,11 @@ def __init__(self, *args, **kwargs) -> None:
3536 self .queue = mp .Queue ()
3637
3738 @torch .inference_mode ()
38- def __call__ (self , requests : GenerationRequest ) -> mp .Queue :
39+ def __call__ (
40+ self , requests : GenerationRequest , multimodal_params : Optional [MultimodalParams ]
41+ ) -> mp .Queue :
3942 """Generate tokens and put the results in a queue and return the queue."""
43+ requests .multimodal_params = multimodal_params
4044 output = self .generate_tokens_batched ([requests ])[0 ]
4145 self .queue .put (output )
4246 return self .queue
@@ -274,6 +278,7 @@ def _run_engine(
274278 def _unpack (inputs ) -> GenerationRequest :
275279 args , kwargs = inputs # unpack the inputs
276280 request : GenerationRequest = args [0 ]
281+ request .multimodal_params : Optional [MultimodalParams ] = args [1 ]
277282 return request
278283
279284 engine = DemoEngine .build_from_config (** engine_kwargs )
@@ -328,8 +333,11 @@ def submit(self, request: GenerationRequest) -> GenerationResult:
328333 request .set_id (client_id )
329334
330335 # submit request to our demo engine and store results
336+ # NOTE: when returning from this function, the reference request.multimodal_params will
337+ # be cleared immediately. So we pass it in explicitly to maintain a reference even when
338+ # requests get submitted asynchronously.
331339 result = GenerationResult (request )
332- result .queue = self .engine_executor (request )
340+ result .queue = self .engine_executor (request , request . multimodal_params )
333341
334342 return result
335343
0 commit comments