@@ -256,7 +256,7 @@ def analyze_logits_probs(logprobs_data: List[Dict]) -> Dict:
256256 "token_count" : len (token_entropies )
257257 }
258258
259- def get_llm_response (problem : str , model : str , analyze_logits : bool = False ) -> Union [str , List [Dict ]]:
259+ def get_llm_response (problem : str , model : str , analyze_logits : bool = False , extra_body : dict = None ) -> Union [str , List [Dict ]]:
260260 """
261261 Get response from the LLM for a given problem.
262262 If multiple choices are returned, formats them as attempt dictionaries.
@@ -276,18 +276,16 @@ def get_llm_response(problem: str, model: str, analyze_logits: bool = False) ->
276276 kwargs ["logprobs" ] = True
277277 kwargs ["top_logprobs" ] = 3
278278
279+ # Add extra_body if provided
280+ if extra_body :
281+ kwargs ["extra_body" ] = extra_body
282+
279283 response = client .with_options (timeout = 1000.0 ).chat .completions .create (
280284 model = model ,
281285 messages = [
282286 {"role" : "user" , "content" : SYSTEM_PROMPT + problem }
283287 ],
284288 max_tokens = 8192 ,
285- # extra_body={
286- # "decoding": "thinkdeeper",
287- # "min_thinking_tokens" : 0,
288- # "max_thinking_tokens" : 8000,
289- # "max_thoughts": 100,
290- # },
291289 ** kwargs
292290 )
293291
@@ -333,7 +331,7 @@ def get_llm_response(problem: str, model: str, analyze_logits: bool = False) ->
333331 logger .error (f"Error getting LLM response: { e } " )
334332 return ""
335333
336- def make_n_attempts (problem : str , model : str , n : int , analyze_thoughts : bool = False , analyze_logits : bool = False ) -> List [Dict ]:
334+ def make_n_attempts (problem : str , model : str , n : int , analyze_thoughts : bool = False , analyze_logits : bool = False , extra_body : dict = None ) -> List [Dict ]:
337335 """
338336 Make n attempts to solve a problem and return all responses and predictions.
339337
@@ -351,7 +349,7 @@ def make_n_attempts(problem: str, model: str, n: int, analyze_thoughts: bool = F
351349 remaining_attempts = n
352350
353351 while remaining_attempts > 0 :
354- response = get_llm_response (problem , model , analyze_logits )
352+ response = get_llm_response (problem , model , analyze_logits , extra_body )
355353
356354 # If response is already formatted as attempts
357355 if isinstance (response , list ):
@@ -774,7 +772,7 @@ def save_raw_response(filename: str, problem_id: int, response_data: Dict):
774772
775773 return response_id
776774
777- def main (model : str , n_attempts : int , analyze_thoughts : bool = False , analyze_logits : bool = False ):
775+ def main (model : str , n_attempts : int , analyze_thoughts : bool = False , analyze_logits : bool = False , test_time_compute : bool = False , approach_name : str = None , extra_body : dict = None ):
778776 """Main evaluation function that handles gaps in processed indexes."""
779777 os .makedirs ("results" , exist_ok = True )
780778
@@ -784,6 +782,8 @@ def main(model: str, n_attempts: int, analyze_thoughts: bool = False, analyze_lo
784782 suffix_parts .append ("thought_analysis" )
785783 if analyze_logits :
786784 suffix_parts .append ("logit_analysis" )
785+ if approach_name :
786+ suffix_parts .append (approach_name )
787787
788788 suffix = "_" + "_" .join (suffix_parts ) if suffix_parts else ""
789789 results_file = f"results/evaluation_results_{ model .replace ('/' , '_' )} _pass_at_{ n_attempts } { suffix } .json"
@@ -804,7 +804,7 @@ def main(model: str, n_attempts: int, analyze_thoughts: bool = False, analyze_lo
804804 correct_answer = int (item ['answer' ])
805805
806806 # Make n attempts for each problem
807- attempts = make_n_attempts (problem_text , model , n_attempts , analyze_thoughts , analyze_logits )
807+ attempts = make_n_attempts (problem_text , model , n_attempts , analyze_thoughts , analyze_logits , extra_body )
808808 is_correct , first_correct = evaluate_pass_at_n (attempts , correct_answer )
809809
810810 result = {
@@ -826,6 +826,51 @@ def main(model: str, n_attempts: int, analyze_thoughts: bool = False, analyze_lo
826826 parser .add_argument ("--n" , type = int , default = 1 , help = "Number of attempts per problem (for pass@n evaluation)" )
827827 parser .add_argument ("--analyze-thoughts" , action = "store_true" , help = "Analyze thinking patterns in responses" )
828828 parser .add_argument ("--analyze-logits" , action = "store_true" , help = "Analyze token probability distributions" )
829+ parser .add_argument ("--test-time-compute" , action = "store_true" , help = "Evaluate test-time compute scaling approaches" )
829830 args = parser .parse_args ()
830831
831- main (args .model , args .n , args .analyze_thoughts , args .analyze_logits )
832+ if args .test_time_compute :
833+ # Define test-time compute approaches with same config as eval_optillmbench.py
834+ TEST_TIME_COMPUTE_APPROACHES = [
835+ # Baseline
836+ ("none" , "Baseline without any optimization" , {}),
837+
838+ # Sequential test-time compute using thinkdeeper with controlled thinking budgets
839+ ("thinkdeeper_2k" , "ThinkDeeper with 2K thinking tokens" , {
840+ "decoding" : "thinkdeeper" ,
841+ "min_thinking_tokens" : 2048 ,
842+ "max_thinking_tokens" : 2560 , # min + 512 for flexibility
843+ "max_tokens" : 3072 # Total budget: max_thinking_tokens + 512
844+ }),
845+ ("thinkdeeper_4k" , "ThinkDeeper with 4K thinking tokens" , {
846+ "decoding" : "thinkdeeper" ,
847+ "min_thinking_tokens" : 4096 ,
848+ "max_thinking_tokens" : 4608 , # min + 512 for flexibility
849+ "max_tokens" : 5120 # Total budget: max_thinking_tokens + 512
850+ }),
851+ ("thinkdeeper_8k" , "ThinkDeeper with 8K thinking tokens" , {
852+ "decoding" : "thinkdeeper" ,
853+ "min_thinking_tokens" : 8192 ,
854+ "max_thinking_tokens" : 8704 , # min + 512 for flexibility
855+ "max_tokens" : 9216 # Total budget: max_thinking_tokens + 512
856+ }),
857+
858+ # Parallel test-time compute using majority voting with different k values
859+ ("majority_voting_3" , "Majority Voting with k=3" , {"k" : 3 }),
860+ ("majority_voting_6" , "Majority Voting with k=6" , {"k" : 6 }),
861+ ("majority_voting_9" , "Majority Voting with k=9" , {"k" : 9 }),
862+ ]
863+
864+ # Run evaluation for each approach
865+ for approach_slug , approach_name , extra_body in TEST_TIME_COMPUTE_APPROACHES :
866+ print (f"\n { '=' * 80 } " )
867+ print (f"Evaluating: { approach_name } " )
868+ print (f"Model: { args .model } " )
869+ print (f"Approach: { approach_slug } " )
870+ print (f"Extra body: { extra_body } " )
871+ print (f"{ '=' * 80 } \n " )
872+
873+ main (args .model , args .n , args .analyze_thoughts , args .analyze_logits ,
874+ test_time_compute = True , approach_name = approach_slug , extra_body = extra_body )
875+ else :
876+ main (args .model , args .n , args .analyze_thoughts , args .analyze_logits )
0 commit comments