Skip to content

Commit b894c77

Browse files
committed
add response metrics
1 parent 318da40 commit b894c77

File tree

2 files changed

+58
-0
lines changed

2 files changed

+58
-0
lines changed

trinity/common/workflows/workflow.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from trinity.common.experience import Experience
1515
from trinity.common.models.model import ModelWrapper
1616
from trinity.common.rewards.reward_fn import MathRewardFn, RewardFn
17+
from trinity.utils.eval_utils import compute_response_metrics
1718
from trinity.utils.log import get_logger
1819
from trinity.utils.registry import Registry
1920

@@ -213,6 +214,8 @@ def run(self) -> List[Experience]:
213214
response.metrics.update(reward)
214215
reward = sum(reward.values())
215216
response.reward = reward
217+
if self.is_eval:
218+
responses = compute_response_metrics(responses)
216219
return responses
217220

218221

trinity/utils/eval_utils.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,3 +76,58 @@ def evaluate_equation(equation_str):
7676
return result
7777
except Exception as e: # noqa: F841
7878
return None
79+
80+
81+
def compute_response_metrics(responses):
82+
if not responses:
83+
return responses
84+
85+
acc_key = next((k for k, v in responses[0].metrics.items() if "accuracy" in k), None)
86+
if acc_key is None:
87+
raise ValueError("No accuracy metric found in responses.")
88+
89+
total_response_length = 0
90+
total_correct_length = 0
91+
total_incorrect_length = 0
92+
pass_k_count = 0
93+
num_responses = len(responses)
94+
95+
for response in responses:
96+
tokens_length = len(response.tokens) - response.prompt_length
97+
is_correct = response.metrics.get(acc_key) == 1.0
98+
99+
total_response_length += tokens_length
100+
if is_correct:
101+
pass_k_count += 1
102+
total_correct_length += tokens_length
103+
else:
104+
total_incorrect_length += tokens_length
105+
106+
avg_response_length = total_response_length / num_responses
107+
avg_pass_k = pass_k_count / num_responses
108+
avg_correct_length = total_correct_length / pass_k_count if pass_k_count > 0 else None
109+
avg_incorrect_length = (
110+
total_incorrect_length / (num_responses - pass_k_count)
111+
if num_responses > pass_k_count
112+
else None
113+
)
114+
115+
metrics = {
116+
"response_length": avg_response_length,
117+
"pass@k": avg_pass_k,
118+
**(
119+
{"response_length_correct": avg_correct_length}
120+
if avg_correct_length is not None
121+
else {}
122+
),
123+
**(
124+
{"response_length_wrong": avg_incorrect_length}
125+
if avg_incorrect_length is not None
126+
else {}
127+
),
128+
}
129+
130+
for response in responses:
131+
response.metrics.update(metrics)
132+
133+
return responses

0 commit comments

Comments
 (0)