@@ -405,6 +405,10 @@ def _reranker_encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
405
405
406
406
positive = deepcopy (inputs )
407
407
positive .rejected_response = []
408
+ if '{doc}' in positive .messages [- 2 ]['content' ]:
409
+ positive .messages [- 2 ]['content' ] = positive .messages [- 2 ]['content' ].replace (
410
+ '{doc}' , inputs .messages [- 1 ]['content' ])
411
+ positive .messages .pop (- 1 )
408
412
positive_encoded = self ._encode_truncated (positive )
409
413
for key in positive_encoded :
410
414
_encoded [f'positive_{ key } ' ] = positive_encoded [key ]
@@ -414,7 +418,12 @@ def _reranker_encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
414
418
rejected_len = len (inputs .rejected_response ) if inputs .rejected_response else 0
415
419
for i in range (rejected_len ):
416
420
negative = deepcopy (inputs )
417
- negative .messages [- 1 ]['content' ] = negative .rejected_response [i ]
421
+ if '{doc}' in negative .messages [- 2 ]['content' ]:
422
+ negative .messages [- 2 ]['content' ] = negative .messages [- 2 ]['content' ].replace (
423
+ '{doc}' , negative .rejected_response [i ])
424
+ negative .messages .pop (- 1 )
425
+ else :
426
+ negative .messages [- 1 ]['content' ] = negative .rejected_response [i ]
418
427
negative .rejected_response = []
419
428
negative_encoded = self ._encode_truncated (negative )
420
429
for key in negative_encoded :
@@ -1637,19 +1646,62 @@ def _torchacc_xtuner_data_collator(self, res, padding_to, tokenizer, padding_sid
1637
1646
def print_inputs (self , inputs : Dict [str , Any ], tokenizer_kwargs : Optional [Dict [str , Any ]] = None ) -> None :
1638
1647
if tokenizer_kwargs is None :
1639
1648
tokenizer_kwargs = {}
1640
- for key in [
1641
- 'input' , 'labels' , 'generate' , 'chosen_input' , 'chosen_labels' , 'rejected_input' , 'rejected_labels'
1642
- ]:
1649
+
1650
+ # Base keys to check
1651
+ base_keys = [
1652
+ 'input' , 'labels' , 'generate' , 'chosen_input' , 'chosen_labels' , 'rejected_input' , 'rejected_labels'
1653
+ ]
1654
+
1655
+ # For reranker/embedding modes, also check prefixed keys
1656
+ if self .mode in {'reranker' , 'generative_reranker' , 'embedding' }:
1657
+ prefixes = []
1658
+ if self .mode in {'reranker' , 'generative_reranker' }:
1659
+ prefixes = ['positive_' , 'negative_' ]
1660
+ elif self .mode == 'embedding' :
1661
+ prefixes = ['anchor_' , 'positive_' , 'negative_' ]
1662
+
1663
+ # Add prefixed keys for reranker/embedding modes
1664
+ extended_keys = base_keys .copy ()
1665
+ for prefix in prefixes :
1666
+ for base_key in ['input' , 'labels' ]:
1667
+ extended_keys .append (f'{ prefix } { base_key } ' )
1668
+
1669
+ # Also check for numbered negative keys (negative0_, negative1_, etc.)
1670
+ input_keys = list (inputs .keys ())
1671
+ for key in input_keys :
1672
+ if any (key .startswith (f'{ prefix } ' ) for prefix in prefixes ):
1673
+ # Extract the base key after removing prefix
1674
+ for prefix in prefixes :
1675
+ if key .startswith (prefix ):
1676
+ base_key = key [len (prefix ):]
1677
+ if base_key in ['input_ids' , 'labels'
1678
+ ] or base_key .rstrip ('0123456789_' ) in ['input' , 'labels' ]:
1679
+ extended_keys .append (key .replace ('_ids' , '' ))
1680
+ break
1681
+
1682
+ keys_to_check = list (set (extended_keys ))
1683
+ else :
1684
+ keys_to_check = base_keys
1685
+
1686
+ for key in keys_to_check :
1687
+ # Skip labels completely for certain modes
1688
+ if key .endswith ('labels' ) and self .mode in {'reranker' , 'generative_reranker' }:
1689
+ continue
1690
+
1643
1691
val = inputs .get (key ) # fix val is a tensor
1644
1692
if val is None :
1645
1693
val = inputs .get (f'{ key } _ids' )
1646
1694
if val is not None :
1647
1695
key_upper = key .upper ()
1648
1696
logger .info (f'[{ key_upper } _IDS] { val } ' )
1649
- if key == 'labels' and self .mode in {'seq_cls' , 'embedding' , 'reranker' , 'generative_reranker ' }:
1697
+ if key . endswith ( 'labels' ) and self .mode in {'seq_cls' , 'embedding' }:
1650
1698
continue
1651
1699
if isinstance (val , (list , tuple , torch .Tensor )):
1652
- val_str = self .safe_decode (val , ** tokenizer_kwargs )
1700
+ # Handle nested lists (e.g., for reranker negative samples)
1701
+ if isinstance (val , (list , tuple )) and len (val ) > 0 and isinstance (val [0 ], (list , tuple )):
1702
+ val_str = [self .safe_decode (sub_val , ** tokenizer_kwargs ) for sub_val in val ]
1703
+ else :
1704
+ val_str = self .safe_decode (val , ** tokenizer_kwargs )
1653
1705
logger .info (f'[{ key_upper } ] { val_str } ' )
1654
1706
if inputs .get ('loss_scale' ) is not None :
1655
1707
val = inputs ['loss_scale' ]
0 commit comments