@@ -134,6 +134,7 @@ class TemplateType:
134134 paligemma = 'paligemma'
135135 mplug_owl2 = 'mplug-owl2'
136136 mplug_owl3 = 'mplug_owl3'
137+ mplug_owl3v = 'mplug_owl3v'
137138 wizardlm2_awq = 'wizardlm2-awq'
138139 wizardlm2 = 'wizardlm2'
139140 atom = 'atom'
@@ -4004,7 +4005,69 @@ def data_collator(self, batch: List[Dict[str, Any]], padding_to: Optional[int] =
40044005 return res
40054006
40064007
4008+ class mPlugOwl3vTemplate (mPlugOwl3Template ):
4009+ system = None
4010+
4011+ def _encode (self , example : Dict [str , Any ]) -> Tuple [Dict [str , Any ], Dict [str , Any ]]:
4012+ inputs , _ = super (mPlugOwl3Template , self )._encode (example )
4013+ if len (inputs ) == 0 :
4014+ return inputs , {}
4015+ images = example ['images' ]
4016+ videos = example ['videos' ]
4017+ cut_enable = not videos
4018+ input_ids = inputs ['input_ids' ]
4019+ labels = inputs ['labels' ]
4020+ idx_list = _findall (input_ids , - 100 )
4021+ processor = self .tokenizer .processor
4022+ inputs = {'_data' : {}}
4023+ if images :
4024+ image_inputs = processor .image_processor (images , cut_enable = cut_enable , return_tensors = 'pt' )
4025+ added_tokens_len = 0
4026+ cut_shapes = image_inputs ['cut_shape' ] or [None ] * 2 * len (idx_list )
4027+ image_token_list = self .tokenizer .encode ('<|image|>' , add_special_tokens = False )
4028+ for idx , cut_shape in zip (idx_list , cut_shapes [::2 ]):
4029+ if cut_shape :
4030+ token_list = self ._get_image_token_list (cut_shape )
4031+ else :
4032+ token_list = image_token_list
4033+ input_ids = input_ids [:idx + added_tokens_len ] + token_list + input_ids [added_tokens_len + idx + 1 :]
4034+ if labels :
4035+ labels = labels [:idx + added_tokens_len ] + [- 100 ] * len (token_list ) + labels [added_tokens_len + idx
4036+ + 1 :]
4037+ added_tokens_len += len (token_list ) - 1
4038+ image_token_idx = torch .tensor (_findall (input_ids , image_token_list ))
4039+
4040+ inputs ['_data' ].update ({
4041+ 'pixel_values' : image_inputs ['pixel_values' ],
4042+ 'media_offset' : image_token_idx ,
4043+ })
4044+ inputs ['_data' ]['input_ids' ] = input_ids
4045+ inputs ['labels' ] = labels
4046+ return inputs , {}
4047+
4048+ def _post_encode (self , model , data : Any ) -> Dict [str , Any ]:
4049+ if 'pixel_values' in data :
4050+ pixel_values = data .pop ('pixel_values' )
4051+ data ['image_embeds' ] = model .forward_image (pixel_values )
4052+ return data
4053+
4054+ def data_collator (self , batch : List [Dict [str , Any ]], padding_to : Optional [int ] = None ) -> Dict [str , Any ]:
4055+ res = super (mPlugOwl3Template , self ).data_collator (batch , padding_to )
4056+ image_embeds = [b ['image_embeds' ] for b in batch if 'image_embeds' in b ]
4057+ if image_embeds :
4058+ res ['image_embeds' ] = torch .concat (image_embeds )
4059+ media_offset = []
4060+
4061+ for bi , b in enumerate (batch ):
4062+ media_offset .append (b .get ('media_offset' , torch .tensor ([]).long ()))
4063+
4064+ if media_offset :
4065+ res ['media_offset' ] = media_offset
4066+ return res
4067+
4068+
40074069register_template (TemplateType .mplug_owl3 , mPlugOwl3Template (), use_model = True , lazy_tokenize = True )
4070+ register_template (TemplateType .mplug_owl3v , mPlugOwl3vTemplate (), use_model = True , lazy_tokenize = True )
40084071
40094072register_template (TemplateType .wizardlm2_awq ,
40104073 Template (['{{SYSTEM}}' ], ['User:\n {{QUERY}}\n \n Assistant:\n ' ], ['\n \n ' ], ['</s>' ]))
0 commit comments