1111from concurrent .futures import ThreadPoolExecutor , as_completed
1212
1313from datasets import load_dataset
14- import optimize # the file Weco mutates
14+ import optimize # the file Weco mutates
1515
1616# ---------------------------------------------------------------------
1717# Configuration
18- TOTAL_SAMPLES = 20 # how many problems to load
19- NUM_WORKERS = 20 # concurrent LLM calls
20- LOG_EVERY = 5 # print progress after this many
18+ TOTAL_SAMPLES = 20 # how many problems to load
19+ NUM_WORKERS = 20 # concurrent LLM calls
20+ LOG_EVERY = 5 # print progress after this many
2121# ---------------------------------------------------------------------
2222
2323print (f"[setup] loading { TOTAL_SAMPLES } problems from AIME 2024 …" )
24- DATA = load_dataset (
25- "Maxwell-Jia/AIME_2024" ,
26- split = f"train[:{ TOTAL_SAMPLES } ]" ,
27- cache_dir = ".cache"
28- )
24+ DATA = load_dataset ("Maxwell-Jia/AIME_2024" , split = f"train[:{ TOTAL_SAMPLES } ]" , cache_dir = ".cache" )
25+
2926
3027def extract_number (text : str ) -> str :
3128 m = re .search (r"\b(\d{1,3})\b" , text )
3229 return m .group (1 ) if m else ""
3330
31+
3432def score_one (row ) -> bool :
3533 guess = extract_number (optimize .solve (row ["Problem" ]))
3634 return guess == str (row ["Answer" ])
3735
36+
3837def accuracy () -> float :
3938 correct = 0
4039 start = time .time ()
@@ -46,13 +45,11 @@ def accuracy() -> float:
4645
4746 if idx % LOG_EVERY == 0 or idx == TOTAL_SAMPLES :
4847 elapsed = time .time () - start
49- print (
50- f"[progress] { idx } /{ TOTAL_SAMPLES } completed, "
51- f"elapsed { elapsed :.1f} s"
52- )
48+ print (f"[progress] { idx } /{ TOTAL_SAMPLES } completed, elapsed { elapsed :.1f} s" )
5349
5450 return correct / TOTAL_SAMPLES
5551
52+
5653if __name__ == "__main__" :
5754 acc = accuracy ()
58- print (f"accuracy: { acc :.4f} " ) # Weco parses this line
55+ print (f"accuracy: { acc :.4f} " ) # Weco parses this line
0 commit comments