@@ -370,6 +370,7 @@ def split_multi_medias(_inputs):
370370 positive_encoded = self ._encode_truncated (positive )
371371 for key in positive_encoded :
372372 _encoded [f'positive_{ key } ' ] = positive_encoded [key ]
373+ _encoded [f'negative_{ key } ' ] = []
373374 labels .append (float (inputs .label ) if inputs .label is not None else 1.0 )
374375
375376 rejected_len = len (inputs .rejected_response ) if inputs .rejected_response else 0
@@ -381,7 +382,7 @@ def split_multi_medias(_inputs):
381382 split_multi_medias (negative )
382383 negative_encoded = self ._encode_truncated (negative )
383384 for key in negative_encoded :
384- _encoded [f'negative { i } _ { key } ' ] = negative_encoded [key ]
385+ _encoded [f'negative_ { key } ' ]. append ( negative_encoded [key ])
385386 labels .append (0.0 )
386387
387388 _encoded ['labels' ] = labels
@@ -1314,10 +1315,18 @@ def _embedding_data_collator(self,
13141315 new_batch = []
13151316 for b in batch :
13161317 keys = [key for key in b .keys () if 'negative' in key ]
1317- max_neg = max ([int (re .findall (r'negative(-?\d+)' , key )[0 ]) for key in keys ]) if keys else None
1318+ max_neg = None
1319+ for key in keys :
1320+ value_list = b [key ]
1321+ suffix = key [len ('negative_' ):]
1322+ max_neg = len (value_list )
1323+ for i , value in enumerate (value_list ):
1324+ b [f'negative{ i } _{ suffix } ' ] = value
1325+ b .pop (key )
1326+
13181327 indexes = ['anchor_' , 'positive_' ]
13191328 if max_neg is not None :
1320- for i in range (0 , max_neg + 1 ):
1329+ for i in range (0 , max_neg ):
13211330 indexes .append (f'negative{ i } _' )
13221331 for prefix in indexes :
13231332 new_batch += self ._fetch_inputs_startswith ([b ], prefix )
0 commit comments