@@ -114,6 +114,15 @@ def load_base64_image(parsed_url: str) -> Image.Image:
114114 return image
115115
116116
117+ def load_base64_image_embeds (str_content : str ) -> torch .Tensor :
118+ content_bytes = base64 .b64decode (str_content )
119+ with BytesIO (content_bytes ) as buf :
120+ image_data : torch .Tensor = torch .load (buf ,
121+ weights_only = True ,
122+ map_location = "cpu" )
123+ return image_data
124+
125+
117126def load_image (image : Union [str , Image .Image ],
118127 format : str = "pt" ,
119128 device : str = "cpu" ) -> Union [Image .Image , torch .Tensor ]:
@@ -425,13 +434,14 @@ class MultimodalData(TypedDict):
425434 """Type definition for multimodal data structure."""
426435 modality : str
427436 data : Any
437+ is_embedding : bool
428438
429439
430440class ConversationMessage (TypedDict ):
431441 """Type definition for conversation message structure."""
432442 role : str
433443 content : List [dict [str , Any ]]
434- media : List [MultimodalData ] | List [ torch . Tensor ] | List [ Dict [ str , Any ]]
444+ media : List [MultimodalData ]
435445
436446 # @classmethod
437447 # def fromSample(cls, sample: dict[str, str]) -> "ConversationMessage":
@@ -446,33 +456,57 @@ def __init__(
446456 model_type : str ,
447457 multimodal_server_config : Optional [MultimodalServerConfig ] = None ):
448458 self ._model_type = model_type
449- self ._data = defaultdict [str ](list )
450- self ._placeholder_counts = defaultdict [str ](int )
459+ self ._data = defaultdict [str , list ](list )
460+ self ._embeddings = defaultdict [str , list ](list )
461+ self ._placeholder_counts = defaultdict [str , int ](int )
451462 self ._multimodal_server_config = multimodal_server_config if multimodal_server_config is not None else MultimodalServerConfig (
452463 )
453464
454- async def retrieve_all_async (self ) -> Optional [Dict [str , List [Any ]]]:
455- """Retrieve all collected multimodal data."""
456- if not self ._data :
457- return None
458-
459- return {
460- modality : await asyncio .gather (* items )
461- for modality , items in self ._data .items ()
462- }
463-
464- def retrieve_all_sync (self ) -> Optional [Dict [str , List [Any ]]]:
465- """Retrieve all collected multimodal data."""
466- if not self ._data :
467- return None
468-
469- return {modality : items for modality , items in self ._data .items ()}
470-
471- def add_data (self , media_type : str , data : Union [Coroutine , Any ]):
472- current_count = len (self ._data [media_type ]) + 1
465+ async def retrieve_all_async (
466+ self
467+ ) -> tuple [Optional [Dict [str , List [Any ]]], Optional [Dict [str , List [Any ]]]]:
468+ """Retrieve all collected multimodal data and embeddings."""
469+
470+ async def _retrieve (
471+ data : Optional [dict [str ,
472+ list ]]) -> Optional [Dict [str , List [Any ]]]:
473+ if not data :
474+ return None
475+ return {
476+ modality : await asyncio .gather (* items )
477+ for modality , items in data .items () if items
478+ }
479+
480+ return await _retrieve (self ._data ), await _retrieve (self ._embeddings )
481+
482+ def retrieve_all_sync (
483+ self
484+ ) -> tuple [Optional [Dict [str , List [Any ]]], Optional [Dict [str , List [Any ]]]]:
485+ """Retrieve all collected multimodal data and embeddings."""
486+
487+ def _retrieve (
488+ data : Optional [dict [str ,
489+ list ]]) -> Optional [Dict [str , List [Any ]]]:
490+ if not data :
491+ return None
492+ return {
493+ modality : items
494+ for modality , items in data .items () if items
495+ }
496+
497+ return _retrieve (self ._data ), _retrieve (self ._embeddings )
498+
499+ def add_data (self ,
500+ media_type : str ,
501+ data : Union [Coroutine , Any ],
502+ * ,
503+ is_embedding : bool = False ):
504+ current_count = len (self ._data [media_type ]) + len (
505+ self ._embeddings [media_type ]) + 1
473506 placeholder = retrieve_multimodal_placeholder (self ._model_type ,
474507 media_type , current_count )
475- self ._data [media_type ].append (data )
508+ (self ._embeddings
509+ if is_embedding else self ._data )[media_type ].append (data )
476510 if placeholder :
477511 self ._placeholder_counts [placeholder ] += 1
478512
@@ -643,42 +677,46 @@ def convert_to_conversation_message(
643677 media = [media ]
644678 if modality in ["image" , "multiple_image" ]:
645679 if is_embedding :
680+ _load = lambda mm : mm
681+
646682 # each mm_embedding corresponds to each image placeholder
647683 if not isinstance (media , list ):
648684 media = [media ]
649-
650- mm_data = [{
651- 'modality' : modality ,
652- 'mm_embedding_info' : mm
653- } for mm in media ]
654685 else :
655- mm_data = [
656- MultimodalData (modality = modality ,
657- data = load_image (i ,
658- format = image_data_format ,
659- device = device ))
660- for i in media
661- ]
686+ _load = lambda mm : load_image (
687+ mm , format = image_data_format , device = device )
688+
689+ mm_data = [
690+ MultimodalData (modality = modality ,
691+ data = _load (mm ),
692+ is_embedding = is_embedding ) for mm in media
693+ ]
662694 elif modality == "video" :
663695 if is_embedding :
664696 raise ValueError (
665697 "External embedding is not supported for video modality yet."
666698 )
667699 mm_data = [
668- MultimodalData (modality = modality ,
669- data = load_video (i ,
670- num_frames ,
671- format = image_data_format ,
672- device = device )) for i in media
700+ MultimodalData (
701+ modality = modality ,
702+ data = load_video (i ,
703+ num_frames ,
704+ format = image_data_format ,
705+ device = device ),
706+ is_embedding = False ,
707+ ) for i in media
673708 ]
674709 elif modality == "audio" :
675710 if is_embedding :
676711 raise ValueError (
677712 "External embedding is not supported for audio modality yet."
678713 )
679714 mm_data = [
680- MultimodalData (modality = modality ,
681- data = load_audio (i , device = device )) for i in media
715+ MultimodalData (
716+ modality = modality ,
717+ data = load_audio (i , device = device ),
718+ is_embedding = False ,
719+ ) for i in media
682720 ]
683721 elif modality == "image_audio" :
684722 if is_embedding :
@@ -706,16 +744,22 @@ def convert_to_conversation_message(
706744 pass
707745 if _modal is None :
708746 raise ValueError (f"Unknown matching modality: { modality } " )
709- mm_data .append (MultimodalData (modality = _modal , data = data ))
747+ mm_data .append (
748+ MultimodalData (modality = _modal ,
749+ data = data ,
750+ is_embedding = False ))
710751 elif modality == "mixture_text_image" :
711752 mm_data = []
712753 for m in media :
713754 if m :
714755 mm_data .append (
715- MultimodalData (modality = "image" ,
716- data = load_image (m ,
717- format = image_data_format ,
718- device = device )))
756+ MultimodalData (
757+ modality = "image" ,
758+ data = load_image (m ,
759+ format = image_data_format ,
760+ device = device ),
761+ is_embedding = False ,
762+ ))
719763 else :
720764 raise ValueError (f"Unknown modality: { modality } " )
721765 return ConversationMessage (role = "user" , content = prompt , media = mm_data )
@@ -749,17 +793,12 @@ def convert_to_conversation_message(
749793 is_embedding )
750794 mm_data_tracker = MultimodalDataTracker (model_type )
751795 for mdata in conv ["media" ]:
752- # Check if mdata is a MultimodalData
753- if isinstance (mdata ,
754- dict ) and "modality" in mdata and "data" in mdata :
755- mdata_modality = mdata ["modality" ]
756- if modality == "multiple_image" :
757- mdata_modality = "image"
758- mm_data_tracker .add_data (mdata_modality , mdata ["data" ])
759- else :
760- # Add embeddings to the tracker for placeholder handling
761- mm_data_tracker .add_data (mdata ["modality" ],
762- mdata ["mm_embedding_info" ])
796+ mdata_modality = mdata ["modality" ]
797+ if modality == "multiple_image" :
798+ mdata_modality = "image"
799+ mm_data_tracker .add_data (mdata_modality ,
800+ mdata ["data" ],
801+ is_embedding = is_embedding )
763802 mm_placeholder_counts = mm_data_tracker .placeholder_counts ()
764803 prompt = conv ["content" ]
765804 if mm_placeholder_counts :
@@ -776,11 +815,13 @@ def convert_to_conversation_message(
776815
777816 if mm_placeholder_counts :
778817 if mm_embeddings is not None :
779- input [
818+ _ , input [
780819 "multi_modal_embeddings" ] = mm_data_tracker .retrieve_all_sync (
781820 )
782821 else :
783- input ["multi_modal_data" ] = mm_data_tracker .retrieve_all_sync ()
822+ input [
823+ "multi_modal_data" ], _ = mm_data_tracker .retrieve_all_sync (
824+ )
784825 inputs .append (input )
785826
786827 return inputs
0 commit comments