@@ -76,3 +76,58 @@ def evaluate_equation(equation_str):
7676 return result
7777 except Exception as e : # noqa: F841
7878 return None
79+
80+
81+ def compute_response_metrics (responses ):
82+ if not responses :
83+ return responses
84+
85+ acc_key = next ((k for k , v in responses [0 ].metrics .items () if "accuracy" in k ), None )
86+ if acc_key is None :
87+ raise ValueError ("No accuracy metric found in responses." )
88+
89+ total_response_length = 0
90+ total_correct_length = 0
91+ total_incorrect_length = 0
92+ pass_k_count = 0
93+ num_responses = len (responses )
94+
95+ for response in responses :
96+ tokens_length = len (response .tokens ) - response .prompt_length
97+ is_correct = response .metrics .get (acc_key ) == 1.0
98+
99+ total_response_length += tokens_length
100+ if is_correct :
101+ pass_k_count += 1
102+ total_correct_length += tokens_length
103+ else :
104+ total_incorrect_length += tokens_length
105+
106+ avg_response_length = total_response_length / num_responses
107+ avg_pass_k = pass_k_count / num_responses
108+ avg_correct_length = total_correct_length / pass_k_count if pass_k_count > 0 else None
109+ avg_incorrect_length = (
110+ total_incorrect_length / (num_responses - pass_k_count )
111+ if num_responses > pass_k_count
112+ else None
113+ )
114+
115+ metrics = {
116+ "response_length" : avg_response_length ,
117+ "pass@k" : avg_pass_k ,
118+ ** (
119+ {"response_length_correct" : avg_correct_length }
120+ if avg_correct_length is not None
121+ else {}
122+ ),
123+ ** (
124+ {"response_length_wrong" : avg_incorrect_length }
125+ if avg_incorrect_length is not None
126+ else {}
127+ ),
128+ }
129+
130+ for response in responses :
131+ response .metrics .update (metrics )
132+
133+ return responses
0 commit comments