@@ -1081,6 +1081,8 @@ def encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any
10811081 image = _read_from_path (image_path )
10821082 images .append (image )
10831083 inputs , _ = super ().encode (example )
1084+ if len (inputs ) == 0 :
1085+ return inputs , {}
10841086 input_ids = inputs ['input_ids' ]
10851087 labels = inputs ['labels' ]
10861088 idx_list = _findall (input_ids , 1 )[1 :] # 1: <s>
@@ -1330,7 +1332,16 @@ def data_collator(self, batch: List[Dict[str, Any]], padding_to: Optional[int] =
13301332register_template (TemplateType .minicpm , Template (['<s>{{SYSTEM}}' ], ['<用户>{{QUERY}}<AI>' ], [], ['</s>' ]))
13311333
13321334
1333- class MiniCPMVTemlate (Template ):
1335+ def _remove_idx (arr : List [int ], idx_list : List [int ]) -> List [int ]:
1336+ res = []
1337+ idx_set = set (idx_list )
1338+ for i , x in enumerate (arr ):
1339+ if i not in idx_set :
1340+ res .append (x )
1341+ return res
1342+
1343+
1344+ class MiniCPMVTemplate (Template ):
13341345
13351346 def __init__ (self , * args , ** kwargs ):
13361347 self .is_v2_5 = kwargs .pop ('is_v2_5' , False )
@@ -1345,32 +1356,22 @@ def encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any
13451356 return inputs , {}
13461357 input_ids = inputs ['input_ids' ]
13471358 labels = inputs ['labels' ]
1348-
1349- img_start_idxs = np .where (np .array (input_ids ) == self .tokenizer .im_start_id )[0 ]
1350- if len (img_start_idxs ) > 1 : # if mutli-round, input_ids have mutli <image><unk></image>\n
1351- start = 0
1352- new_input_ids = []
1353- new_labels = []
1354- for idx in img_start_idxs [1 :]:
1355- new_input_ids = new_input_ids + input_ids [start :idx ]
1356- if labels is not None :
1357- new_labels = new_labels + labels [start :idx ]
1358- start = idx + 4 # skip <image><unk></image>\n
1359- new_input_ids = new_input_ids + input_ids [start :]
1360- input_ids = new_input_ids
1359+ idx_list = _findall (input_ids , - 1 )
1360+ if len (idx_list ) >= 2 :
1361+ input_ids = _remove_idx (input_ids , idx_list [1 :])
13611362 if labels is not None :
1362- new_labels = new_labels + labels [start :]
1363- labels = new_labels
1364-
1365- idx = img_start_idxs [0 ] + 1 # first <unk>
1363+ labels = _remove_idx (labels , idx_list [1 :])
1364+ idx = idx_list [0 ]
13661365 config = self .model .config
13671366 tgt_sizes = None
1368- if config .slice_mode :
1367+ slice_mode = getattr (config , 'slice_mode' , False )
1368+ if slice_mode :
13691369 images , placeholder = self .model .get_slice_image_placeholder (image , self .tokenizer )
1370+ placeholder += '\n '
13701371 placeholder_id = self .tokenizer .encode (placeholder , add_special_tokens = False )
1371- input_ids = (input_ids [:idx - 1 ] + placeholder_id + input_ids [idx + 2 :])
1372+ input_ids = (input_ids [:idx ] + placeholder_id + input_ids [idx + 1 :])
13721373 if labels is not None :
1373- labels = (labels [:idx - 1 ] + [- 100 ] * len (placeholder_id ) + labels [idx + 2 :])
1374+ labels = (labels [:idx ] + [- 100 ] * len (placeholder_id ) + labels [idx + 1 :])
13741375 input_tensor_ids = torch .tensor (input_ids )
13751376 image_start_idx = torch .where (input_tensor_ids == self .tokenizer .im_start_id )[0 ]
13761377 image_start_idx += 1
@@ -1393,9 +1394,11 @@ def encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any
13931394 else :
13941395 pixel_values = [self .model .transform (img ).to (device = self .model .device ) for img in images ]
13951396 else :
1396- input_ids = (input_ids [:idx ] + [self .tokenizer .unk_token_id ] * config .query_num + input_ids [idx + 1 :])
1397+ placeholder = '<image>' + '<unk>' * config .query_num + '</image>\n '
1398+ placeholder_id = self .tokenizer .encode (placeholder , add_special_tokens = False )
1399+ input_ids = (input_ids [:idx ] + placeholder_id + input_ids [idx + 1 :])
13971400 if labels is not None :
1398- labels = (labels [:idx ] + [- 100 ] * config . query_num + labels [idx + 1 :])
1401+ labels = (labels [:idx ] + [- 100 ] * len ( placeholder_id ) + labels [idx + 1 :])
13991402 image_bound = [torch .tensor ([[idx , idx + config .query_num ]])]
14001403 pixel_values = [self .model .transform (image ).to (device = self .model .device )]
14011404 data = {
@@ -1418,7 +1421,7 @@ def get_generate_ids(generate_ids: Tensor, input_token_len: int) -> List[int]:
14181421
14191422register_template (
14201423 TemplateType .minicpm_v ,
1421- MiniCPMVTemlate (['<s>{{SYSTEM}}' ], ['<用户><image><unk></image> \n {{QUERY}}<AI>' ], [], ['</s>' ]),
1424+ MiniCPMVTemplate (['<s>{{SYSTEM}}' ], ['<用户>' , [ - 1 ], ' {{QUERY}}<AI>' ], [], ['</s>' ]),
14221425 use_model = True ,
14231426 lazy_tokenize = True ,
14241427 infer_media_type = 'dialogue' ,
@@ -1427,11 +1430,11 @@ def get_generate_ids(generate_ids: Tensor, input_token_len: int) -> List[int]:
14271430
14281431register_template (
14291432 TemplateType .minicpm_v_v2_5 ,
1430- MiniCPMVTemlate (['<|begin_of_text|>{{SYSTEM}}' ], [
1431- '<|start_header_id|>user<|end_header_id|>\n \n <image><unk></image> \n {{QUERY}}<|eot_id|>'
1433+ MiniCPMVTemplate (['<|begin_of_text|>{{SYSTEM}}' ], [
1434+ '<|start_header_id|>user<|end_header_id|>\n \n ' , [ - 1 ], ' {{QUERY}}<|eot_id|>'
14321435 '<|start_header_id|>assistant<|end_header_id|>\n \n '
14331436 ], ['<|eot_id|>' ], ['<|eot_id|>' ],
1434- is_v2_5 = True ),
1437+ is_v2_5 = True ),
14351438 use_model = True ,
14361439 lazy_tokenize = True ,
14371440 infer_media_type = 'dialogue' ,
0 commit comments