@@ -1387,6 +1387,76 @@ def get_generate_ids(generate_ids: Tensor, input_token_len: int) -> List[int]:
13871387
13881388class Internvl2Template (InternvlTemplate ):
13891389
1390+ video_segments = 8
1391+
1392+ def replace_tag (self , media_type , index , example ) -> List [Context ]:
1393+ if media_type == 'image' :
1394+ return [[- 100 ]]
1395+ elif media_type == 'video' :
1396+ context_list = []
1397+ for i in range (self .video_segments ):
1398+ context_list .append (f'Frame{ i + 1 } : ' )
1399+ context_list .append ([- 100 ])
1400+ context_list .append ('\n ' )
1401+ return context_list
1402+
1403+ def encode (self , example : Dict [str , Any ]) -> Tuple [Dict [str , Any ], Dict [str , Any ]]:
1404+ inputs , _ = super (InternvlTemplate , self ).encode (example )
1405+ if len (inputs ) == 0 :
1406+ return inputs , {}
1407+ input_ids = inputs ['input_ids' ]
1408+ idx_list = _findall (input_ids , - 100 )
1409+ labels = inputs .get ('labels' )
1410+ images_path = example .get ('images' ) or []
1411+ videos_path = example .get ('videos' ) or []
1412+ if images_path :
1413+ from .vision_utils import load_image
1414+ pixel_values = []
1415+ if isinstance (images_path , str ):
1416+ images_path = [images_path ]
1417+ for image_path in images_path :
1418+ pixel_values .append (load_image (image_path ))
1419+
1420+ assert len (images_path ) == len (idx_list )
1421+ added_tokens_len = 0
1422+ patches = 0
1423+ for idx , pv in zip (idx_list , pixel_values ):
1424+ patches += pv .shape [0 ]
1425+ img_tokens : List [int ] = self .tokenizer .encode (
1426+ '<img>' + '<IMG_CONTEXT>' * self .num_image_token * pv .shape [0 ] + '</img>\n ' ,
1427+ add_special_tokens = False )
1428+ input_ids = input_ids [:idx + added_tokens_len ] + img_tokens + input_ids [idx + added_tokens_len + 1 :]
1429+ if labels is not None :
1430+ labels = labels [:idx + added_tokens_len ] + [- 100 ] * len (img_tokens ) + labels [idx + added_tokens_len
1431+ + 1 :]
1432+ added_tokens_len += len (img_tokens ) - 1
1433+ inputs ['input_ids' ] = input_ids
1434+ inputs ['labels' ] = labels
1435+ inputs ['pixel_values' ] = torch .cat (pixel_values ).to (self .model .dtype )
1436+ inputs ['image_flags' ] = torch .ones (patches )
1437+ if videos_path :
1438+ if not isinstance (videos_path , (list , tuple )):
1439+ videos_path = [videos_path ]
1440+ assert len (videos_path ) == 1
1441+ from swift .llm .utils .vision_utils import load_video
1442+ pixel_values , num_patches = load_video (videos_path [0 ], num_segments = self .video_segments )
1443+ assert len (num_patches ) == len (idx_list )
1444+ added_tokens_len = 0
1445+ for idx , num_patch in zip (idx_list , num_patches ):
1446+ img_tokens : List [int ] = self .tokenizer .encode (
1447+ '<img>' + '<IMG_CONTEXT>' * self .num_image_token * num_patch + '</img>\n ' , add_special_tokens = False )
1448+ input_ids = input_ids [:idx + added_tokens_len ] + img_tokens + input_ids [idx + added_tokens_len + 1 :]
1449+ if labels is not None :
1450+ labels = labels [:idx + added_tokens_len ] + [- 100 ] * len (img_tokens ) + labels [idx + added_tokens_len
1451+ + 1 :]
1452+ added_tokens_len += len (img_tokens ) - 1
1453+ inputs ['input_ids' ] = input_ids
1454+ inputs ['labels' ] = labels
1455+ inputs ['pixel_values' ] = pixel_values .to (self .model .dtype )
1456+ inputs ['image_flags' ] = torch .ones (sum (num_patches ))
1457+ inputs .pop ('loss_scale' , None )
1458+ return inputs , {}
1459+
13901460 def __init__ (self ):
13911461 self .system = '你是由上海人工智能实验室联合商汤科技开发的书生多模态大模型,英文名叫InternVL, 是一个有用无害的人工智能助手。'
13921462 Template .__init__ (self , [], ['<|im_start|>user\n {{QUERY}}<|im_end|><|im_start|>assistant\n ' ], ['<|im_end|>' ],
0 commit comments