@@ -30,7 +30,6 @@ class KeyeVLTemplate(Template):
3030 def replace_tag (self , media_type : Literal ['image' , 'video' , 'audio' ], index : int ,
3131 inputs : StdTemplateInputs ) -> List [Context ]:
3232 from keye_vl_utils import fetch_image , fetch_video
33- # from qwen_vl_utils import fetch_image, fetch_video
3433 assert media_type in {'image' , 'video' }
3534 if media_type == 'image' :
3635 inputs .images [index ] = fetch_image ({'image' : inputs .images [index ]})
@@ -49,7 +48,6 @@ def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int
4948 return ['<|vision_start|><|video_pad|><|vision_end|>' ]
5049
5150 def _encode (self , inputs : StdTemplateInputs ) -> Dict [str , Any ]:
52- from keye_vl_utils import vision_process
5351 encoded = super ()._encode (inputs )
5452 processor = self .processor
5553 input_ids = encoded ['input_ids' ]
@@ -63,15 +61,16 @@ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
6361 if locals ()[media_type ]:
6462 if media_type == 'images' :
6563 media_token = self .image_token_id
66- media_inputs = processor .image_processor (
67- images = images , videos = None , return_tensors = 'pt' , do_resize = False )
64+ media_inputs = processor .image_processor (images = images , return_tensors = 'pt' , do_resize = False )
6865 media_grid_thw = media_inputs ['image_grid_thw' ]
6966 else :
67+ kwargs = {}
7068 if hasattr (processor , 'video_processor' ):
7169 processor_func = processor .video_processor
7270 else :
7371 processor_func = processor .image_processor
74- media_inputs = processor_func (images = None , videos = videos , return_tensors = 'pt' , do_resize = False )
72+ kwargs ['images' ] = None
73+ media_inputs = processor_func (videos = videos , return_tensors = 'pt' , do_resize = False , ** kwargs )
7574 media_grid_thw = media_inputs ['video_grid_thw' ]
7675 media_token = self .video_token_id
7776 media_inputs ['second_per_grid_ts' ] = [
@@ -118,7 +117,7 @@ def _post_encode(self, model, inputs: Dict[str, Any]) -> Dict[str, Any]:
118117 if is_deepspeed_enabled ():
119118 from PIL import Image
120119 images = [Image .new ('RGB' , (32 , 32 ), (0 , 0 , 0 ))]
121- media_inputs = self .processor .image_processor (images = images , videos = None , return_tensors = 'pt' )
120+ media_inputs = self .processor .image_processor (images = images , return_tensors = 'pt' )
122121 device = input_ids .device
123122 media_inputs = to_device (media_inputs , device )
124123 pixel_values = media_inputs ['pixel_values' ].type (dtype )
0 commit comments