Skip to content

Commit e5cbfcc

Browse files
committed
fix demollm for world_size >=1
Signed-off-by: Lucas Liebenwein <[email protected]>
1 parent a325fa1 commit e5cbfcc

File tree

1 file changed

+10
-2
lines changed

1 file changed

+10
-2
lines changed

tensorrt_llm/_torch/auto_deploy/shim/demollm.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from ....executor import GenerationExecutor
1212
from ....executor.request import GenerationRequest
1313
from ....executor.result import CompletionOutput, GenerationResult
14+
from ....inputs.multimodal import MultimodalParams
1415
from ....sampling_params import SamplingParams
1516
from ...pyexecutor.sampler import greedy_search_sampling_batch, top_k_sampling_batch
1617
from ..distributed import common as dist_ad
@@ -35,8 +36,11 @@ def __init__(self, *args, **kwargs) -> None:
3536
self.queue = mp.Queue()
3637

3738
@torch.inference_mode()
38-
def __call__(self, requests: GenerationRequest) -> mp.Queue:
39+
def __call__(
40+
self, requests: GenerationRequest, multimodal_params: Optional[MultimodalParams]
41+
) -> mp.Queue:
3942
"""Generate tokens and put the results in a queue and return the queue."""
43+
requests.multimodal_params = multimodal_params
4044
output = self.generate_tokens_batched([requests])[0]
4145
self.queue.put(output)
4246
return self.queue
@@ -274,6 +278,7 @@ def _run_engine(
274278
def _unpack(inputs) -> GenerationRequest:
275279
args, kwargs = inputs # unpack the inputs
276280
request: GenerationRequest = args[0]
281+
request.multimodal_params: Optional[MultimodalParams] = args[1]
277282
return request
278283

279284
engine = DemoEngine.build_from_config(**engine_kwargs)
@@ -328,8 +333,11 @@ def submit(self, request: GenerationRequest) -> GenerationResult:
328333
request.set_id(client_id)
329334

330335
# submit request to our demo engine and store results
336+
# NOTE: when returning from this function, the reference request.multimodal_params will
337+
# be cleared immediately. So we pass it in explicitly to maintain a reference even when
338+
# requests get submitted asynchronously.
331339
result = GenerationResult(request)
332-
result.queue = self.engine_executor(request)
340+
result.queue = self.engine_executor(request, request.multimodal_params)
333341

334342
return result
335343

0 commit comments

Comments
 (0)