@@ -124,15 +124,15 @@ def evaluate(flags):
124124 assert flags .samples .endswith (".jsonl" )
125125 result_path = flags .samples .replace (".jsonl" , "_eval_results.json" )
126126
127- problems = get_bigcodebench ()
128- dataset_hash = get_bigcodebench_hash ()
127+ problems = get_bigcodebench (hard = flags . hard )
128+ dataset_hash = get_bigcodebench_hash (hard = flags . hard )
129129
130130 if not flags .no_gt :
131131 expected_time = get_groundtruth (n_workers , problems , dataset_hash , flags .check_gt_only , flags .max_as_limit , flags .max_data_limit , flags .max_stack_limit )
132132 else :
133133 expected_time = {task_id : None for task_id in problems }
134134
135- gt_pass_rate = np .mean ([1 if v is not None else 0 for v in expected_time .values () ])
135+ gt_pass_rate = np .mean ([1 if v is not None else 0 for k , v in expected_time .items () if k in problems ])
136136
137137 if os .path .isfile (result_path ):
138138 print (f"Load from previous results from { result_path } " )
@@ -229,10 +229,12 @@ def stucking_checker():
229229 )
230230
231231 # Calculate pass@k.
232- total = np .array ([len (r ) for r in results ["eval" ].values () ])
232+ total = np .array ([len (r ) for k , r in results ["eval" ].items () if k in problems ])
233233 base_correct = []
234234
235- for res in results ["eval" ].values ():
235+ for key , res in results ["eval" ].items ():
236+ if key not in problems :
237+ continue
236238 bc = sum ([r ["status" ] == PASS for r in res ])
237239 base_correct .append (bc )
238240
@@ -245,8 +247,9 @@ def stucking_checker():
245247 }
246248
247249 mode = "-calibrated" if "sanitized-calibrated" in flags .samples else ""
250+ extra = "Full" if not flags .hard else "Hard"
248251 flags .subset = flags .subset [0 ].upper () + flags .subset [1 :]
249- cprint (f"BigCodeBench-{ flags .subset } { mode } " , "green" )
252+ cprint (f"BigCodeBench-{ flags .subset } { mode } ( { extra } ) " , "green" )
250253
251254 if flags .no_gt :
252255 cprint (f"Groundtruth is not checked" , "yellow" )
@@ -284,6 +287,7 @@ def main():
284287 parser .add_argument (
285288 "--subset" , required = True , type = str , choices = ["complete" , "instruct" ]
286289 )
290+ parser .add_argument ("--hard" , action = "store_true" )
287291 parser .add_argument ("--samples" , required = True , type = str )
288292 parser .add_argument ("--parallel" , default = None , type = int )
289293 parser .add_argument ("--min-time-limit" , default = 1 , type = float )
0 commit comments