Skip to content

Commit be21a9a

Browse files
committed
feat: add support hard eval
1 parent ce53dc1 commit be21a9a

File tree

1 file changed

+10
-6
lines changed

1 file changed

+10
-6
lines changed

bigcodebench/evaluate.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -124,15 +124,15 @@ def evaluate(flags):
124124
assert flags.samples.endswith(".jsonl")
125125
result_path = flags.samples.replace(".jsonl", "_eval_results.json")
126126

127-
problems = get_bigcodebench()
128-
dataset_hash = get_bigcodebench_hash()
127+
problems = get_bigcodebench(hard=flags.hard)
128+
dataset_hash = get_bigcodebench_hash(hard=flags.hard)
129129

130130
if not flags.no_gt:
131131
expected_time = get_groundtruth(n_workers, problems, dataset_hash, flags.check_gt_only, flags.max_as_limit, flags.max_data_limit, flags.max_stack_limit)
132132
else:
133133
expected_time = {task_id: None for task_id in problems}
134134

135-
gt_pass_rate = np.mean([1 if v is not None else 0 for v in expected_time.values()])
135+
gt_pass_rate = np.mean([1 if v is not None else 0 for k, v in expected_time.items() if k in problems])
136136

137137
if os.path.isfile(result_path):
138138
print(f"Load from previous results from {result_path}")
@@ -229,10 +229,12 @@ def stucking_checker():
229229
)
230230

231231
# Calculate pass@k.
232-
total = np.array([len(r) for r in results["eval"].values()])
232+
total = np.array([len(r) for k, r in results["eval"].items() if k in problems])
233233
base_correct = []
234234

235-
for res in results["eval"].values():
235+
for key, res in results["eval"].items():
236+
if key not in problems:
237+
continue
236238
bc = sum([r["status"] == PASS for r in res])
237239
base_correct.append(bc)
238240

@@ -245,8 +247,9 @@ def stucking_checker():
245247
}
246248

247249
mode = "-calibrated" if "sanitized-calibrated" in flags.samples else ""
250+
extra = "Full" if not flags.hard else "Hard"
248251
flags.subset = flags.subset[0].upper() + flags.subset[1:]
249-
cprint(f"BigCodeBench-{flags.subset}{mode}", "green")
252+
cprint(f"BigCodeBench-{flags.subset}{mode} ({extra})", "green")
250253

251254
if flags.no_gt:
252255
cprint(f"Groundtruth is not checked", "yellow")
@@ -284,6 +287,7 @@ def main():
284287
parser.add_argument(
285288
"--subset", required=True, type=str, choices=["complete", "instruct"]
286289
)
290+
parser.add_argument("--hard", action="store_true")
287291
parser.add_argument("--samples", required=True, type=str)
288292
parser.add_argument("--parallel", default=None, type=int)
289293
parser.add_argument("--min-time-limit", default=1, type=float)

0 commit comments

Comments
 (0)