Skip to content

Commit a325fa1

Browse files
committed
demollm support for multi-modal input
Signed-off-by: Lucas Liebenwein <[email protected]>
1 parent 6eff0da commit a325fa1

File tree

1 file changed

+16
-3
lines changed

1 file changed

+16
-3
lines changed

tensorrt_llm/_torch/auto_deploy/shim/demollm.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
"""A demo LLM api to for debugging and testing purposes of e2e workflows."""
22

33
import gc
4+
from collections import defaultdict
45
from queue import Empty
5-
from typing import Any, Callable, List, Optional, Tuple
6+
from typing import Any, Callable, Dict, List, Optional, Tuple
67

78
import torch
89
import torch.multiprocessing as mp
@@ -93,10 +94,22 @@ def generate_tokens_batched(
9394
# set up sequence info object for decode phase
9495
sequence_info = self.cache_seq_interface.info
9596
sequence_info.reset()
96-
total_lens = [len(r.prompt_token_ids) for r in requests]
97+
98+
input_ids = []
99+
total_lens = []
100+
extra_args: Dict[str, List[torch.Tensor]] = defaultdict(list)
101+
102+
for request in requests:
103+
total_lens.append(len(request.prompt_token_ids))
104+
input_ids.append(request.prompt_token_ids)
105+
if request.multimodal_params is not None:
106+
for k, v in request.multimodal_params.multimodal_data.items():
107+
extra_args[k].append(v)
108+
97109
sequence_info.nest_sequences(
98-
input_ids=[r.prompt_token_ids for r in requests],
110+
input_ids=input_ids,
99111
page_assignments=self._assign_pages(total_lens),
112+
**extra_args,
100113
)
101114

102115
# setup objects we want to track for the output

0 commit comments

Comments
 (0)