diff --git a/trinity/common/workflows/workflow.py b/trinity/common/workflows/workflow.py index 9786bd6b77..8c6a165b2c 100644 --- a/trinity/common/workflows/workflow.py +++ b/trinity/common/workflows/workflow.py @@ -14,6 +14,7 @@ from trinity.common.experience import Experience from trinity.common.models.model import ModelWrapper from trinity.common.rewards.reward_fn import MathRewardFn, RewardFn +from trinity.utils.eval_utils import compute_response_metrics from trinity.utils.log import get_logger from trinity.utils.registry import Registry @@ -213,6 +214,8 @@ def run(self) -> List[Experience]: response.metrics.update(reward) reward = sum(reward.values()) response.reward = reward + if self.is_eval: + responses = compute_response_metrics(responses) return responses diff --git a/trinity/utils/eval_utils.py b/trinity/utils/eval_utils.py index e3aa216eda..5d6a29a406 100644 --- a/trinity/utils/eval_utils.py +++ b/trinity/utils/eval_utils.py @@ -76,3 +76,54 @@ def evaluate_equation(equation_str): return result except Exception as e: # noqa: F841 return None + + +def compute_response_metrics(responses): + if not responses: + return responses + + acc_key = next((k for k, v in responses[0].metrics.items() if "accuracy" in k), None) + if acc_key is None: + raise ValueError("No accuracy metric found in responses.") + + total_response_length = 0 + total_correct_length = 0 + total_incorrect_length = 0 + n_correct = 0 + n_responses = len(responses) + + for response in responses: + tokens_length = len(response.tokens) - response.prompt_length + is_correct = response.metrics.get(acc_key) == 1.0 + + total_response_length += tokens_length + if is_correct: + n_correct += 1 + total_correct_length += tokens_length + else: + total_incorrect_length += tokens_length + + avg_response_length = total_response_length / n_responses + avg_correct_length = total_correct_length / n_correct if n_correct > 0 else None + avg_incorrect_length = ( + total_incorrect_length / (n_responses - n_correct) if n_responses > n_correct else None + ) + + metrics = { + "response_length": avg_response_length, + **( + {"response_length_correct": avg_correct_length} + if avg_correct_length is not None + else {} + ), + **( + {"response_length_wrong": avg_incorrect_length} + if avg_incorrect_length is not None + else {} + ), + } + + for response in responses: + response.metrics.update(metrics) + + return responses