Skip to content

Commit f202b74

Browse files
committed
fix: change dataset to subset
1 parent a13bf9e commit f202b74

File tree

2 files changed

+5
-14
lines changed

2 files changed

+5
-14
lines changed

bigcodebench/eval/__init__.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,6 @@ def is_floats(x) -> bool:
106106

107107

108108
def unsafe_execute(
109-
dataset: str,
110109
entry_point: str,
111110
code: str,
112111
test_code: str,
@@ -168,7 +167,6 @@ def unsafe_execute(
168167

169168

170169
def untrusted_check(
171-
dataset: str,
172170
code: str,
173171
test_code: str,
174172
entry_point: str,
@@ -185,7 +183,6 @@ def untrusted_check(
185183
p = multiprocessing.Process(
186184
target=unsafe_execute,
187185
args=(
188-
dataset,
189186
entry_point,
190187
code,
191188
test_code,
@@ -217,7 +214,6 @@ def untrusted_check(
217214

218215

219216
def evaluate_files(
220-
dataset: str,
221217
files: List[str],
222218
inputs: List,
223219
entry_point: str,
@@ -230,7 +226,6 @@ def evaluate_files(
230226
for file in files:
231227
code = open(file, "r").read()
232228
stat, det = untrusted_check(
233-
dataset,
234229
code,
235230
inputs,
236231
entry_point,

bigcodebench/evaluate.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,6 @@ def get_groundtruth(problems, hashcode, check_gt_only):
6262
return expected_time
6363

6464
def check_correctness(
65-
dataset: str,
6665
completion_id: int,
6766
problem: Dict[str, Any],
6867
solution: str,
@@ -77,7 +76,6 @@ def check_correctness(
7776
"solution": solution,
7877
}
7978
ret["base"] = untrusted_check(
80-
dataset,
8179
solution,
8280
problem["test"],
8381
problem["entry_point"],
@@ -119,10 +117,9 @@ def evaluate(flags):
119117

120118
results = compatible_eval_result(results)
121119
else:
122-
if flags.dataset == "bigcodebench":
123-
problems = get_bigcodebench()
124-
dataset_hash = get_bigcodebench_hash()
125-
expected_time = get_groundtruth(problems, dataset_hash, flags.check_gt_only)
120+
problems = get_bigcodebench()
121+
dataset_hash = get_bigcodebench_hash()
122+
expected_time = get_groundtruth(problems, dataset_hash, flags.check_gt_only)
126123

127124
if flags.check_gt_only:
128125
return
@@ -157,7 +154,6 @@ def evaluate(flags):
157154
solution = problems[task_id]["prompt_wo_doc"] + "\n pass\n" + solution
158155
remainings.add(sample["_identifier"])
159156
args = (
160-
flags.dataset,
161157
completion_id[task_id],
162158
problems[task_id],
163159
solution,
@@ -219,7 +215,7 @@ def stucking_checker():
219215
for k in [1, 5, 10, 25, 100]
220216
if total.min() >= k
221217
}
222-
cprint(f"{flags.dataset}", "green")
218+
cprint(f"BigCodeBench-{flags.subset}", "green")
223219
for k, v in pass_at_k.items():
224220
cprint(f"{k}:\t{v:.3f}", "green")
225221

@@ -246,7 +242,7 @@ def stucking_checker():
246242
def main():
247243
parser = argparse.ArgumentParser()
248244
parser.add_argument(
249-
"--dataset", required=True, type=str, choices=["bigcodebench"]
245+
"--subset", required=True, type=str, choices=["c2c", "nl2c"]
250246
)
251247
parser.add_argument("--samples", required=True, type=str)
252248
parser.add_argument("--parallel", default=None, type=int)

0 commit comments

Comments
 (0)