Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions trinity/common/workflows/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from trinity.common.experience import Experience
from trinity.common.models.model import ModelWrapper
from trinity.common.rewards.reward_fn import MathRewardFn, RewardFn
from trinity.utils.eval_utils import compute_response_metrics
from trinity.utils.log import get_logger
from trinity.utils.registry import Registry

Expand Down Expand Up @@ -213,6 +214,8 @@ def run(self) -> List[Experience]:
response.metrics.update(reward)
reward = sum(reward.values())
response.reward = reward
if self.is_eval:
responses = compute_response_metrics(responses)
return responses


Expand Down
55 changes: 55 additions & 0 deletions trinity/utils/eval_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,58 @@ def evaluate_equation(equation_str):
return result
except Exception as e: # noqa: F841
return None


def compute_response_metrics(responses):
if not responses:
return responses

acc_key = next((k for k, v in responses[0].metrics.items() if "accuracy" in k), None)
if acc_key is None:
raise ValueError("No accuracy metric found in responses.")

total_response_length = 0
total_correct_length = 0
total_incorrect_length = 0
pass_k_count = 0
num_responses = len(responses)

for response in responses:
tokens_length = len(response.tokens) - response.prompt_length
is_correct = response.metrics.get(acc_key) == 1.0

total_response_length += tokens_length
if is_correct:
pass_k_count += 1
total_correct_length += tokens_length
else:
total_incorrect_length += tokens_length

avg_response_length = total_response_length / num_responses
avg_pass_k = pass_k_count / num_responses
avg_correct_length = total_correct_length / pass_k_count if pass_k_count > 0 else None
avg_incorrect_length = (
total_incorrect_length / (num_responses - pass_k_count)
if num_responses > pass_k_count
else None
)

metrics = {
"response_length": avg_response_length,
"pass@k": avg_pass_k,
**(
{"response_length_correct": avg_correct_length}
if avg_correct_length is not None
else {}
),
**(
{"response_length_wrong": avg_incorrect_length}
if avg_incorrect_length is not None
else {}
),
}

for response in responses:
response.metrics.update(metrics)

return responses