@@ -870,33 +870,76 @@ def create_optimizer_and_scheduler(self, num_training_steps: int):
870870
871871 def _compute_acc (self , outputs , labels ) -> None :
872872 args = self .args
873- preds = outputs .logits .argmax (dim = - 1 )
874- if self .template .sequence_parallel_size > 1 :
875- from swift .trainers .sequence_parallel import sequence_parallel
876- # Gather preds and labels across the sp group
877- if isinstance (preds , np .ndarray ):
878- preds = torch .from_numpy (preds ).to (get_current_device ())
879- if isinstance (labels , np .ndarray ):
880- labels = torch .from_numpy (labels ).to (get_current_device ())
881- assert labels .shape [1 ] == preds .shape [1 ]
882-
883- if sequence_parallel .rp_world_size > 1 :
884- position_ids = sequence_parallel .real_position_ids
885- position_ids = sequence_parallel .pad (position_ids , padding_value = - 1 , position_ids = position_ids )
886- else :
887- position_ids = None
888- preds_output = sequence_parallel .gather (preds , dim = 1 , position_ids = position_ids )
889- labels_output = sequence_parallel .gather (labels , dim = 1 , position_ids = position_ids )
890- # roll back to fit compute_acc
891- labels_output = torch .roll (labels_output , shifts = 1 , dims = 1 )
892- preds = preds_output
893- labels = labels_output .int ()
894-
895- metrics = compute_acc (
896- preds , labels , acc_strategy = args .acc_strategy , is_encoder_decoder = self .template .is_encoder_decoder )
897- mode = 'train' if self .model .training else 'eval'
898- for k , v in metrics .items ():
899- self .custom_metrics [mode ][k ].update (v )
873+ logits = outputs .logits
874+ metrics = None
875+ if getattr (args , 'loss_type' , None ) in {'generative_reranker' , 'listwise_generative_reranker' } \
876+ and logits is not None and logits .dim () == 3 :
877+ tokenizer = getattr (self , 'processing_class' , None )
878+ if tokenizer is None and getattr (self , 'template' , None ) is not None :
879+ tokenizer = self .template .tokenizer
880+ if tokenizer is None :
881+ raise RuntimeError ('tokenizer not available for generative_reranker acc' )
882+
883+ positive_token = os .environ .get ('GENERATIVE_RERANKER_POSITIVE_TOKEN' , 'yes' )
884+ negative_token = os .environ .get ('GENERATIVE_RERANKER_NEGATIVE_TOKEN' , 'no' )
885+
886+ try :
887+ positive_token_id = tokenizer .convert_tokens_to_ids (positive_token )
888+ negative_token_id = tokenizer .convert_tokens_to_ids (negative_token )
889+ except Exception as e :
890+ logger .warning (f'Failed to convert reranker tokens to ids: { e } ' )
891+ positive_token_id = None
892+ negative_token_id = None
893+
894+ if isinstance (positive_token_id , int ) and isinstance (negative_token_id , int ) \
895+ and positive_token_id >= 0 and negative_token_id >= 0 :
896+ positive_logits = logits [:, - 1 , positive_token_id ]
897+ negative_logits = logits [:, - 1 , negative_token_id ]
898+ binary_preds = (positive_logits > negative_logits ).long ()
899+ metrics = compute_acc (
900+ binary_preds ,
901+ labels .long (),
902+ acc_strategy = args .acc_strategy ,
903+ is_encoder_decoder = self .template .is_encoder_decoder )
904+ elif logits .dim () == 1 or (logits .dim () == 2 and logits .size (- 1 ) == 1 ):
905+ if logits .dim () == 2 :
906+ logits = logits .squeeze (- 1 )
907+ binary_preds = (logits > 0 ).long ()
908+ metrics = compute_acc (
909+ binary_preds ,
910+ labels .long (),
911+ acc_strategy = args .acc_strategy ,
912+ is_encoder_decoder = self .template .is_encoder_decoder )
913+ else :
914+ preds = logits .argmax (dim = - 1 )
915+ if self .template .sequence_parallel_size > 1 :
916+ from swift .trainers .sequence_parallel import sequence_parallel
917+ # Gather preds and labels across the sp group
918+ if isinstance (preds , np .ndarray ):
919+ preds = torch .from_numpy (preds ).to (get_current_device ())
920+ if isinstance (labels , np .ndarray ):
921+ labels = torch .from_numpy (labels ).to (get_current_device ())
922+ assert labels .shape [1 ] == preds .shape [1 ]
923+
924+ if sequence_parallel .rp_world_size > 1 :
925+ position_ids = sequence_parallel .real_position_ids
926+ position_ids = sequence_parallel .pad (position_ids , padding_value = - 1 , position_ids = position_ids )
927+ else :
928+ position_ids = None
929+ preds_output = sequence_parallel .gather (preds , dim = 1 , position_ids = position_ids )
930+ labels_output = sequence_parallel .gather (labels , dim = 1 , position_ids = position_ids )
931+ # roll back to fit compute_acc
932+ labels_output = torch .roll (labels_output , shifts = 1 , dims = 1 )
933+ preds = preds_output
934+ labels = labels_output .int ()
935+
936+ metrics = compute_acc (
937+ preds , labels , acc_strategy = args .acc_strategy , is_encoder_decoder = self .template .is_encoder_decoder )
938+
939+ if metrics :
940+ mode = 'train' if self .model .training else 'eval'
941+ for k , v in metrics .items ():
942+ self .custom_metrics [mode ][k ].update (v )
900943
901944 @torch .no_grad ()
902945 def _evalscope_eval (self ):
0 commit comments