Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions trinity/common/workflows/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from trinity.common.experience import Experience
from trinity.common.models.model import ModelWrapper
from trinity.common.rewards.reward_fn import MathRewardFn, RewardFn
from trinity.utils.eval_utils import compute_response_metrics
from trinity.utils.log import get_logger
from trinity.utils.registry import Registry

Expand Down Expand Up @@ -213,6 +214,8 @@ def run(self) -> List[Experience]:
response.metrics.update(reward)
reward = sum(reward.values())
response.reward = reward
if self.is_eval:
responses = compute_response_metrics(responses)
return responses


Expand Down
51 changes: 51 additions & 0 deletions trinity/utils/eval_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,54 @@ def evaluate_equation(equation_str):
return result
except Exception as e: # noqa: F841
return None


def compute_response_metrics(responses):
if not responses:
return responses

acc_key = next((k for k, v in responses[0].metrics.items() if "accuracy" in k), None)
if acc_key is None:
raise ValueError("No accuracy metric found in responses.")

total_response_length = 0
total_correct_length = 0
total_incorrect_length = 0
n_correct = 0
n_responses = len(responses)

for response in responses:
tokens_length = len(response.tokens) - response.prompt_length
is_correct = response.metrics.get(acc_key) == 1.0

total_response_length += tokens_length
if is_correct:
n_correct += 1
total_correct_length += tokens_length
else:
total_incorrect_length += tokens_length

avg_response_length = total_response_length / n_responses
avg_correct_length = total_correct_length / n_correct if n_correct > 0 else None
avg_incorrect_length = (
total_incorrect_length / (n_responses - n_correct) if n_responses > n_correct else None
)

metrics = {
"response_length": avg_response_length,
**(
{"response_length_correct": avg_correct_length}
if avg_correct_length is not None
else {}
),
**(
{"response_length_wrong": avg_incorrect_length}
if avg_incorrect_length is not None
else {}
),
}

for response in responses:
response.metrics.update(metrics)

return responses