1717from tensorrt_llm .inputs import (ConversationMessage , MultimodalData ,
1818 MultimodalDataTracker ,
1919 add_multimodal_placeholders , async_load_audio ,
20- async_load_image , async_load_video )
20+ async_load_image , async_load_video ,
21+ load_base64_image_embeds )
2122from tensorrt_llm .inputs .multimodal import MultimodalServerConfig
2223from tensorrt_llm .logger import logger
2324
@@ -33,24 +34,38 @@ class ChatCompletionContentPartVideoParam(TypedDict, total=False):
3334 type : Required [Literal ["video_url" ]]
3435
3536
37+ class ChatCompletionContentPartImageEmbedsParam (TypedDict , total = False ):
38+ """Type definition for image embeddings passed in base64-encoded PyTorch tensor format."""
39+ image_embeds : Required [str ]
40+ type : Required [Literal ["image_embeds" ]]
41+
42+
3643# Type Aliases and Constants
3744ChatCompletionContentPartParam : TypeAlias = Union [
38- OpenAIChatCompletionContentPartParam , ChatCompletionContentPartVideoParam ,
39- str ]
45+ OpenAIChatCompletionContentPartParam ,
46+ ChatCompletionContentPartVideoParam ,
47+ ChatCompletionContentPartImageEmbedsParam ,
48+ str ,
49+ ]
4050
4151# TODO: Add "input_audio" to support byte_encoded audio input.
4252VALID_MESSAGE_CONTENT_MM_PART_TYPES = [
43- "text" , "image_url" , "video_url" , "audio_url"
53+ "text" ,
54+ "image_url" ,
55+ "video_url" ,
56+ "audio_url" ,
57+ "image_embeds" ,
4458]
4559
4660# Parser Functions
4761_TextParser = partial (cast , ChatCompletionContentPartTextParam )
4862_ImageParser = partial (cast , ChatCompletionContentPartImageParam )
63+ _ImageEmbedsParser = partial (cast , ChatCompletionContentPartImageEmbedsParam )
4964_VideoParser = partial (cast , ChatCompletionContentPartVideoParam )
5065_AudioParser = partial (cast , ChatCompletionContentPartInputAudioParam )
5166
5267MM_PARSER_MAP : dict [str , Callable [[ChatCompletionContentPartParam ], Union [
53- str , dict [str , str ]]]] = {
68+ str , dict [str , str ], None ]]] = {
5469 "text" :
5570 lambda part : _TextParser (part ).get ("text" , None ),
5671 "image_url" :
@@ -59,12 +74,20 @@ class ChatCompletionContentPartVideoParam(TypedDict, total=False):
5974 lambda part : _VideoParser (part ).get ("video_url" , {}).get ("url" , None ),
6075 "audio_url" :
6176 lambda part : _AudioParser (part ).get ("audio_url" , {}).get ("url" , None ),
77+ "image_embeds" :
78+ lambda part : _ImageEmbedsParser (part ).get ("image_embeds" , None ),
6279 }
6380
81+ # Map from content part tags used to directly provide embeddings
82+ # to the corresponding data modality.
83+ MM_EMBEDDING_MAP : dict [str , str ] = {
84+ "image_embeds" : "image" ,
85+ }
86+
6487
6588def _parse_chat_message_content_mm_part (
6689 part : ChatCompletionContentPartParam
67- ) -> tuple [str , Union [str , dict [str , str ]]]:
90+ ) -> tuple [str , Union [str , dict [str , str ], None ]]:
6891 """Parse a single multimodal part of a chat message."""
6992 assert isinstance (part , dict )
7093 part_type = part .get ("type" , None )
@@ -78,7 +101,7 @@ def _parse_chat_message_content_mm_part(
78101
79102
80103def parse_chat_message_content_part (
81- part : ChatCompletionMessageParam ,
104+ part : ChatCompletionContentPartParam ,
82105 mm_data_tracker : MultimodalDataTracker ,
83106) -> Optional [Any ]:
84107 """Parse a single part of a chat message."""
@@ -112,6 +135,19 @@ async def load_image_async():
112135
113136 return MultimodalData (modality = "image" , data = load_image_async ())
114137
138+ if part_type == "image_embeds" :
139+ str_content = cast (str , content )
140+
141+ async def decode_image_embeds_async ():
142+ try :
143+ return load_base64_image_embeds (str_content )
144+ except Exception as e :
145+ logger .error (f"Failed to decode image data: { str (e )} " )
146+ return None
147+
148+ return MultimodalData (modality = "image_embeds" ,
149+ data = decode_image_embeds_async ())
150+
115151 if part_type == "video_url" :
116152 str_content = cast (str , content )
117153
@@ -147,7 +183,7 @@ async def load_audio_async():
147183
148184def parse_chat_message_content_parts (
149185 role : str ,
150- parts : Iterable [ChatCompletionMessageParam ],
186+ parts : Iterable [ChatCompletionContentPartParam ],
151187 mm_data_tracker : MultimodalDataTracker ,
152188) -> ConversationMessage :
153189 """Parse multiple parts of a chat message."""
@@ -237,7 +273,10 @@ def parse_chat_messages_coroutines(
237273 conversation .append (parsed_msg )
238274 if parsed_msg ["media" ]:
239275 for mdata in parsed_msg ["media" ]:
240- mm_data_tracker .add_data (mdata ["modality" ], mdata ["data" ])
276+ mm_data_tracker .add_data (mdata ["modality" ],
277+ mdata ["data" ],
278+ modality = MM_EMBEDDING_MAP .get (
279+ mdata ["modality" ], None ))
241280 mm_placeholder_count = mm_data_tracker .placeholder_counts ()
242281 if mm_placeholder_count :
243282 parsed_msg ["content" ] = add_multimodal_placeholders (
0 commit comments