Skip to content

Commit 052c36d

Browse files
authored
[TRTLLM-9522][feat] support image_embeds in OpenAI API (#9715)
Signed-off-by: ixlmar <206748156+ixlmar@users.noreply.github.com>
1 parent 487287a commit 052c36d

File tree

11 files changed

+367
-84
lines changed

11 files changed

+367
-84
lines changed

cpp/kernels/fmha_v2/pytest.ini

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ markers =
66
fmhca
77
debug
88
bench
9+
needs_l40s
910
# bin: unit tests
1011
# test: python script for invoking fmha.exe
1112
testpaths = bin test

docs/source/commands/trtllm-serve/trtllm-serve.rst

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,24 @@ TRT-LLM multimodal supports the following modalities and data types (depending o
170170
`load_base64_image utility <https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt_llm/utils/load_base64_image.py>`__
171171
for implementation details.
172172
173+
**Image embeddings**
174+
175+
It is also possible to directly provide the image embeddings to use by the multimodal
176+
model.
177+
178+
* Using "image_embeds" with base64-encoded data:
179+
180+
.. code-block:: json
181+
182+
{"role": "user", "content": [
183+
{"type": "text", "text": "What's in this image?"},
184+
{"type": "image_embeds", "image_embeds": {"data": "{image_embeddings_base64}"}}}
185+
]}
186+
187+
.. note::
188+
The contents of `image_embeddings_base64` can be generated by base64-encoding
189+
the result of serializing a tensor via `torch.save`.
190+
173191
**Video**
174192
175193
* Using "video_url":

tensorrt_llm/inputs/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616
async_load_audio, async_load_image, async_load_video,
1717
convert_image_mode, default_multimodal_input_loader,
1818
encode_base64_content_from_url, encode_base64_image,
19-
get_cache_salt_id, load_image, load_video)
19+
get_cache_salt_id, load_base64_image_embeds, load_image,
20+
load_video)
2021

2122
__all__ = [
2223
"ALL_SUPPORTED_MULTIMODAL_MODELS",
@@ -57,4 +58,5 @@
5758
"get_cache_salt_id",
5859
"compute_retained_tokens_count",
5960
"compute_retention_mask",
61+
"load_base64_image_embeds",
6062
]

tensorrt_llm/inputs/utils.py

Lines changed: 101 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,15 @@ def load_base64_image(parsed_url: str) -> Image.Image:
114114
return image
115115

116116

117+
def load_base64_image_embeds(str_content: str) -> torch.Tensor:
118+
content_bytes = base64.b64decode(str_content)
119+
with BytesIO(content_bytes) as buf:
120+
image_data: torch.Tensor = torch.load(buf,
121+
weights_only=True,
122+
map_location="cpu")
123+
return image_data
124+
125+
117126
def load_image(image: Union[str, Image.Image],
118127
format: str = "pt",
119128
device: str = "cpu") -> Union[Image.Image, torch.Tensor]:
@@ -425,13 +434,14 @@ class MultimodalData(TypedDict):
425434
"""Type definition for multimodal data structure."""
426435
modality: str
427436
data: Any
437+
is_embedding: bool
428438

429439

430440
class ConversationMessage(TypedDict):
431441
"""Type definition for conversation message structure."""
432442
role: str
433443
content: List[dict[str, Any]]
434-
media: List[MultimodalData] | List[torch.Tensor] | List[Dict[str, Any]]
444+
media: List[MultimodalData]
435445

436446
# @classmethod
437447
# def fromSample(cls, sample: dict[str, str]) -> "ConversationMessage":
@@ -446,33 +456,57 @@ def __init__(
446456
model_type: str,
447457
multimodal_server_config: Optional[MultimodalServerConfig] = None):
448458
self._model_type = model_type
449-
self._data = defaultdict[str](list)
450-
self._placeholder_counts = defaultdict[str](int)
459+
self._data = defaultdict[str, list](list)
460+
self._embeddings = defaultdict[str, list](list)
461+
self._placeholder_counts = defaultdict[str, int](int)
451462
self._multimodal_server_config = multimodal_server_config if multimodal_server_config is not None else MultimodalServerConfig(
452463
)
453464

454-
async def retrieve_all_async(self) -> Optional[Dict[str, List[Any]]]:
455-
"""Retrieve all collected multimodal data."""
456-
if not self._data:
457-
return None
458-
459-
return {
460-
modality: await asyncio.gather(*items)
461-
for modality, items in self._data.items()
462-
}
463-
464-
def retrieve_all_sync(self) -> Optional[Dict[str, List[Any]]]:
465-
"""Retrieve all collected multimodal data."""
466-
if not self._data:
467-
return None
468-
469-
return {modality: items for modality, items in self._data.items()}
470-
471-
def add_data(self, media_type: str, data: Union[Coroutine, Any]):
472-
current_count = len(self._data[media_type]) + 1
465+
async def retrieve_all_async(
466+
self
467+
) -> tuple[Optional[Dict[str, List[Any]]], Optional[Dict[str, List[Any]]]]:
468+
"""Retrieve all collected multimodal data and embeddings."""
469+
470+
async def _retrieve(
471+
data: Optional[dict[str,
472+
list]]) -> Optional[Dict[str, List[Any]]]:
473+
if not data:
474+
return None
475+
return {
476+
modality: await asyncio.gather(*items)
477+
for modality, items in data.items() if items
478+
}
479+
480+
return await _retrieve(self._data), await _retrieve(self._embeddings)
481+
482+
def retrieve_all_sync(
483+
self
484+
) -> tuple[Optional[Dict[str, List[Any]]], Optional[Dict[str, List[Any]]]]:
485+
"""Retrieve all collected multimodal data and embeddings."""
486+
487+
def _retrieve(
488+
data: Optional[dict[str,
489+
list]]) -> Optional[Dict[str, List[Any]]]:
490+
if not data:
491+
return None
492+
return {
493+
modality: items
494+
for modality, items in data.items() if items
495+
}
496+
497+
return _retrieve(self._data), _retrieve(self._embeddings)
498+
499+
def add_data(self,
500+
media_type: str,
501+
data: Union[Coroutine, Any],
502+
*,
503+
is_embedding: bool = False):
504+
current_count = len(self._data[media_type]) + len(
505+
self._embeddings[media_type]) + 1
473506
placeholder = retrieve_multimodal_placeholder(self._model_type,
474507
media_type, current_count)
475-
self._data[media_type].append(data)
508+
(self._embeddings
509+
if is_embedding else self._data)[media_type].append(data)
476510
if placeholder:
477511
self._placeholder_counts[placeholder] += 1
478512

@@ -643,42 +677,46 @@ def convert_to_conversation_message(
643677
media = [media]
644678
if modality in ["image", "multiple_image"]:
645679
if is_embedding:
680+
_load = lambda mm: mm
681+
646682
# each mm_embedding corresponds to each image placeholder
647683
if not isinstance(media, list):
648684
media = [media]
649-
650-
mm_data = [{
651-
'modality': modality,
652-
'mm_embedding_info': mm
653-
} for mm in media]
654685
else:
655-
mm_data = [
656-
MultimodalData(modality=modality,
657-
data=load_image(i,
658-
format=image_data_format,
659-
device=device))
660-
for i in media
661-
]
686+
_load = lambda mm: load_image(
687+
mm, format=image_data_format, device=device)
688+
689+
mm_data = [
690+
MultimodalData(modality=modality,
691+
data=_load(mm),
692+
is_embedding=is_embedding) for mm in media
693+
]
662694
elif modality == "video":
663695
if is_embedding:
664696
raise ValueError(
665697
"External embedding is not supported for video modality yet."
666698
)
667699
mm_data = [
668-
MultimodalData(modality=modality,
669-
data=load_video(i,
670-
num_frames,
671-
format=image_data_format,
672-
device=device)) for i in media
700+
MultimodalData(
701+
modality=modality,
702+
data=load_video(i,
703+
num_frames,
704+
format=image_data_format,
705+
device=device),
706+
is_embedding=False,
707+
) for i in media
673708
]
674709
elif modality == "audio":
675710
if is_embedding:
676711
raise ValueError(
677712
"External embedding is not supported for audio modality yet."
678713
)
679714
mm_data = [
680-
MultimodalData(modality=modality,
681-
data=load_audio(i, device=device)) for i in media
715+
MultimodalData(
716+
modality=modality,
717+
data=load_audio(i, device=device),
718+
is_embedding=False,
719+
) for i in media
682720
]
683721
elif modality == "image_audio":
684722
if is_embedding:
@@ -706,16 +744,22 @@ def convert_to_conversation_message(
706744
pass
707745
if _modal is None:
708746
raise ValueError(f"Unknown matching modality: {modality}")
709-
mm_data.append(MultimodalData(modality=_modal, data=data))
747+
mm_data.append(
748+
MultimodalData(modality=_modal,
749+
data=data,
750+
is_embedding=False))
710751
elif modality == "mixture_text_image":
711752
mm_data = []
712753
for m in media:
713754
if m:
714755
mm_data.append(
715-
MultimodalData(modality="image",
716-
data=load_image(m,
717-
format=image_data_format,
718-
device=device)))
756+
MultimodalData(
757+
modality="image",
758+
data=load_image(m,
759+
format=image_data_format,
760+
device=device),
761+
is_embedding=False,
762+
))
719763
else:
720764
raise ValueError(f"Unknown modality: {modality}")
721765
return ConversationMessage(role="user", content=prompt, media=mm_data)
@@ -749,17 +793,12 @@ def convert_to_conversation_message(
749793
is_embedding)
750794
mm_data_tracker = MultimodalDataTracker(model_type)
751795
for mdata in conv["media"]:
752-
# Check if mdata is a MultimodalData
753-
if isinstance(mdata,
754-
dict) and "modality" in mdata and "data" in mdata:
755-
mdata_modality = mdata["modality"]
756-
if modality == "multiple_image":
757-
mdata_modality = "image"
758-
mm_data_tracker.add_data(mdata_modality, mdata["data"])
759-
else:
760-
# Add embeddings to the tracker for placeholder handling
761-
mm_data_tracker.add_data(mdata["modality"],
762-
mdata["mm_embedding_info"])
796+
mdata_modality = mdata["modality"]
797+
if modality == "multiple_image":
798+
mdata_modality = "image"
799+
mm_data_tracker.add_data(mdata_modality,
800+
mdata["data"],
801+
is_embedding=is_embedding)
763802
mm_placeholder_counts = mm_data_tracker.placeholder_counts()
764803
prompt = conv["content"]
765804
if mm_placeholder_counts:
@@ -776,11 +815,13 @@ def convert_to_conversation_message(
776815

777816
if mm_placeholder_counts:
778817
if mm_embeddings is not None:
779-
input[
818+
_, input[
780819
"multi_modal_embeddings"] = mm_data_tracker.retrieve_all_sync(
781820
)
782821
else:
783-
input["multi_modal_data"] = mm_data_tracker.retrieve_all_sync()
822+
input[
823+
"multi_modal_data"], _ = mm_data_tracker.retrieve_all_sync(
824+
)
784825
inputs.append(input)
785826

786827
return inputs

0 commit comments

Comments
 (0)