44import logging
55import re
66import time
7-
87from typing import List , Dict , Tuple , Optional
98from datetime import datetime
10-
119from openai import OpenAI
1210from datasets import load_dataset
1311from tqdm import tqdm
1715logger = logging .getLogger (__name__ )
1816
1917# Initialize OpenAI client
20- client = OpenAI (api_key = os .environ .get ("OPENAI_API_KEY" ), base_url = "http://localhost:8000 /v1" )
18+ client = OpenAI (api_key = os .environ .get ("OPENAI_API_KEY" ), base_url = "http://localhost:8888 /v1" )
2119
2220SYSTEM_PROMPT = '''You are solving AIME (American Invitational Mathematics Examination) problems.
2321
@@ -48,50 +46,30 @@ def extract_answer(response: str) -> Optional[int]:
4846 """
4947 Extract the numerical answer from a math solution response.
5048 Handles various formats of boxed answers and falls back to last number if needed.
51-
52- Args:
53- response (str): The complete response text from the model
54-
55- Returns:
56- Optional[int]: The extracted answer as an integer, or None if no valid answer found
5749 """
5850 if not response :
5951 return None
6052
61- # Clean the response: normalize whitespace and handle potential Unicode
53+ # Clean the response
6254 response = ' ' .join (response .split ())
6355
64- # List of regex patterns to try, in order of preference
6556 patterns = [
66- # $n=\boxed{X}$ format
6757 r'\$n=\\boxed{(\d+)}\$' ,
68-
69- # LaTeX display style answer: \[\boxed{X}\] or \[\boxed{X}.\]
7058 r'\\\[\\boxed{(\d+)}\\\]' ,
7159 r'\\\[\\boxed{(\d+)}\.\\\]' ,
72-
73- # Inline LaTeX \boxed{X}
7460 r'\\boxed{(\d+)}' ,
75-
76- # Common variations
7761 r'\$\\boxed{(\d+)}\$' ,
7862 r'boxed{(\d+)}' ,
79-
80- # Less strict patterns
8163 r'\\boxed\s*{\s*(\d+)\s*}' ,
8264 r'\bboxed\s*{\s*(\d+)\s*}' ,
83-
84- # Plain text answer indicators
8565 r'final answer is[^\d]*(\d+)' ,
8666 r'answer is[^\d]*(\d+)' ,
8767 r'answer:[^\d]*(\d+)' ,
8868 r'= ?(\d+)$'
8969 ]
9070
91- # Try each pattern in order
9271 for pattern in patterns :
9372 matches = re .finditer (pattern , response , re .IGNORECASE )
94- # Get the last match for this pattern (in case there are multiple)
9573 last_match = None
9674 for match in matches :
9775 last_match = match
@@ -102,47 +80,70 @@ def extract_answer(response: str) -> Optional[int]:
10280 except (ValueError , IndexError ):
10381 continue
10482
105- # Fallback: Extract all numbers and take the last one
106- # This is our last resort, assuming the answer typically comes last
10783 numbers = re .findall (r'(\d+)' , response )
10884 if numbers :
10985 try :
110- # Convert to int and return the last number found
11186 return int (numbers [- 1 ])
11287 except ValueError :
11388 pass
11489
115- # If all methods fail, return None
11690 return None
11791
11892def get_llm_response (problem : str , model : str ) -> str :
11993 """
12094 Get response from the LLM for a given problem.
12195 """
12296 try :
123- response = client .chat .completions .create (
97+ response = client .with_options ( timeout = 1000.0 ). chat .completions .create (
12498 model = model ,
12599 messages = [
126- # {"role": "system", "content": SYSTEM_PROMPT},
127100 {"role" : "user" , "content" : SYSTEM_PROMPT + problem }
128101 ],
129102 max_tokens = 8192 ,
130- # extra_body={
131- # "decoding": "entropy_decoding",
132- # }
133103 )
134104 return response .choices [0 ].message .content .strip ()
135105 except Exception as e :
136106 logger .error (f"Error getting LLM response: { e } " )
137107 return ""
138108
139- def evaluate_response (predicted_answer : Optional [int ], correct_answer : int ) -> bool :
109+ def make_n_attempts (problem : str , model : str , n : int ) -> List [Dict ]:
110+ """
111+ Make n attempts to solve a problem and return all responses and predictions.
112+
113+ Args:
114+ problem (str): The problem text
115+ model (str): The model identifier
116+ n (int): Number of attempts to make
117+
118+ Returns:
119+ List[Dict]: List of dictionaries containing response and predicted answer for each attempt
120+ """
121+ attempts = []
122+ for i in range (n ):
123+ response = get_llm_response (problem , model )
124+ predicted_answer = extract_answer (response )
125+ attempts .append ({
126+ "attempt_number" : i + 1 ,
127+ "response" : response ,
128+ "predicted_answer" : predicted_answer
129+ })
130+ return attempts
131+
132+ def evaluate_pass_at_n (attempts : List [Dict ], correct_answer : int ) -> Tuple [bool , Optional [int ]]:
140133 """
141- Evaluate if the predicted answer matches the correct answer.
134+ Evaluate if any of the n attempts got the correct answer.
135+
136+ Args:
137+ attempts (List[Dict]): List of attempt results
138+ correct_answer (int): The correct answer
139+
140+ Returns:
141+ Tuple[bool, Optional[int]]: (whether any attempt was correct, first correct attempt number)
142142 """
143- if predicted_answer is None :
144- return False
145- return predicted_answer == correct_answer
143+ for attempt in attempts :
144+ if attempt ["predicted_answer" ] == correct_answer :
145+ return True , attempt ["attempt_number" ]
146+ return False , None
146147
147148def load_existing_results (filename : str ) -> List [Dict ]:
148149 """Load existing results from file if it exists."""
@@ -165,76 +166,84 @@ def get_last_processed_index(results: List[Dict]) -> int:
165166 return - 1
166167 return max (int (r .get ('index' , - 1 )) for r in results )
167168
168- def analyze_results (results : List [Dict ]):
169- """Analyze and print summary statistics of the results."""
169+ def analyze_results (results : List [Dict ], n : int ):
170+ """
171+ Analyze and print summary statistics of the results.
172+
173+ Args:
174+ results (List[Dict]): List of evaluation results
175+ n (int): Number of attempts per problem
176+ """
170177 total = len (results )
171178 correct = sum (1 for r in results if r ['is_correct' ])
172179 accuracy = correct / total if total > 0 else 0
173180
174181 print ("\n === Results Summary ===" )
182+ print (f"Evaluation mode: pass@{ n } " )
175183 print (f"Total problems: { total } " )
176184 print (f"Correct answers: { correct } " )
177185 print (f"Accuracy: { accuracy :.2%} " )
178186
179- # Print incorrect problems for analysis
180- print ("\n === Incorrect Answers ===" )
187+ # Calculate attempt statistics
188+ successful_attempts = [r ['first_correct_attempt' ] for r in results if r ['is_correct' ]]
189+ if successful_attempts :
190+ avg_attempts = sum (successful_attempts ) / len (successful_attempts )
191+ print (f"\n For correct solutions:" )
192+ print (f"Average attempts needed: { avg_attempts :.2f} " )
193+ print (f"Attempt distribution:" )
194+ for i in range (1 , n + 1 ):
195+ count = sum (1 for x in successful_attempts if x == i )
196+ print (f" Attempt { i } : { count } problems" )
197+
198+ print ("\n === Incorrect Problems ===" )
181199 for r in results :
182200 if not r ['is_correct' ]:
183201 print (f"Problem { r ['index' ]} :" )
184202 print (f"Expected: { r ['correct_answer' ]} " )
185- print (f"Predicted: { r ['predicted_answer' ]} " )
203+ print ("Predicted answers across attempts:" , [
204+ attempt ['predicted_answer' ] for attempt in r ['attempts' ]
205+ ])
186206 print ("---" )
187207
188- def main (model : str ):
208+ def main (model : str , n_attempts : int ):
189209 """Main evaluation function."""
190- # Create results directory if it doesn't exist
191210 os .makedirs ("results" , exist_ok = True )
192211
193- # Setup results file
194- results_file = f"evaluation_results_{ model .replace ('/' , '_' )} .json"
212+ # Include n_attempts in filename to keep separate results for different n values
213+ results_file = f"evaluation_results_{ model .replace ('/' , '_' )} _pass_at_ { n_attempts } .json"
195214
196- # Load dataset
197215 dataset = load_2024_dataset ()
198-
199- # Load existing results
200216 existing_results = load_existing_results (results_file )
201217 last_processed_index = get_last_processed_index (existing_results )
202218
203- # Process problems
204219 for idx , item in enumerate (tqdm (dataset , desc = "Evaluating problems" )):
205220 if idx <= last_processed_index :
206221 continue
207222
208223 problem_text = item ['problem' ]
209224 correct_answer = int (item ['answer' ])
210225
211- # Get model response
212- response = get_llm_response (problem_text , model )
213- logger .debug (f"Response: { response } " )
214- predicted_answer = extract_answer (response )
215- is_correct = evaluate_response (predicted_answer , correct_answer )
226+ # Make n attempts for each problem
227+ attempts = make_n_attempts (problem_text , model , n_attempts )
228+ is_correct , first_correct = evaluate_pass_at_n (attempts , correct_answer )
216229
217- # Save result
218230 result = {
219231 "index" : idx ,
220232 "problem" : problem_text ,
221- "model_response" : response ,
222- "predicted_answer" : predicted_answer ,
233+ "attempts" : attempts ,
223234 "correct_answer" : correct_answer ,
224- "is_correct" : is_correct
235+ "is_correct" : is_correct ,
236+ "first_correct_attempt" : first_correct
225237 }
226238 save_result (results_file , result )
227-
228- # Optional: Add delay between requests if needed
229- time .sleep (300 )
230239
231- # Analyze results
232240 final_results = load_existing_results (results_file )
233- analyze_results (final_results )
241+ analyze_results (final_results , n_attempts )
234242
235243if __name__ == "__main__" :
236244 parser = argparse .ArgumentParser (description = "Evaluate LLM performance on AIME 2024 problems" )
237245 parser .add_argument ("--model" , type = str , required = True , help = "OpenAI model to use (e.g., gpt-4, gpt-3.5-turbo)" )
246+ parser .add_argument ("--n" , type = int , default = 1 , help = "Number of attempts per problem (for pass@n evaluation)" )
238247 args = parser .parse_args ()
239248
240- main (args .model )
249+ main (args .model , args . n )
0 commit comments