Skip to content

Commit ce53dc1

Browse files
committed
feat: add support hard gen
1 parent 694d73c commit ce53dc1

File tree

1 file changed

+9
-4
lines changed

1 file changed

+9
-4
lines changed

bigcodebench/generate.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,21 +17,23 @@ def codegen(
1717
model: DecoderBase,
1818
save_path: str,
1919
subset: str,
20+
hard=False,
2021
greedy=False,
2122
strip_newlines=False,
2223
n_samples=1,
2324
id_range=None,
2425
resume=True,
2526
):
27+
extra = "Full" if not hard else "Hard"
2628
with Progress(
27-
TextColumn(f"BigCodeBench--{subset} •" + "[progress.percentage]{task.percentage:>3.0f}%"),
29+
TextColumn(f"BigCodeBench--{subset} ({extra}) •" + "[progress.percentage]{task.percentage:>3.0f}%"),
2830
BarColumn(),
2931
MofNCompleteColumn(),
3032
TextColumn("•"),
3133
TimeElapsedColumn(),
3234
) as p:
3335

34-
dataset = get_bigcodebench()
36+
dataset = get_bigcodebench(hard=hard)
3537

3638
if model.is_direct_completion() and subset == "instruct":
3739
raise Exception("Base model does not support direct completion for instruct tasks")
@@ -106,6 +108,7 @@ def main():
106108
parser = argparse.ArgumentParser()
107109
parser.add_argument("--model", required=True, type=str)
108110
parser.add_argument("--subset", required=True, type=str)
111+
parser.add_argument("--hard", action="store_true")
109112
parser.add_argument("--save_path", default=None, type=str)
110113
parser.add_argument("--bs", default=1, type=int)
111114
parser.add_argument("--n_samples", default=1, type=int)
@@ -147,16 +150,18 @@ def main():
147150
tp=args.tp,
148151
trust_remote_code=args.trust_remote_code
149152
)
150-
153+
154+
extra = "" if not args.hard else "-hard"
151155
if not args.save_path:
152-
save_path = args.model.replace("/", "--") + f"--bigcodebench-{args.subset}--{args.backend}-{args.temperature}-{args.n_samples}.jsonl"
156+
save_path = args.model.replace("/", "--") + f"--bigcodebench{extra}-{args.subset}--{args.backend}-{args.temperature}-{args.n_samples}.jsonl"
153157
else:
154158
save_path = args.save_path
155159

156160
codegen(
157161
model=model_runner,
158162
save_path=save_path,
159163
subset=args.subset,
164+
hard=args.hard,
160165
greedy=args.greedy,
161166
strip_newlines=args.strip_newlines,
162167
n_samples=args.n_samples,

0 commit comments

Comments
 (0)