@@ -41,11 +41,12 @@ def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int
4141 video = inputs .videos [index ]
4242 if os .path .isdir (video ):
4343 video = [os .path .join (video , fname ) for fname in os .listdir (video )]
44- video , video_kwargs = fetch_video ({'video' : video }, return_video_sample_fps = True )
44+ video , video_kwargs = fetch_video ({'video' : video })
4545 if isinstance (video , torch .Tensor ):
4646 video = video .to (torch .uint8 )
4747 inputs .videos [index ] = video
48- inputs .mm_processor_kwargs .setdefault ('fps' , []).append (video_kwargs )
48+ for k , v in video_kwargs .items ():
49+ inputs .mm_processor_kwargs .setdefault (k , []).append (v )
4950 return ['<|vision_start|><|video_pad|><|vision_end|>' ]
5051
5152 def _encode (self , inputs : StdTemplateInputs ) -> Dict [str , Any ]:
@@ -62,25 +63,24 @@ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
6263 media_inputs = processor .image_processor (images = mm_data , return_tensors = 'pt' , do_resize = False )
6364 media_grid_thw = media_inputs ['image_grid_thw' ]
6465 else :
65- kwargs = {}
66- if hasattr ( processor , 'video_processor' ):
67- processor_func = processor . video_processor
68- else :
69- processor_func = processor . image_processor
70- kwargs [ 'images' ] = None
71- media_inputs = processor_func ( videos = mm_data , return_tensors = 'pt' , do_resize = False , ** kwargs )
66+ split_token = self . _tokenize ( ' \n ' )[ 0 ]
67+ media_inputs = processor (
68+ text = [ ' \n ' . join ([ '<|video_pad|>' ] * len ( mm_data ))],
69+ videos = mm_data ,
70+ return_tensors = 'pt' ,
71+ ** inputs . mm_processor_kwargs )
72+ splited_tokens = self . _split_list ( media_inputs [ 'input_ids' ][ 0 ]. tolist (), split_token )
7273 media_grid_thw = media_inputs ['video_grid_thw' ]
7374 media_token = self .video_token_id
74- fps = inputs .mm_processor_kwargs ['fps' ]
75- media_inputs ['second_per_grid_ts' ] = [
76- processor .image_processor .temporal_patch_size / tmp for tmp in fps
77- ]
7875 idx_list = findall (input_ids , media_token )
7976 merge_length = processor .image_processor .merge_size ** 2
8077
8178 def _get_new_tokens (i ):
82- token_len = (media_grid_thw [i ].prod () // merge_length )
83- return [media_token ] * token_len
79+ if media_type == 'images' :
80+ token_len = (media_grid_thw [i ].prod () // merge_length )
81+ return [media_token ] * token_len
82+ else :
83+ return splited_tokens [i ]
8484
8585 input_ids , labels , loss_scale = self ._extend_tokens (input_ids , labels , loss_scale , idx_list ,
8686 _get_new_tokens )
@@ -291,6 +291,14 @@ def _data_collator_mm_data(self, batch: List[Dict[str, Any]]) -> Dict[str, Any]:
291291
292292# Register the Keye VL template
293293register_template (KeyeTemplateMeta (MLLMTemplateType .keye_vl , template_cls = KeyeVLTemplate ))
294+
295+
296+ class KeyeVL1_5Template (KeyeVLTemplate ):
297+
298+ def _post_encode (self , model , inputs : Dict [str , Any ]) -> Dict [str , Any ]:
299+ return super (KeyeVLTemplate , self )._post_encode (model , inputs )
300+
301+
294302register_template (
295303 KeyeTemplateMeta (
296- MLLMTemplateType .keye_vl_1_5 , template_cls = KeyeVLTemplate , default_system = 'You are a helpful assistant.' ))
304+ MLLMTemplateType .keye_vl_1_5 , template_cls = KeyeVL1_5Template , default_system = 'You are a helpful assistant.' ))
0 commit comments