Skip to content

Commit 8f83db1

Browse files
committed
feat: refactor generate pipeline
1 parent 129e263 commit 8f83db1

File tree

3 files changed

+236
-177
lines changed

3 files changed

+236
-177
lines changed

bigcodebench/generate.py

Lines changed: 68 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from bigcodebench.model import DecoderBase, make_model
66
from bigcodebench.data import get_bigcodebench, write_jsonl
7+
from bigcodebench.sanitize import sanitize
78
from rich.progress import (
89
BarColumn,
910
MofNCompleteColumn,
@@ -23,6 +24,7 @@ def codegen(
2324
n_samples=1,
2425
id_range=None,
2526
resume=True,
27+
batch_size: int=-1,
2628
):
2729
with Progress(
2830
TextColumn(f"BigCodeBench--{split.capitalize()} ({subset.capitalize()}) •" + "[progress.percentage]{task.percentage:>3.0f}%"),
@@ -41,65 +43,81 @@ def codegen(
4143
dirname = os.path.dirname(save_path)
4244
if not os.path.exists(dirname) and dirname != "":
4345
os.makedirs(dirname)
46+
47+
batch_prompts = []
48+
batch_task_ids = []
49+
batch_nsamples = []
50+
batch_entry_points = []
51+
52+
# Read existing data once if resuming
53+
existing_data = {}
54+
if resume and os.path.exists(save_path):
55+
with open(save_path, "r") as f:
56+
for line in f:
57+
item = json.loads(line)
58+
existing_data[item["task_id"]] = existing_data.get(item["task_id"], 0) + 1
59+
4460
for id_num, (task_id, task) in enumerate(p.track(dataset.items())):
4561
if id_range is not None:
4662
low, high = id_range
47-
if id_num < low or id_num >= high:
63+
if id_num < low:
4864
p.console.print(f"Skipping {task_id} as it is not in {id_range}")
4965
continue
66+
if id_num > id_range[1]:
67+
break
5068

5169
p_name = task_id.replace("/", "_")
5270

53-
# read the existing file if save_path exists
54-
if os.path.exists(save_path):
55-
with open(save_path, "r") as f:
56-
existing_data = f.read().splitlines()
57-
log = f"Codegen: {p_name} @ {model}"
58-
n_existing = 0
59-
if resume:
60-
if os.path.exists(save_path):
61-
n_existing = len([1 for line in existing_data if json.loads(line)["task_id"] == task_id])
62-
else:
63-
n_existing = 0
71+
n_existing = existing_data.get(task_id, 0)
72+
nsamples = n_samples - n_existing
73+
74+
try:
75+
prompt = task[f"{split}_prompt"]
76+
except:
77+
raise Exception(f"Invalid split {split} for bigcodebench-{subset}")
78+
if strip_newlines:
79+
prompt = prompt.strip("\n")
80+
81+
if nsamples > 0:
82+
batch_prompts.append(prompt)
83+
batch_task_ids.append(task_id)
84+
batch_nsamples.append(nsamples)
85+
batch_entry_points.append(task["entry_point"])
86+
87+
log = f"Codegen: {p_name} @ {model}"
6488
if n_existing > 0:
6589
log += f" (resuming from {n_existing})"
66-
67-
nsamples = n_samples - n_existing
68-
p.console.print(log)
69-
70-
sidx = n_samples - nsamples
71-
while sidx < n_samples:
72-
try:
73-
prompt = task[f"{split}_prompt"]
74-
except:
75-
raise Exception(f"Invalid split {split}")
76-
if strip_newlines:
77-
prompt = prompt.strip("\n")
90+
p.console.print(log)
91+
92+
if (batch_size and len(batch_prompts) == batch_size) or id_num == len(dataset) - 1 or (id_range and id_num == id_range[1] - 1):
93+
if not batch_prompts and id_num == len(dataset) - 1:
94+
break
7895
outputs = model.codegen(
79-
prompt,
96+
batch_prompts,
8097
do_sample=not greedy,
81-
num_samples=n_samples - sidx,
98+
num_samples=max(batch_nsamples),
8299
)
83100
assert outputs, "No outputs from model!"
84-
if model.is_direct_completion():
85-
samples = [
86-
dict(
87-
task_id=task_id,
88-
solution=task["complete_prompt"]+completion
89-
)
90-
for task_id, completion in zip([task_id]*len(outputs), outputs)
91-
]
92-
else:
93-
samples = [
94-
dict(
95-
task_id=task_id,
96-
solution=completion,
97-
)
98-
for task_id, completion in zip([task_id]*len(outputs), outputs)
99-
]
101+
102+
samples = []
103+
for task_id, content, entry_point, nsamples, task_outputs in zip(batch_task_ids, batch_prompts, batch_entry_points, batch_nsamples, outputs):
104+
if model.is_direct_completion():
105+
samples.extend([
106+
dict(task_id=task_id, solution=sanitize(content+completion, entry_point))
107+
for completion in task_outputs[:nsamples]
108+
])
109+
else:
110+
samples.extend([
111+
dict(task_id=task_id, solution=sanitize(completion, entry_point))
112+
for completion in task_outputs[:nsamples]
113+
])
100114
print(f"Generated {len(samples)} samples")
101115
write_jsonl(save_path, samples, append=True)
102-
sidx += len(outputs)
116+
117+
# Clear batches
118+
batch_prompts = []
119+
batch_task_ids = []
120+
batch_nsamples = []
103121

104122

105123
def main():
@@ -113,6 +131,7 @@ def main():
113131
parser.add_argument("--temperature", default=0.0, type=float)
114132
parser.add_argument("--greedy", action="store_true")
115133
parser.add_argument("--strip_newlines", action="store_true")
134+
parser.add_argument("--direct_completion", action="store_true")
116135
parser.add_argument("--resume", action="store_true")
117136
parser.add_argument("--id_range", nargs=2, type=int)
118137
parser.add_argument("--backend", default="vllm", type=str, choices=["vllm", "hf", "openai", "mistral", "anthropic", "google"])
@@ -126,7 +145,6 @@ def main():
126145

127146
if args.greedy or (args.temperature == 0 and args.n_samples == 1):
128147
args.temperature = 0
129-
args.bs = 1
130148
args.n_samples = 1
131149
args.greedy = True
132150
print("Greedy decoding ON (--greedy): setting bs=1, n_samples=1, temperature=0")
@@ -140,18 +158,20 @@ def main():
140158
model_runner = make_model(
141159
model=args.model,
142160
backend=args.backend,
143-
batch_size=args.bs,
161+
subset=args.subset,
162+
split=args.split,
144163
temperature=args.temperature,
145164
base_url=args.base_url,
146165
tp=args.tp,
147166
trust_remote_code=args.trust_remote_code,
167+
direct_completion=args.direct_completion,
148168
tokenizer_name=args.tokenizer_name,
149169
tokenizer_legacy=args.tokenizer_legacy
150170
)
151171

152172
extra = "-" + args.subset if args.subset != "full" else ""
153173
if not args.save_path:
154-
save_path = args.model.replace("/", "--") + f"--bigcodebench{extra}-{args.split}--{args.backend}-{args.temperature}-{args.n_samples}.jsonl"
174+
save_path = args.model.replace("/", "--") + f"--bigcodebench{extra}-{args.split}--{args.backend}-{args.temperature}-{args.n_samples}-sanitized_calibrated.jsonl"
155175
else:
156176
save_path = args.save_path
157177

@@ -164,7 +184,8 @@ def main():
164184
strip_newlines=args.strip_newlines,
165185
n_samples=args.n_samples,
166186
resume=args.resume,
167-
id_range=args.id_range
187+
id_range=args.id_range,
188+
batch_size=args.bs
168189
)
169190

170191

0 commit comments

Comments
 (0)