77import re
88import threading
99import json
10+ import inspect
1011import pandas as pd
1112from typing import final , Dict , TextIO , Optional , Any , List , Tuple , TYPE_CHECKING , Callable
1213from transformers import __version__ as transformers_version
@@ -482,6 +483,7 @@ def run(
482483 else :
483484 try :
484485 logger .info ("Evaluating the running model..." )
486+ include_rewards_metrics = training_params .get ("include_rewards_metrics" , False )
485487 model , tokenizer = self ._model_service .model , self ._model_service .tokenizer
486488 if non_default_device_is_available (self ._config .DEVICE ):
487489 model .to (self ._config .DEVICE )
@@ -528,8 +530,23 @@ def run(
528530 )
529531
530532 eval_metrics = trainer .evaluate ()
533+ if "perplexity" not in eval_metrics and "eval_loss" in eval_metrics :
534+ eval_metrics .update ({"perplexity" : math .exp (eval_metrics ["eval_loss" ])})
531535 logger .info (f"Evaluation metrics: { eval_metrics } " )
532536 self ._tracker_client .send_hf_metrics_logs (eval_metrics , 0 )
537+ if include_rewards_metrics :
538+ try :
539+ reward_metrics = self ._evaluate_with_rewards (
540+ model = model ,
541+ tokenizer = tokenizer ,
542+ eval_dataset = eval_dataset ,
543+ max_new_tokens = training_args .max_completion_length ,
544+ )
545+ if reward_metrics :
546+ logger .info (f"Reward metrics: { reward_metrics } " )
547+ self ._tracker_client .send_hf_metrics_logs (reward_metrics , 0 )
548+ except Exception as e :
549+ logger .warning (f"Failed to compute reward-based metrics: { e } " )
533550 self ._tracker_client .end_with_success ()
534551 logger .info ("Model evaluation finished" )
535552 except torch .OutOfMemoryError as e :
@@ -577,8 +594,8 @@ def correctness_reward_func(
577594 answer : List ,
578595 ** kwargs : Dict [str , Any ]
579596 ) -> List [float ]:
580- responses = [completion [0 ][' content' ] for completion in completions ]
581- q = prompts [0 ][- 1 ][' content' ]
597+ responses = [completion [0 ][" content" ] for completion in completions ]
598+ q = prompts [0 ][- 1 ][" content" ]
582599 extracted_responses = [extract_xml_answer (r ) for r in responses ]
583600 logger .debug (
584601 "%s\n Question:\n %s\n Answer:\n %s\n Response:\n %s\n Extracted:\n %s" ,
@@ -591,7 +608,7 @@ def correctness_reward_func(
591608 return [2.0 if r == a else 0.0 for r , a in zip (extracted_responses , answer )]
592609
593610 def int_reward_func (completions : Tuple [Any ], ** kwargs : Dict [str , Any ]) -> list [float ]:
594- responses = [completion [0 ][' content' ] for completion in completions ]
611+ responses = [completion [0 ][" content" ] for completion in completions ]
595612 extracted_responses = [extract_xml_answer (r ) for r in responses ]
596613 return [0.5 if r .isdigit () else 0.0 for r in extracted_responses ]
597614
@@ -635,6 +652,74 @@ def xmlcount_reward_func(completions: Tuple[Any], **kwargs: Dict[str, Any]) -> l
635652 correctness_reward_func ,
636653 ]
637654
655+ def _evaluate_with_rewards (
656+ self ,
657+ model : PreTrainedModel ,
658+ tokenizer : PreTrainedTokenizerBase ,
659+ eval_dataset : datasets .Dataset ,
660+ max_new_tokens : int ,
661+ ) -> Dict [str , float ]:
662+ model .eval ()
663+ if non_default_device_is_available (self ._config .DEVICE ):
664+ model .to (self ._config .DEVICE )
665+
666+ reward_funcs = self ._get_reward_functions ()
667+ reward_sums : Dict [str , float ] = {fn .__name__ : 0.0 for fn in reward_funcs }
668+ count = 0
669+
670+ for example in eval_dataset :
671+ if "prompt" not in example :
672+ continue
673+ messages = example ["prompt" ]
674+ answer = example .get ("answer" , "" )
675+
676+ prompt_text = tokenizer .apply_chat_template (messages , tokenize = False , add_generation_prompt = True )
677+ inputs = tokenizer (prompt_text , return_tensors = "pt" )
678+ input_ids = inputs ["input_ids" ]
679+ attention_mask = inputs .get ("attention_mask" )
680+ if non_default_device_is_available (self ._config .DEVICE ):
681+ input_ids = input_ids .to (self ._config .DEVICE )
682+ attention_mask = attention_mask .to (self ._config .DEVICE )
683+
684+ with torch .no_grad ():
685+ generated = model .generate (
686+ input_ids = input_ids ,
687+ attention_mask = attention_mask ,
688+ max_new_tokens = max_new_tokens ,
689+ do_sample = False ,
690+ temperature = 0.0 ,
691+ eos_token_id = getattr (tokenizer , "eos_token_id" , None ),
692+ pad_token_id = getattr (tokenizer , "pad_token_id" , 0 ),
693+ )
694+
695+ completion_text = tokenizer .decode (generated [0 ][input_ids .shape [1 ]:], skip_special_tokens = True )
696+ for fn in reward_funcs :
697+ sig = inspect .signature (fn )
698+ kwargs : Dict [str , Any ] = {}
699+ if "prompts" in sig .parameters :
700+ kwargs ["prompts" ] = [messages ]
701+ if "completions" in sig .parameters :
702+ kwargs ["completions" ] = [({"content" : completion_text },)]
703+ if "answer" in sig .parameters :
704+ kwargs ["answer" ] = [answer ]
705+
706+ try :
707+ rewards = fn (** kwargs ) # type: ignore
708+ value = float (rewards [0 ]) if isinstance (rewards , (list , tuple )) and rewards else float (rewards )
709+ except Exception :
710+ value = 0.0
711+
712+ reward_sums [fn .__name__ ] += value
713+ count += 1
714+ if count == 0 :
715+ return {}
716+
717+ reward_avgs = {f"reward_{ name } " : total / count for name , total in reward_sums .items ()}
718+ reward_overall_mean = sum (reward_avgs .values ()) / len (reward_avgs ) if reward_avgs else 0.0
719+ reward_avgs ["reward_overall_mean" ] = reward_overall_mean
720+ reward_avgs ["reward_samples" ] = float (count )
721+ return reward_avgs
722+
638723
639724@final
640725class MLflowLoggingCallback (TrainerCallback ):
0 commit comments