Skip to content

Commit 5809463

Browse files
committed
support ttc argument
1 parent d2ed063 commit 5809463

File tree

2 files changed

+72
-27
lines changed

2 files changed

+72
-27
lines changed

scripts/eval_aime_benchmark.py

Lines changed: 57 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@ def analyze_logits_probs(logprobs_data: List[Dict]) -> Dict:
256256
"token_count": len(token_entropies)
257257
}
258258

259-
def get_llm_response(problem: str, model: str, analyze_logits: bool = False) -> Union[str, List[Dict]]:
259+
def get_llm_response(problem: str, model: str, analyze_logits: bool = False, extra_body: dict = None) -> Union[str, List[Dict]]:
260260
"""
261261
Get response from the LLM for a given problem.
262262
If multiple choices are returned, formats them as attempt dictionaries.
@@ -276,18 +276,16 @@ def get_llm_response(problem: str, model: str, analyze_logits: bool = False) ->
276276
kwargs["logprobs"] = True
277277
kwargs["top_logprobs"] = 3
278278

279+
# Add extra_body if provided
280+
if extra_body:
281+
kwargs["extra_body"] = extra_body
282+
279283
response = client.with_options(timeout=1000.0).chat.completions.create(
280284
model=model,
281285
messages=[
282286
{"role": "user", "content": SYSTEM_PROMPT + problem}
283287
],
284288
max_tokens=8192,
285-
# extra_body={
286-
# "decoding": "thinkdeeper",
287-
# "min_thinking_tokens" : 0,
288-
# "max_thinking_tokens" : 8000,
289-
# "max_thoughts": 100,
290-
# },
291289
**kwargs
292290
)
293291

@@ -333,7 +331,7 @@ def get_llm_response(problem: str, model: str, analyze_logits: bool = False) ->
333331
logger.error(f"Error getting LLM response: {e}")
334332
return ""
335333

336-
def make_n_attempts(problem: str, model: str, n: int, analyze_thoughts: bool = False, analyze_logits: bool = False) -> List[Dict]:
334+
def make_n_attempts(problem: str, model: str, n: int, analyze_thoughts: bool = False, analyze_logits: bool = False, extra_body: dict = None) -> List[Dict]:
337335
"""
338336
Make n attempts to solve a problem and return all responses and predictions.
339337
@@ -351,7 +349,7 @@ def make_n_attempts(problem: str, model: str, n: int, analyze_thoughts: bool = F
351349
remaining_attempts = n
352350

353351
while remaining_attempts > 0:
354-
response = get_llm_response(problem, model, analyze_logits)
352+
response = get_llm_response(problem, model, analyze_logits, extra_body)
355353

356354
# If response is already formatted as attempts
357355
if isinstance(response, list):
@@ -774,7 +772,7 @@ def save_raw_response(filename: str, problem_id: int, response_data: Dict):
774772

775773
return response_id
776774

777-
def main(model: str, n_attempts: int, analyze_thoughts: bool = False, analyze_logits: bool = False):
775+
def main(model: str, n_attempts: int, analyze_thoughts: bool = False, analyze_logits: bool = False, test_time_compute: bool = False, approach_name: str = None, extra_body: dict = None):
778776
"""Main evaluation function that handles gaps in processed indexes."""
779777
os.makedirs("results", exist_ok=True)
780778

@@ -784,6 +782,8 @@ def main(model: str, n_attempts: int, analyze_thoughts: bool = False, analyze_lo
784782
suffix_parts.append("thought_analysis")
785783
if analyze_logits:
786784
suffix_parts.append("logit_analysis")
785+
if approach_name:
786+
suffix_parts.append(approach_name)
787787

788788
suffix = "_" + "_".join(suffix_parts) if suffix_parts else ""
789789
results_file = f"results/evaluation_results_{model.replace('/', '_')}_pass_at_{n_attempts}{suffix}.json"
@@ -804,7 +804,7 @@ def main(model: str, n_attempts: int, analyze_thoughts: bool = False, analyze_lo
804804
correct_answer = int(item['answer'])
805805

806806
# Make n attempts for each problem
807-
attempts = make_n_attempts(problem_text, model, n_attempts, analyze_thoughts, analyze_logits)
807+
attempts = make_n_attempts(problem_text, model, n_attempts, analyze_thoughts, analyze_logits, extra_body)
808808
is_correct, first_correct = evaluate_pass_at_n(attempts, correct_answer)
809809

810810
result = {
@@ -826,6 +826,51 @@ def main(model: str, n_attempts: int, analyze_thoughts: bool = False, analyze_lo
826826
parser.add_argument("--n", type=int, default=1, help="Number of attempts per problem (for pass@n evaluation)")
827827
parser.add_argument("--analyze-thoughts", action="store_true", help="Analyze thinking patterns in responses")
828828
parser.add_argument("--analyze-logits", action="store_true", help="Analyze token probability distributions")
829+
parser.add_argument("--test-time-compute", action="store_true", help="Evaluate test-time compute scaling approaches")
829830
args = parser.parse_args()
830831

831-
main(args.model, args.n, args.analyze_thoughts, args.analyze_logits)
832+
if args.test_time_compute:
833+
# Define test-time compute approaches with same config as eval_optillmbench.py
834+
TEST_TIME_COMPUTE_APPROACHES = [
835+
# Baseline
836+
("none", "Baseline without any optimization", {}),
837+
838+
# Sequential test-time compute using thinkdeeper with controlled thinking budgets
839+
("thinkdeeper_2k", "ThinkDeeper with 2K thinking tokens", {
840+
"decoding": "thinkdeeper",
841+
"min_thinking_tokens": 2048,
842+
"max_thinking_tokens": 2560, # min + 512 for flexibility
843+
"max_tokens": 3072 # Total budget: max_thinking_tokens + 512
844+
}),
845+
("thinkdeeper_4k", "ThinkDeeper with 4K thinking tokens", {
846+
"decoding": "thinkdeeper",
847+
"min_thinking_tokens": 4096,
848+
"max_thinking_tokens": 4608, # min + 512 for flexibility
849+
"max_tokens": 5120 # Total budget: max_thinking_tokens + 512
850+
}),
851+
("thinkdeeper_8k", "ThinkDeeper with 8K thinking tokens", {
852+
"decoding": "thinkdeeper",
853+
"min_thinking_tokens": 8192,
854+
"max_thinking_tokens": 8704, # min + 512 for flexibility
855+
"max_tokens": 9216 # Total budget: max_thinking_tokens + 512
856+
}),
857+
858+
# Parallel test-time compute using majority voting with different k values
859+
("majority_voting_3", "Majority Voting with k=3", {"k": 3}),
860+
("majority_voting_6", "Majority Voting with k=6", {"k": 6}),
861+
("majority_voting_9", "Majority Voting with k=9", {"k": 9}),
862+
]
863+
864+
# Run evaluation for each approach
865+
for approach_slug, approach_name, extra_body in TEST_TIME_COMPUTE_APPROACHES:
866+
print(f"\n{'=' * 80}")
867+
print(f"Evaluating: {approach_name}")
868+
print(f"Model: {args.model}")
869+
print(f"Approach: {approach_slug}")
870+
print(f"Extra body: {extra_body}")
871+
print(f"{'=' * 80}\n")
872+
873+
main(args.model, args.n, args.analyze_thoughts, args.analyze_logits,
874+
test_time_compute=True, approach_name=approach_slug, extra_body=extra_body)
875+
else:
876+
main(args.model, args.n, args.analyze_thoughts, args.analyze_logits)

scripts/eval_optillmbench.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -41,30 +41,30 @@
4141
# Baseline
4242
("none", "Baseline without any optimization", {}),
4343

44-
# Sequential test-time compute using thinkdeeper with different minimum thinking budgets
45-
("thinkdeeper_4k", "ThinkDeeper with 4K min thinking tokens", {
44+
# Sequential test-time compute using thinkdeeper with controlled thinking budgets
45+
("thinkdeeper_2k", "ThinkDeeper with 2K thinking tokens", {
4646
"decoding": "thinkdeeper",
47-
"min_thinking_tokens": 4000,
48-
"max_thinking_tokens": 20000, # Allow up to 20K for completion
49-
"max_tokens": 24000 # Total budget: 20K thinking + 4K response
47+
"min_thinking_tokens": 2048,
48+
"max_thinking_tokens": 2560, # min + 512 for flexibility
49+
"max_tokens": 3072 # Total budget: max_thinking_tokens + 512
5050
}),
51-
("thinkdeeper_8k", "ThinkDeeper with 8K min thinking tokens", {
51+
("thinkdeeper_4k", "ThinkDeeper with 4K thinking tokens", {
5252
"decoding": "thinkdeeper",
53-
"min_thinking_tokens": 8000,
54-
"max_thinking_tokens": 32000, # Allow up to 32K for completion
55-
"max_tokens": 36000 # Total budget: 32K thinking + 4K response
53+
"min_thinking_tokens": 4096,
54+
"max_thinking_tokens": 4608, # min + 512 for flexibility
55+
"max_tokens": 5120 # Total budget: max_thinking_tokens + 512
5656
}),
57-
("thinkdeeper_16k", "ThinkDeeper with 16K min thinking tokens", {
57+
("thinkdeeper_8k", "ThinkDeeper with 8K thinking tokens", {
5858
"decoding": "thinkdeeper",
59-
"min_thinking_tokens": 16000,
60-
"max_thinking_tokens": 48000, # Allow up to 48K for completion
61-
"max_tokens": 52000 # Total budget: 48K thinking + 4K response
59+
"min_thinking_tokens": 8192,
60+
"max_thinking_tokens": 8704, # min + 512 for flexibility
61+
"max_tokens": 9216 # Total budget: max_thinking_tokens + 512
6262
}),
6363

6464
# Parallel test-time compute using majority voting with different k values
65+
("majority_voting_3", "Majority Voting with k=3", {"k": 3}),
6566
("majority_voting_6", "Majority Voting with k=6", {"k": 6}),
66-
("majority_voting_12", "Majority Voting with k=12", {"k": 12}),
67-
("majority_voting_18", "Majority Voting with k=18", {"k": 18}),
67+
("majority_voting_9", "Majority Voting with k=9", {"k": 9}),
6868
]
6969

7070
def load_optillm_bench() -> datasets.Dataset:

0 commit comments

Comments
 (0)