Skip to content

Commit 057a84a

Browse files
committed
feat: make pass_rate_save optional
1 parent c65100d commit 057a84a

File tree

1 file changed

+27
-25
lines changed

1 file changed

+27
-25
lines changed

bigcodebench/evaluate.py

Lines changed: 27 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -281,33 +281,34 @@ def stucking_checker():
281281
if not os.path.isfile(result_path):
282282
with open(result_path, "w") as f:
283283
json.dump(results, f, indent=2)
284-
285-
pass_at_k_path = result_path.replace("_eval_results.json", "_pass_at_k.json")
286-
pass_at_k["model"] = os.path.basename(flags.samples).split("--bigcodebench-")[0]
287-
pass_at_k["calibrated"] = "sanitized-calibrated" in flags.samples
288-
pass_at_k["subset"] = flags.subset
289-
290-
def save_pass_at_k():
291-
with open(pass_at_k_path, "w") as f:
292-
json.dump(pass_at_k, f, indent=2)
293-
294-
if os.path.isfile(pass_at_k_path):
295-
saved_pass_at_k = json.load(open(pass_at_k_path, "r"))
296-
# compare saved_pass_at_k with pass_at_k
297-
for k in saved_pass_at_k.keys():
298-
if pass_at_k[k] != saved_pass_at_k[k]:
299-
cprint(f"Warning: {k} is different from the saved one", "yellow")
284+
285+
if flags.save_pass_rate:
286+
pass_at_k_path = result_path.replace("_eval_results.json", "_pass_at_k.json")
287+
pass_at_k["model"] = os.path.basename(flags.samples).split("--bigcodebench-")[0]
288+
pass_at_k["calibrated"] = "sanitized-calibrated" in flags.samples
289+
pass_at_k["subset"] = flags.subset
290+
291+
def save_pass_at_k():
292+
with open(pass_at_k_path, "w") as f:
293+
json.dump(pass_at_k, f, indent=2)
294+
295+
if os.path.isfile(pass_at_k_path):
296+
saved_pass_at_k = json.load(open(pass_at_k_path, "r"))
297+
# compare saved_pass_at_k with pass_at_k
298+
for k in saved_pass_at_k.keys():
299+
if pass_at_k[k] != saved_pass_at_k[k]:
300+
cprint(f"Warning: {k} is different from the saved one", "yellow")
301+
302+
# ask user whether to save the pass@k
303+
decision = ""
304+
while decision.lower() not in ["y", "n"]:
305+
print(f"Save pass@k to {pass_at_k_path}? [Y/N]")
306+
decision = input()
307+
if decision.lower() == "y":
308+
save_pass_at_k()
300309

301-
# ask user whether to save the pass@k
302-
decision = ""
303-
while decision.lower() not in ["y", "n"]:
304-
print(f"Save pass@k to {pass_at_k_path}? [Y/N]")
305-
decision = input()
306-
if decision.lower() == "y":
310+
else:
307311
save_pass_at_k()
308-
309-
else:
310-
save_pass_at_k()
311312

312313

313314
def main():
@@ -317,6 +318,7 @@ def main():
317318
)
318319
parser.add_argument("--subset", default="full", type=str, choices=["full", "hard"])
319320
parser.add_argument("--samples", required=True, type=str)
321+
parser.add_argument("--save_pass_rate", action="store_true")
320322
parser.add_argument("--parallel", default=None, type=int)
321323
parser.add_argument("--min-time-limit", default=1, type=float)
322324
parser.add_argument("--max-as-limit", default=128*1024, type=int)

0 commit comments

Comments
 (0)