Skip to content

Commit dde2050

Browse files
committed
feat: support image_embeds in OpenAI API
Signed-off-by: ixlmar <206748156+ixlmar@users.noreply.github.com>
1 parent 398d242 commit dde2050

File tree

8 files changed

+247
-16
lines changed

8 files changed

+247
-16
lines changed

docs/source/commands/trtllm-serve/trtllm-serve.rst

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,24 @@ TRT-LLM multimodal supports the following modalities and data types (depending o
152152
`load_base64_image utility <https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt_llm/utils/load_base64_image.py>`__
153153
for implementation details.
154154
155+
**Image embeddings**
156+
157+
It is also possible to directly provide the image embeddings to use by the multimodal
158+
model.
159+
160+
* Using "image_embeds" with base64-encoded data:
161+
162+
.. code-block:: json
163+
164+
{"role": "user", "content": [
165+
{"type": "text", "text": "What's in this image?"},
166+
{"type": "image_embeds", "image_embeds": "{image_embeddings_base64}"}}
167+
]}
168+
169+
.. note::
170+
The contents of `image_embeddings_base64` can be generated by base64-encoding
171+
the result of serializing a tensor via `torch.save`.
172+
155173
**Video**
156174
157175
* Using "video_url":

tensorrt_llm/inputs/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616
async_load_audio, async_load_image, async_load_video,
1717
convert_image_mode, default_multimodal_input_loader,
1818
encode_base64_content_from_url, encode_base64_image,
19-
get_cache_salt_id, load_image, load_video)
19+
get_cache_salt_id, load_base64_image_embeds, load_image,
20+
load_video)
2021

2122
__all__ = [
2223
"ALL_SUPPORTED_MULTIMODAL_MODELS",
@@ -57,4 +58,5 @@
5758
"get_cache_salt_id",
5859
"compute_retained_tokens_count",
5960
"compute_retention_mask",
61+
"load_base64_image_embeds",
6062
]

tensorrt_llm/inputs/utils.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,15 @@ def load_base64_image(parsed_url: str) -> Image.Image:
113113
return image
114114

115115

116+
def load_base64_image_embeds(str_content: str) -> torch.Tensor:
117+
content_bytes = base64.b64decode(str_content)
118+
with BytesIO(content_bytes) as buf:
119+
image_data: torch.Tensor = torch.load(buf,
120+
weights_only=True,
121+
map_location="cpu")
122+
return image_data
123+
124+
116125
def load_image(image: Union[str, Image.Image],
117126
format: str = "pt",
118127
device: str = "cpu") -> Union[Image.Image, torch.Tensor]:
@@ -465,10 +474,15 @@ def retrieve_all_sync(self) -> Optional[Dict[str, List[Any]]]:
465474

466475
return {modality: items for modality, items in self._data.items()}
467476

468-
def add_data(self, media_type: str, data: Union[Coroutine, Any]):
477+
def add_data(self,
478+
media_type: str,
479+
data: Union[Coroutine, Any],
480+
*,
481+
modality: Optional[str] = None):
482+
modality = modality or media_type
469483
current_count = len(self._data[media_type]) + 1
470484
placeholder = retrieve_multimodal_placeholder(self._model_type,
471-
media_type, current_count)
485+
modality, current_count)
472486
self._data[media_type].append(data)
473487
if placeholder:
474488
self._placeholder_counts[placeholder] += 1

tensorrt_llm/serve/chat_utils.py

Lines changed: 48 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717
from tensorrt_llm.inputs import (ConversationMessage, MultimodalData,
1818
MultimodalDataTracker,
1919
add_multimodal_placeholders, async_load_audio,
20-
async_load_image, async_load_video)
20+
async_load_image, async_load_video,
21+
load_base64_image_embeds)
2122
from tensorrt_llm.inputs.multimodal import MultimodalServerConfig
2223
from tensorrt_llm.logger import logger
2324

@@ -33,24 +34,38 @@ class ChatCompletionContentPartVideoParam(TypedDict, total=False):
3334
type: Required[Literal["video_url"]]
3435

3536

37+
class ChatCompletionContentPartImageEmbedsParam(TypedDict, total=False):
38+
"""Type definition for image embeddings passed in base64-encoded PyTorch tensor format."""
39+
image_embeds: Required[str]
40+
type: Required[Literal["image_embeds"]]
41+
42+
3643
# Type Aliases and Constants
3744
ChatCompletionContentPartParam: TypeAlias = Union[
38-
OpenAIChatCompletionContentPartParam, ChatCompletionContentPartVideoParam,
39-
str]
45+
OpenAIChatCompletionContentPartParam,
46+
ChatCompletionContentPartVideoParam,
47+
ChatCompletionContentPartImageEmbedsParam,
48+
str,
49+
]
4050

4151
# TODO: Add "input_audio" to support byte_encoded audio input.
4252
VALID_MESSAGE_CONTENT_MM_PART_TYPES = [
43-
"text", "image_url", "video_url", "audio_url"
53+
"text",
54+
"image_url",
55+
"video_url",
56+
"audio_url",
57+
"image_embeds",
4458
]
4559

4660
# Parser Functions
4761
_TextParser = partial(cast, ChatCompletionContentPartTextParam)
4862
_ImageParser = partial(cast, ChatCompletionContentPartImageParam)
63+
_ImageEmbedsParser = partial(cast, ChatCompletionContentPartImageEmbedsParam)
4964
_VideoParser = partial(cast, ChatCompletionContentPartVideoParam)
5065
_AudioParser = partial(cast, ChatCompletionContentPartInputAudioParam)
5166

5267
MM_PARSER_MAP: dict[str, Callable[[ChatCompletionContentPartParam], Union[
53-
str, dict[str, str]]]] = {
68+
str, dict[str, str], None]]] = {
5469
"text":
5570
lambda part: _TextParser(part).get("text", None),
5671
"image_url":
@@ -59,12 +74,20 @@ class ChatCompletionContentPartVideoParam(TypedDict, total=False):
5974
lambda part: _VideoParser(part).get("video_url", {}).get("url", None),
6075
"audio_url":
6176
lambda part: _AudioParser(part).get("audio_url", {}).get("url", None),
77+
"image_embeds":
78+
lambda part: _ImageEmbedsParser(part).get("image_embeds", None),
6279
}
6380

81+
# Map from content part tags used to directly provide embeddings
82+
# to the corresponding data modality.
83+
MM_EMBEDDING_MAP: dict[str, str] = {
84+
"image_embeds": "image",
85+
}
86+
6487

6588
def _parse_chat_message_content_mm_part(
6689
part: ChatCompletionContentPartParam
67-
) -> tuple[str, Union[str, dict[str, str]]]:
90+
) -> tuple[str, Union[str, dict[str, str], None]]:
6891
"""Parse a single multimodal part of a chat message."""
6992
assert isinstance(part, dict)
7093
part_type = part.get("type", None)
@@ -78,7 +101,7 @@ def _parse_chat_message_content_mm_part(
78101

79102

80103
def parse_chat_message_content_part(
81-
part: ChatCompletionMessageParam,
104+
part: ChatCompletionContentPartParam,
82105
mm_data_tracker: MultimodalDataTracker,
83106
) -> Optional[Any]:
84107
"""Parse a single part of a chat message."""
@@ -112,6 +135,19 @@ async def load_image_async():
112135

113136
return MultimodalData(modality="image", data=load_image_async())
114137

138+
if part_type == "image_embeds":
139+
str_content = cast(str, content)
140+
141+
async def decode_image_embeds_async():
142+
try:
143+
return load_base64_image_embeds(str_content)
144+
except Exception as e:
145+
logger.error(f"Failed to decode image data: {str(e)}")
146+
return None
147+
148+
return MultimodalData(modality="image_embeds",
149+
data=decode_image_embeds_async())
150+
115151
if part_type == "video_url":
116152
str_content = cast(str, content)
117153

@@ -147,7 +183,7 @@ async def load_audio_async():
147183

148184
def parse_chat_message_content_parts(
149185
role: str,
150-
parts: Iterable[ChatCompletionMessageParam],
186+
parts: Iterable[ChatCompletionContentPartParam],
151187
mm_data_tracker: MultimodalDataTracker,
152188
) -> ConversationMessage:
153189
"""Parse multiple parts of a chat message."""
@@ -237,7 +273,10 @@ def parse_chat_messages_coroutines(
237273
conversation.append(parsed_msg)
238274
if parsed_msg["media"]:
239275
for mdata in parsed_msg["media"]:
240-
mm_data_tracker.add_data(mdata["modality"], mdata["data"])
276+
mm_data_tracker.add_data(mdata["modality"],
277+
mdata["data"],
278+
modality=MM_EMBEDDING_MAP.get(
279+
mdata["modality"], None))
241280
mm_placeholder_count = mm_data_tracker.placeholder_counts()
242281
if mm_placeholder_count:
243282
parsed_msg["content"] = add_multimodal_placeholders(

tensorrt_llm/serve/openai_server.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from tensorrt_llm.llmapi.llm import RequestOutput
3535
from tensorrt_llm.logger import logger
3636
from tensorrt_llm.metrics.collector import MetricsCollector
37-
from tensorrt_llm.serve.chat_utils import (load_chat_template,
37+
from tensorrt_llm.serve.chat_utils import (MM_EMBEDDING_MAP, load_chat_template,
3838
parse_chat_messages_coroutines)
3939
from tensorrt_llm.serve.cluster_storage import create_cluster_storage_client
4040
from tensorrt_llm.serve.disagg_auto_scaling import DisaggClusterWorker
@@ -526,7 +526,18 @@ async def create_chat_response(
526526

527527
mm_data = await mm_coroutines
528528
if mm_data is not None:
529-
prompt["multi_modal_data"] = mm_data
529+
# single out directly provided embeddings
530+
mm_embeds = {}
531+
for tag in list(mm_data.keys()):
532+
if (modality := MM_EMBEDDING_MAP.get(tag, None)) is not None:
533+
mm_embeds[modality] = mm_data.pop(tag)
534+
535+
if mm_data:
536+
prompt["multi_modal_data"] = mm_data
537+
if mm_embeds:
538+
prompt["multi_modal_embeddings"] = mm_embeds
539+
if mm_data and mm_embeds:
540+
raise ValueError("Passing 'multi_modal_data' and 'multi_modal_embeddings' at the same time is not supported.")
530541

531542
postproc_args.reasoning_parser = self.llm.args.reasoning_parser
532543
postproc_args.tool_parser = self.tool_parser
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# used by tests/unittest/llmapi/apps/_test_openai_chat_multimodal.py
16+
17+
import pickle
18+
import tempfile
19+
from pathlib import Path
20+
from typing import Optional
21+
22+
import torch
23+
24+
from tensorrt_llm._torch.models.modeling_qwen2vl import Qwen2VLInputProcessorBase
25+
from tensorrt_llm.inputs import ExtraProcessedInputs, TextPrompt
26+
from tensorrt_llm.sampling_params import SamplingParams
27+
28+
29+
# signature taken from tensorrt_llm/inputs/registry.py
30+
def _attach_multimodal_embeddings(
31+
self,
32+
inputs: TextPrompt,
33+
multimodal_embedding: dict[str, list[torch.Tensor]],
34+
sampling_params: SamplingParams,
35+
) -> tuple[list[int], Optional[ExtraProcessedInputs]]:
36+
tempdir = tempfile.gettempdir()
37+
file_path = Path(tempdir) / "forwarded_embeddings.pickle"
38+
with open(file_path, "wb") as f:
39+
pickle.dump(multimodal_embedding, f)
40+
raise ValueError(file_path)
41+
42+
43+
setattr(Qwen2VLInputProcessorBase, "attach_multimodal_embeddings", _attach_multimodal_embeddings)

0 commit comments

Comments
 (0)