Skip to content

Commit 65ee18f

Browse files
committed
add response metrics
1 parent b894c77 commit 65ee18f

File tree

1 file changed

+6
-10
lines changed

1 file changed

+6
-10
lines changed

trinity/utils/eval_utils.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -89,32 +89,28 @@ def compute_response_metrics(responses):
8989
total_response_length = 0
9090
total_correct_length = 0
9191
total_incorrect_length = 0
92-
pass_k_count = 0
93-
num_responses = len(responses)
92+
n_correct = 0
93+
n_responses = len(responses)
9494

9595
for response in responses:
9696
tokens_length = len(response.tokens) - response.prompt_length
9797
is_correct = response.metrics.get(acc_key) == 1.0
9898

9999
total_response_length += tokens_length
100100
if is_correct:
101-
pass_k_count += 1
101+
n_correct += 1
102102
total_correct_length += tokens_length
103103
else:
104104
total_incorrect_length += tokens_length
105105

106-
avg_response_length = total_response_length / num_responses
107-
avg_pass_k = pass_k_count / num_responses
108-
avg_correct_length = total_correct_length / pass_k_count if pass_k_count > 0 else None
106+
avg_response_length = total_response_length / n_responses
107+
avg_correct_length = total_correct_length / n_correct if n_correct > 0 else None
109108
avg_incorrect_length = (
110-
total_incorrect_length / (num_responses - pass_k_count)
111-
if num_responses > pass_k_count
112-
else None
109+
total_incorrect_length / (n_responses - n_correct) if n_responses > n_correct else None
113110
)
114111

115112
metrics = {
116113
"response_length": avg_response_length,
117-
"pass@k": avg_pass_k,
118114
**(
119115
{"response_length_correct": avg_correct_length}
120116
if avg_correct_length is not None

0 commit comments

Comments
 (0)