44
55
66def calculate_accuracy (responses ):
7- """Calculate accuracy metrics for LongBench v2."""
7+ """Calculate accuracy metrics for LongBench v2.
8+
9+ Logic is aligned with longbench_stx.print_metrics, but returns a dict
10+ and additionally computes by_domain statistics.
11+ """
812 total = len (responses )
913 if total == 0 :
1014 return {}
1115
12- # Overall accuracy
13- correct = sum (1 for r in responses if r .get ("judge" , False ))
14- overall_acc = round (100 * correct / total , 1 )
15-
16- # By difficulty
17- easy_items = [r for r in responses if r .get ("difficulty" ) == "easy" ]
18- hard_items = [r for r in responses if r .get ("difficulty" ) == "hard" ]
19- easy_acc = (
20- round (100 * sum (1 for r in easy_items if r .get ("judge" , False )) / len (easy_items ), 1 )
21- if easy_items
22- else 0.0
23- )
24- hard_acc = (
25- round (100 * sum (1 for r in hard_items if r .get ("judge" , False )) / len (hard_items ), 1 )
26- if hard_items
27- else 0.0
28- )
29-
30- # By length
31- short_items = [r for r in responses if r .get ("length" ) == "short" ]
32- medium_items = [r for r in responses if r .get ("length" ) == "medium" ]
33- long_items = [r for r in responses if r .get ("length" ) == "long" ]
34-
35- short_acc = (
36- round (100 * sum (1 for r in short_items if r .get ("judge" , False )) / len (short_items ), 1 )
37- if short_items
38- else 0.0
39- )
40- medium_acc = (
41- round (100 * sum (1 for r in medium_items if r .get ("judge" , False )) / len (medium_items ), 1 )
42- if medium_items
43- else 0.0
44- )
45- long_acc = (
46- round (100 * sum (1 for r in long_items if r .get ("judge" , False )) / len (long_items ), 1 )
47- if long_items
48- else 0.0
49- )
50-
51- # By domain
16+ # Counters (aligned with longbench_stx.print_metrics)
17+ easy = hard = short = medium = long = 0
18+ easy_acc = hard_acc = short_acc = medium_acc = long_acc = 0
19+ total_prompt_tokens = 0
20+
21+ for pred in responses :
22+ acc = int (pred .get ("judge" , False ))
23+ diff = pred .get ("difficulty" , "easy" )
24+ length = pred .get ("length" , "short" )
25+
26+ pt = pred .get ("prompt_tokens" )
27+ if isinstance (pt , int | float ):
28+ total_prompt_tokens += int (pt )
29+
30+ if diff == "easy" :
31+ easy += 1
32+ easy_acc += acc
33+ else :
34+ hard += 1
35+ hard_acc += acc
36+
37+ if length == "short" :
38+ short += 1
39+ short_acc += acc
40+ elif length == "medium" :
41+ medium += 1
42+ medium_acc += acc
43+ else :
44+ long += 1
45+ long_acc += acc
46+
47+ o_acc = round (100 * (easy_acc + hard_acc ) / total , 2 )
48+ e_acc = round (100 * easy_acc / easy , 2 ) if easy > 0 else 0.0
49+ h_acc = round (100 * hard_acc / hard , 2 ) if hard > 0 else 0.0
50+ s_acc = round (100 * short_acc / short , 2 ) if short > 0 else 0.0
51+ m_acc = round (100 * medium_acc / medium , 2 ) if medium > 0 else 0.0
52+ l_acc = round (100 * long_acc / long , 2 ) if long > 0 else 0.0
53+
54+ # Additional by-domain stats (extra vs. stx)
5255 domain_stats = {}
53- for response in responses :
54- domain = response .get ("domain" , "Unknown" )
56+ for r in responses :
57+ domain = r .get ("domain" , "Unknown" )
5558 if domain not in domain_stats :
5659 domain_stats [domain ] = {"total" : 0 , "correct" : 0 }
5760 domain_stats [domain ]["total" ] += 1
58- if response .get ("judge" , False ):
61+ if r .get ("judge" , False ):
5962 domain_stats [domain ]["correct" ] += 1
6063
6164 domain_acc = {
62- domain : round (100 * stats ["correct" ] / stats ["total" ], 1 )
65+ domain : round (100 * stats ["correct" ] / stats ["total" ], 2 )
6366 for domain , stats in domain_stats .items ()
6467 }
6568
6669 return {
67- "overall" : overall_acc ,
68- "easy" : easy_acc ,
69- "hard" : hard_acc ,
70- "short" : short_acc ,
71- "medium" : medium_acc ,
72- "long" : long_acc ,
70+ "overall" : o_acc ,
71+ "easy" : e_acc ,
72+ "hard" : h_acc ,
73+ "short" : s_acc ,
74+ "medium" : m_acc ,
75+ "long" : l_acc ,
7376 "by_domain" : domain_acc ,
7477 "total_samples" : total ,
75- "correct_samples" : correct ,
78+ "correct_samples" : easy_acc + hard_acc ,
79+ "total_prompt_tokens" : total_prompt_tokens ,
80+ "avg_prompt_tokens" : round (total_prompt_tokens / total , 2 ) if total > 0 else 0.0 ,
7681 }
7782
7883
@@ -92,11 +97,36 @@ def main(frame, version="default"):
9297 with open (responses_path , encoding = "utf-8" ) as f :
9398 responses = json .load (f )
9499
95- # Only keep entries with non-empty context (search_context) to align with response generation
96- filtered = [r for r in responses if str (r .get ("search_context" , "" )).strip () != "" ]
97-
98- # Calculate metrics
99- metrics = calculate_accuracy (filtered )
100+ # Only keep entries that actually have search results:
101+ # - For new pipeline: non-empty memories_used list
102+ # - For older runs: non-empty search_context string
103+ def _has_search_results (r : dict ) -> bool :
104+ mems = r .get ("memories_used" )
105+ if isinstance (mems , list ) and any (str (m ).strip () for m in mems ):
106+ return True
107+ ctx = str (r .get ("search_context" , "" )).strip ()
108+ return ctx != ""
109+
110+ filtered = [r for r in responses if _has_search_results (r )]
111+
112+ # Calculate metrics (handle case where no samples have search results)
113+ if not filtered :
114+ print ("⚠️ No responses with valid search results were found. Metrics will be zeroed." )
115+ metrics = {
116+ "overall" : 0.0 ,
117+ "easy" : 0.0 ,
118+ "hard" : 0.0 ,
119+ "short" : 0.0 ,
120+ "medium" : 0.0 ,
121+ "long" : 0.0 ,
122+ "by_domain" : {},
123+ "total_samples" : 0 ,
124+ "correct_samples" : 0 ,
125+ "total_prompt_tokens" : 0 ,
126+ "avg_prompt_tokens" : 0.0 ,
127+ }
128+ else :
129+ metrics = calculate_accuracy (filtered )
100130
101131 # Save metrics
102132 output_path = f"results/long_bench_v2/{ frame } -{ version } /{ frame } _longbench_v2_metrics.json"
@@ -112,12 +142,13 @@ def main(frame, version="default"):
112142 # Print summary table
113143 print ("\n 📊 Summary of Results:" )
114144 print ("-" * 80 )
115- print (f"{ 'Overall Accuracy' :<30s} : { metrics ['overall' ]:.1f} %" )
116- print (f"{ 'Easy' :<30s} : { metrics ['easy' ]:.1f} %" )
117- print (f"{ 'Hard' :<30s} : { metrics ['hard' ]:.1f} %" )
118- print (f"{ 'Short' :<30s} : { metrics ['short' ]:.1f} %" )
119- print (f"{ 'Medium' :<30s} : { metrics ['medium' ]:.1f} %" )
120- print (f"{ 'Long' :<30s} : { metrics ['long' ]:.1f} %" )
145+ print (f"{ 'Overall Accuracy' :<30s} : { metrics ['overall' ]:.2f} %" )
146+ print (f"{ 'Easy' :<30s} : { metrics ['easy' ]:.2f} %" )
147+ print (f"{ 'Hard' :<30s} : { metrics ['hard' ]:.2f} %" )
148+ print (f"{ 'Short' :<30s} : { metrics ['short' ]:.2f} %" )
149+ print (f"{ 'Medium' :<30s} : { metrics ['medium' ]:.2f} %" )
150+ print (f"{ 'Long' :<30s} : { metrics ['long' ]:.2f} %" )
151+ print (f"{ 'Avg Prompt Tokens' :<30s} : { metrics .get ('avg_prompt_tokens' , 0.0 ):.2f} " )
121152 print ("\n By Domain:" )
122153 for domain , acc in metrics ["by_domain" ].items ():
123154 print (f" { domain :<28s} : { acc :.1f} %" )
0 commit comments