@@ -281,33 +281,34 @@ def stucking_checker():
281281 if not os .path .isfile (result_path ):
282282 with open (result_path , "w" ) as f :
283283 json .dump (results , f , indent = 2 )
284-
285- pass_at_k_path = result_path .replace ("_eval_results.json" , "_pass_at_k.json" )
286- pass_at_k ["model" ] = os .path .basename (flags .samples ).split ("--bigcodebench-" )[0 ]
287- pass_at_k ["calibrated" ] = "sanitized-calibrated" in flags .samples
288- pass_at_k ["subset" ] = flags .subset
289-
290- def save_pass_at_k ():
291- with open (pass_at_k_path , "w" ) as f :
292- json .dump (pass_at_k , f , indent = 2 )
293-
294- if os .path .isfile (pass_at_k_path ):
295- saved_pass_at_k = json .load (open (pass_at_k_path , "r" ))
296- # compare saved_pass_at_k with pass_at_k
297- for k in saved_pass_at_k .keys ():
298- if pass_at_k [k ] != saved_pass_at_k [k ]:
299- cprint (f"Warning: { k } is different from the saved one" , "yellow" )
284+
285+ if flags .save_pass_rate :
286+ pass_at_k_path = result_path .replace ("_eval_results.json" , "_pass_at_k.json" )
287+ pass_at_k ["model" ] = os .path .basename (flags .samples ).split ("--bigcodebench-" )[0 ]
288+ pass_at_k ["calibrated" ] = "sanitized-calibrated" in flags .samples
289+ pass_at_k ["subset" ] = flags .subset
290+
291+ def save_pass_at_k ():
292+ with open (pass_at_k_path , "w" ) as f :
293+ json .dump (pass_at_k , f , indent = 2 )
294+
295+ if os .path .isfile (pass_at_k_path ):
296+ saved_pass_at_k = json .load (open (pass_at_k_path , "r" ))
297+ # compare saved_pass_at_k with pass_at_k
298+ for k in saved_pass_at_k .keys ():
299+ if pass_at_k [k ] != saved_pass_at_k [k ]:
300+ cprint (f"Warning: { k } is different from the saved one" , "yellow" )
301+
302+ # ask user whether to save the pass@k
303+ decision = ""
304+ while decision .lower () not in ["y" , "n" ]:
305+ print (f"Save pass@k to { pass_at_k_path } ? [Y/N]" )
306+ decision = input ()
307+ if decision .lower () == "y" :
308+ save_pass_at_k ()
300309
301- # ask user whether to save the pass@k
302- decision = ""
303- while decision .lower () not in ["y" , "n" ]:
304- print (f"Save pass@k to { pass_at_k_path } ? [Y/N]" )
305- decision = input ()
306- if decision .lower () == "y" :
310+ else :
307311 save_pass_at_k ()
308-
309- else :
310- save_pass_at_k ()
311312
312313
313314def main ():
@@ -317,6 +318,7 @@ def main():
317318 )
318319 parser .add_argument ("--subset" , default = "full" , type = str , choices = ["full" , "hard" ])
319320 parser .add_argument ("--samples" , required = True , type = str )
321+ parser .add_argument ("--save_pass_rate" , action = "store_true" )
320322 parser .add_argument ("--parallel" , default = None , type = int )
321323 parser .add_argument ("--min-time-limit" , default = 1 , type = float )
322324 parser .add_argument ("--max-as-limit" , default = 128 * 1024 , type = int )
0 commit comments