Skip to content

Commit f7af67a

Browse files
committed
feat: add the option to include rewards metrics
1 parent 1e36b85 commit f7af67a

File tree

1 file changed

+88
-3
lines changed

1 file changed

+88
-3
lines changed

app/trainers/huggingface_llm_trainer.py

Lines changed: 88 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import re
88
import threading
99
import json
10+
import inspect
1011
import pandas as pd
1112
from typing import final, Dict, TextIO, Optional, Any, List, Tuple, TYPE_CHECKING, Callable
1213
from transformers import __version__ as transformers_version
@@ -482,6 +483,7 @@ def run(
482483
else:
483484
try:
484485
logger.info("Evaluating the running model...")
486+
include_rewards_metrics = training_params.get("include_rewards_metrics", False)
485487
model, tokenizer = self._model_service.model, self._model_service.tokenizer
486488
if non_default_device_is_available(self._config.DEVICE):
487489
model.to(self._config.DEVICE)
@@ -528,8 +530,23 @@ def run(
528530
)
529531

530532
eval_metrics = trainer.evaluate()
533+
if "perplexity" not in eval_metrics and "eval_loss" in eval_metrics:
534+
eval_metrics.update({"perplexity": math.exp(eval_metrics["eval_loss"])})
531535
logger.info(f"Evaluation metrics: {eval_metrics}")
532536
self._tracker_client.send_hf_metrics_logs(eval_metrics, 0)
537+
if include_rewards_metrics:
538+
try:
539+
reward_metrics = self._evaluate_with_rewards(
540+
model=model,
541+
tokenizer=tokenizer,
542+
eval_dataset=eval_dataset,
543+
max_new_tokens=training_args.max_completion_length,
544+
)
545+
if reward_metrics:
546+
logger.info(f"Reward metrics: {reward_metrics}")
547+
self._tracker_client.send_hf_metrics_logs(reward_metrics, 0)
548+
except Exception as e:
549+
logger.warning(f"Failed to compute reward-based metrics: {e}")
533550
self._tracker_client.end_with_success()
534551
logger.info("Model evaluation finished")
535552
except torch.OutOfMemoryError as e:
@@ -577,8 +594,8 @@ def correctness_reward_func(
577594
answer: List,
578595
**kwargs: Dict[str, Any]
579596
) -> List[float]:
580-
responses = [completion[0]['content'] for completion in completions]
581-
q = prompts[0][-1]['content']
597+
responses = [completion[0]["content"] for completion in completions]
598+
q = prompts[0][-1]["content"]
582599
extracted_responses = [extract_xml_answer(r) for r in responses]
583600
logger.debug(
584601
"%s\nQuestion:\n%s\nAnswer:\n%s\nResponse:\n%s\nExtracted:\n%s",
@@ -591,7 +608,7 @@ def correctness_reward_func(
591608
return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]
592609

593610
def int_reward_func(completions: Tuple[Any], **kwargs: Dict[str, Any]) -> list[float]:
594-
responses = [completion[0]['content'] for completion in completions]
611+
responses = [completion[0]["content"] for completion in completions]
595612
extracted_responses = [extract_xml_answer(r) for r in responses]
596613
return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]
597614

@@ -635,6 +652,74 @@ def xmlcount_reward_func(completions: Tuple[Any], **kwargs: Dict[str, Any]) -> l
635652
correctness_reward_func,
636653
]
637654

655+
def _evaluate_with_rewards(
656+
self,
657+
model: PreTrainedModel,
658+
tokenizer: PreTrainedTokenizerBase,
659+
eval_dataset: datasets.Dataset,
660+
max_new_tokens: int,
661+
) -> Dict[str, float]:
662+
model.eval()
663+
if non_default_device_is_available(self._config.DEVICE):
664+
model.to(self._config.DEVICE)
665+
666+
reward_funcs = self._get_reward_functions()
667+
reward_sums: Dict[str, float] = {fn.__name__: 0.0 for fn in reward_funcs}
668+
count = 0
669+
670+
for example in eval_dataset:
671+
if "prompt" not in example:
672+
continue
673+
messages = example["prompt"]
674+
answer = example.get("answer", "")
675+
676+
prompt_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
677+
inputs = tokenizer(prompt_text, return_tensors="pt")
678+
input_ids = inputs["input_ids"]
679+
attention_mask = inputs.get("attention_mask")
680+
if non_default_device_is_available(self._config.DEVICE):
681+
input_ids = input_ids.to(self._config.DEVICE)
682+
attention_mask = attention_mask.to(self._config.DEVICE)
683+
684+
with torch.no_grad():
685+
generated = model.generate(
686+
input_ids=input_ids,
687+
attention_mask=attention_mask,
688+
max_new_tokens=max_new_tokens,
689+
do_sample=False,
690+
temperature=0.0,
691+
eos_token_id=getattr(tokenizer, "eos_token_id", None),
692+
pad_token_id=getattr(tokenizer, "pad_token_id", 0),
693+
)
694+
695+
completion_text = tokenizer.decode(generated[0][input_ids.shape[1]:], skip_special_tokens=True)
696+
for fn in reward_funcs:
697+
sig = inspect.signature(fn)
698+
kwargs: Dict[str, Any] = {}
699+
if "prompts" in sig.parameters:
700+
kwargs["prompts"] = [messages]
701+
if "completions" in sig.parameters:
702+
kwargs["completions"] = [({"content": completion_text},)]
703+
if "answer" in sig.parameters:
704+
kwargs["answer"] = [answer]
705+
706+
try:
707+
rewards = fn(**kwargs) # type: ignore
708+
value = float(rewards[0]) if isinstance(rewards, (list, tuple)) and rewards else float(rewards)
709+
except Exception:
710+
value = 0.0
711+
712+
reward_sums[fn.__name__] += value
713+
count += 1
714+
if count == 0:
715+
return {}
716+
717+
reward_avgs = {f"reward_{name}": total / count for name, total in reward_sums.items()}
718+
reward_overall_mean = sum(reward_avgs.values()) / len(reward_avgs) if reward_avgs else 0.0
719+
reward_avgs["reward_overall_mean"] = reward_overall_mean
720+
reward_avgs["reward_samples"] = float(count)
721+
return reward_avgs
722+
638723

639724
@final
640725
class MLflowLoggingCallback(TrainerCallback):

0 commit comments

Comments
 (0)