44import logging
55import re
66import time
7- from typing import List , Dict , Tuple , Optional
7+ from typing import List , Dict , Tuple , Optional , Union
88from datetime import datetime
99from openai import OpenAI
1010from datasets import load_dataset
@@ -89,9 +89,17 @@ def extract_answer(response: str) -> Optional[int]:
8989
9090 return None
9191
92- def get_llm_response (problem : str , model : str ) -> str :
92+ def get_llm_response (problem : str , model : str ) -> Union [ str , List [ Dict ]] :
9393 """
9494 Get response from the LLM for a given problem.
95+ If multiple choices are returned, formats them as attempt dictionaries.
96+
97+ Args:
98+ problem (str): The problem text
99+ model (str): The model identifier
100+
101+ Returns:
102+ Union[str, List[Dict]]: Either a string response or list of attempt dictionaries
95103 """
96104 try :
97105 response = client .with_options (timeout = 1000.0 ).chat .completions .create (
@@ -101,7 +109,23 @@ def get_llm_response(problem: str, model: str) -> str:
101109 ],
102110 max_tokens = 8192 ,
103111 )
112+
113+ # If there's more than one choice, format as attempts
114+ if len (response .choices ) > 1 :
115+ attempts = []
116+ for i , choice in enumerate (response .choices ):
117+ response_text = choice .message .content .strip ()
118+ predicted_answer = extract_answer (response_text )
119+ attempts .append ({
120+ "attempt_number" : i + 1 ,
121+ "response" : response_text ,
122+ "predicted_answer" : predicted_answer
123+ })
124+ return attempts
125+
126+ # If single choice, return as before
104127 return response .choices [0 ].message .content .strip ()
128+
105129 except Exception as e :
106130 logger .error (f"Error getting LLM response: { e } " )
107131 return ""
@@ -119,14 +143,25 @@ def make_n_attempts(problem: str, model: str, n: int) -> List[Dict]:
119143 List[Dict]: List of dictionaries containing response and predicted answer for each attempt
120144 """
121145 attempts = []
122- for i in range (n ):
146+ remaining_attempts = n
147+
148+ while remaining_attempts > 0 :
123149 response = get_llm_response (problem , model )
124- predicted_answer = extract_answer (response )
125- attempts .append ({
126- "attempt_number" : i + 1 ,
127- "response" : response ,
128- "predicted_answer" : predicted_answer
129- })
150+
151+ # If response is already formatted as attempts
152+ if isinstance (response , list ):
153+ attempts .extend (response [:remaining_attempts ]) # Only take what we need
154+ remaining_attempts -= len (response )
155+ else :
156+ # Process single response as before
157+ predicted_answer = extract_answer (response )
158+ attempts .append ({
159+ "attempt_number" : len (attempts ) + 1 ,
160+ "response" : response ,
161+ "predicted_answer" : predicted_answer
162+ })
163+ remaining_attempts -= 1
164+
130165 return attempts
131166
132167def evaluate_pass_at_n (attempts : List [Dict ], correct_answer : int ) -> Tuple [bool , Optional [int ]]:
0 commit comments