Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cpp/kernels/fmha_v2/pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ markers =
fmhca
debug
bench
needs_l40s
# bin: unit tests
# test: python script for invoking fmha.exe
testpaths = bin test
18 changes: 18 additions & 0 deletions docs/source/commands/trtllm-serve/trtllm-serve.rst
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,24 @@ TRT-LLM multimodal supports the following modalities and data types (depending o
`load_base64_image utility <https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt_llm/utils/load_base64_image.py>`__
for implementation details.

**Image embeddings**

It is also possible to directly provide the image embeddings to use by the multimodal
model.

* Using "image_embeds" with base64-encoded data:

.. code-block:: json

{"role": "user", "content": [
{"type": "text", "text": "What's in this image?"},
{"type": "image_embeds", "image_embeds": {"data": "{image_embeddings_base64}"}}}
]}

.. note::
The contents of `image_embeddings_base64` can be generated by base64-encoding
the result of serializing a tensor via `torch.save`.

**Video**

* Using "video_url":
Expand Down
4 changes: 3 additions & 1 deletion tensorrt_llm/inputs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
async_load_audio, async_load_image, async_load_video,
convert_image_mode, default_multimodal_input_loader,
encode_base64_content_from_url, encode_base64_image,
get_cache_salt_id, load_image, load_video)
get_cache_salt_id, load_base64_image_embeds, load_image,
load_video)

__all__ = [
"ALL_SUPPORTED_MULTIMODAL_MODELS",
Expand Down Expand Up @@ -57,4 +58,5 @@
"get_cache_salt_id",
"compute_retained_tokens_count",
"compute_retention_mask",
"load_base64_image_embeds",
]
161 changes: 101 additions & 60 deletions tensorrt_llm/inputs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,15 @@ def load_base64_image(parsed_url: str) -> Image.Image:
return image


def load_base64_image_embeds(str_content: str) -> torch.Tensor:
content_bytes = base64.b64decode(str_content)
with BytesIO(content_bytes) as buf:
image_data: torch.Tensor = torch.load(buf,
weights_only=True,
map_location="cpu")
return image_data


def load_image(image: Union[str, Image.Image],
format: str = "pt",
device: str = "cpu") -> Union[Image.Image, torch.Tensor]:
Expand Down Expand Up @@ -425,13 +434,14 @@ class MultimodalData(TypedDict):
"""Type definition for multimodal data structure."""
modality: str
data: Any
is_embedding: bool


class ConversationMessage(TypedDict):
"""Type definition for conversation message structure."""
role: str
content: List[dict[str, Any]]
media: List[MultimodalData] | List[torch.Tensor] | List[Dict[str, Any]]
media: List[MultimodalData]

# @classmethod
# def fromSample(cls, sample: dict[str, str]) -> "ConversationMessage":
Expand All @@ -446,33 +456,57 @@ def __init__(
model_type: str,
multimodal_server_config: Optional[MultimodalServerConfig] = None):
self._model_type = model_type
self._data = defaultdict[str](list)
self._placeholder_counts = defaultdict[str](int)
self._data = defaultdict[str, list](list)
self._embeddings = defaultdict[str, list](list)
self._placeholder_counts = defaultdict[str, int](int)
self._multimodal_server_config = multimodal_server_config if multimodal_server_config is not None else MultimodalServerConfig(
)

async def retrieve_all_async(self) -> Optional[Dict[str, List[Any]]]:
"""Retrieve all collected multimodal data."""
if not self._data:
return None

return {
modality: await asyncio.gather(*items)
for modality, items in self._data.items()
}

def retrieve_all_sync(self) -> Optional[Dict[str, List[Any]]]:
"""Retrieve all collected multimodal data."""
if not self._data:
return None

return {modality: items for modality, items in self._data.items()}

def add_data(self, media_type: str, data: Union[Coroutine, Any]):
current_count = len(self._data[media_type]) + 1
async def retrieve_all_async(
self
) -> tuple[Optional[Dict[str, List[Any]]], Optional[Dict[str, List[Any]]]]:
"""Retrieve all collected multimodal data and embeddings."""

async def _retrieve(
data: Optional[dict[str,
list]]) -> Optional[Dict[str, List[Any]]]:
if not data:
return None
return {
modality: await asyncio.gather(*items)
for modality, items in data.items() if items
}

return await _retrieve(self._data), await _retrieve(self._embeddings)

def retrieve_all_sync(
self
) -> tuple[Optional[Dict[str, List[Any]]], Optional[Dict[str, List[Any]]]]:
"""Retrieve all collected multimodal data and embeddings."""

def _retrieve(
data: Optional[dict[str,
list]]) -> Optional[Dict[str, List[Any]]]:
if not data:
return None
return {
modality: items
for modality, items in data.items() if items
}

return _retrieve(self._data), _retrieve(self._embeddings)

def add_data(self,
media_type: str,
data: Union[Coroutine, Any],
*,
is_embedding: bool = False):
current_count = len(self._data[media_type]) + len(
self._embeddings[media_type]) + 1
placeholder = retrieve_multimodal_placeholder(self._model_type,
media_type, current_count)
self._data[media_type].append(data)
(self._embeddings
if is_embedding else self._data)[media_type].append(data)
if placeholder:
self._placeholder_counts[placeholder] += 1

Expand Down Expand Up @@ -643,42 +677,46 @@ def convert_to_conversation_message(
media = [media]
if modality in ["image", "multiple_image"]:
if is_embedding:
_load = lambda mm: mm

# each mm_embedding corresponds to each image placeholder
if not isinstance(media, list):
media = [media]

mm_data = [{
'modality': modality,
'mm_embedding_info': mm
} for mm in media]
else:
mm_data = [
MultimodalData(modality=modality,
data=load_image(i,
format=image_data_format,
device=device))
for i in media
]
_load = lambda mm: load_image(
mm, format=image_data_format, device=device)

mm_data = [
MultimodalData(modality=modality,
data=_load(mm),
is_embedding=is_embedding) for mm in media
]
elif modality == "video":
if is_embedding:
raise ValueError(
"External embedding is not supported for video modality yet."
)
mm_data = [
MultimodalData(modality=modality,
data=load_video(i,
num_frames,
format=image_data_format,
device=device)) for i in media
MultimodalData(
modality=modality,
data=load_video(i,
num_frames,
format=image_data_format,
device=device),
is_embedding=False,
) for i in media
]
elif modality == "audio":
if is_embedding:
raise ValueError(
"External embedding is not supported for audio modality yet."
)
mm_data = [
MultimodalData(modality=modality,
data=load_audio(i, device=device)) for i in media
MultimodalData(
modality=modality,
data=load_audio(i, device=device),
is_embedding=False,
) for i in media
]
elif modality == "image_audio":
if is_embedding:
Expand Down Expand Up @@ -706,16 +744,22 @@ def convert_to_conversation_message(
pass
if _modal is None:
raise ValueError(f"Unknown matching modality: {modality}")
mm_data.append(MultimodalData(modality=_modal, data=data))
mm_data.append(
MultimodalData(modality=_modal,
data=data,
is_embedding=False))
elif modality == "mixture_text_image":
mm_data = []
for m in media:
if m:
mm_data.append(
MultimodalData(modality="image",
data=load_image(m,
format=image_data_format,
device=device)))
MultimodalData(
modality="image",
data=load_image(m,
format=image_data_format,
device=device),
is_embedding=False,
))
else:
raise ValueError(f"Unknown modality: {modality}")
return ConversationMessage(role="user", content=prompt, media=mm_data)
Expand Down Expand Up @@ -749,17 +793,12 @@ def convert_to_conversation_message(
is_embedding)
mm_data_tracker = MultimodalDataTracker(model_type)
for mdata in conv["media"]:
# Check if mdata is a MultimodalData
if isinstance(mdata,
dict) and "modality" in mdata and "data" in mdata:
mdata_modality = mdata["modality"]
if modality == "multiple_image":
mdata_modality = "image"
mm_data_tracker.add_data(mdata_modality, mdata["data"])
else:
# Add embeddings to the tracker for placeholder handling
mm_data_tracker.add_data(mdata["modality"],
mdata["mm_embedding_info"])
mdata_modality = mdata["modality"]
if modality == "multiple_image":
mdata_modality = "image"
mm_data_tracker.add_data(mdata_modality,
mdata["data"],
is_embedding=is_embedding)
mm_placeholder_counts = mm_data_tracker.placeholder_counts()
prompt = conv["content"]
if mm_placeholder_counts:
Expand All @@ -776,11 +815,13 @@ def convert_to_conversation_message(

if mm_placeholder_counts:
if mm_embeddings is not None:
input[
_, input[
"multi_modal_embeddings"] = mm_data_tracker.retrieve_all_sync(
)
else:
input["multi_modal_data"] = mm_data_tracker.retrieve_all_sync()
input[
"multi_modal_data"], _ = mm_data_tracker.retrieve_all_sync(
)
inputs.append(input)

return inputs
Expand Down
Loading