From b894c775bbe9865aa9a14db15e0e41df23fe2990 Mon Sep 17 00:00:00 2001 From: yuchang Date: Tue, 27 May 2025 15:43:51 +0800 Subject: [PATCH 1/2] add response metrics --- trinity/common/workflows/workflow.py | 3 ++ trinity/utils/eval_utils.py | 55 ++++++++++++++++++++++++++++ 2 files changed, 58 insertions(+) diff --git a/trinity/common/workflows/workflow.py b/trinity/common/workflows/workflow.py index 9786bd6b77..8c6a165b2c 100644 --- a/trinity/common/workflows/workflow.py +++ b/trinity/common/workflows/workflow.py @@ -14,6 +14,7 @@ from trinity.common.experience import Experience from trinity.common.models.model import ModelWrapper from trinity.common.rewards.reward_fn import MathRewardFn, RewardFn +from trinity.utils.eval_utils import compute_response_metrics from trinity.utils.log import get_logger from trinity.utils.registry import Registry @@ -213,6 +214,8 @@ def run(self) -> List[Experience]: response.metrics.update(reward) reward = sum(reward.values()) response.reward = reward + if self.is_eval: + responses = compute_response_metrics(responses) return responses diff --git a/trinity/utils/eval_utils.py b/trinity/utils/eval_utils.py index e3aa216eda..01c1097482 100644 --- a/trinity/utils/eval_utils.py +++ b/trinity/utils/eval_utils.py @@ -76,3 +76,58 @@ def evaluate_equation(equation_str): return result except Exception as e: # noqa: F841 return None + + +def compute_response_metrics(responses): + if not responses: + return responses + + acc_key = next((k for k, v in responses[0].metrics.items() if "accuracy" in k), None) + if acc_key is None: + raise ValueError("No accuracy metric found in responses.") + + total_response_length = 0 + total_correct_length = 0 + total_incorrect_length = 0 + pass_k_count = 0 + num_responses = len(responses) + + for response in responses: + tokens_length = len(response.tokens) - response.prompt_length + is_correct = response.metrics.get(acc_key) == 1.0 + + total_response_length += tokens_length + if is_correct: + pass_k_count += 1 + total_correct_length += tokens_length + else: + total_incorrect_length += tokens_length + + avg_response_length = total_response_length / num_responses + avg_pass_k = pass_k_count / num_responses + avg_correct_length = total_correct_length / pass_k_count if pass_k_count > 0 else None + avg_incorrect_length = ( + total_incorrect_length / (num_responses - pass_k_count) + if num_responses > pass_k_count + else None + ) + + metrics = { + "response_length": avg_response_length, + "pass@k": avg_pass_k, + **( + {"response_length_correct": avg_correct_length} + if avg_correct_length is not None + else {} + ), + **( + {"response_length_wrong": avg_incorrect_length} + if avg_incorrect_length is not None + else {} + ), + } + + for response in responses: + response.metrics.update(metrics) + + return responses From 65ee18f90c2824536ea998c486e7e0cc58f8f27e Mon Sep 17 00:00:00 2001 From: yuchang Date: Tue, 27 May 2025 16:13:46 +0800 Subject: [PATCH 2/2] add response metrics --- trinity/utils/eval_utils.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/trinity/utils/eval_utils.py b/trinity/utils/eval_utils.py index 01c1097482..5d6a29a406 100644 --- a/trinity/utils/eval_utils.py +++ b/trinity/utils/eval_utils.py @@ -89,8 +89,8 @@ def compute_response_metrics(responses): total_response_length = 0 total_correct_length = 0 total_incorrect_length = 0 - pass_k_count = 0 - num_responses = len(responses) + n_correct = 0 + n_responses = len(responses) for response in responses: tokens_length = len(response.tokens) - response.prompt_length @@ -98,23 +98,19 @@ def compute_response_metrics(responses): total_response_length += tokens_length if is_correct: - pass_k_count += 1 + n_correct += 1 total_correct_length += tokens_length else: total_incorrect_length += tokens_length - avg_response_length = total_response_length / num_responses - avg_pass_k = pass_k_count / num_responses - avg_correct_length = total_correct_length / pass_k_count if pass_k_count > 0 else None + avg_response_length = total_response_length / n_responses + avg_correct_length = total_correct_length / n_correct if n_correct > 0 else None avg_incorrect_length = ( - total_incorrect_length / (num_responses - pass_k_count) - if num_responses > pass_k_count - else None + total_incorrect_length / (n_responses - n_correct) if n_responses > n_correct else None ) metrics = { "response_length": avg_response_length, - "pass@k": avg_pass_k, **( {"response_length_correct": avg_correct_length} if avg_correct_length is not None