44
55from bigcodebench .model import DecoderBase , make_model
66from bigcodebench .data import get_bigcodebench , write_jsonl
7+ from bigcodebench .sanitize import sanitize
78from rich .progress import (
89 BarColumn ,
910 MofNCompleteColumn ,
@@ -23,6 +24,7 @@ def codegen(
2324 n_samples = 1 ,
2425 id_range = None ,
2526 resume = True ,
27+ batch_size : int = - 1 ,
2628):
2729 with Progress (
2830 TextColumn (f"BigCodeBench--{ split .capitalize ()} ({ subset .capitalize ()} ) •" + "[progress.percentage]{task.percentage:>3.0f}%" ),
@@ -41,65 +43,81 @@ def codegen(
4143 dirname = os .path .dirname (save_path )
4244 if not os .path .exists (dirname ) and dirname != "" :
4345 os .makedirs (dirname )
46+
47+ batch_prompts = []
48+ batch_task_ids = []
49+ batch_nsamples = []
50+ batch_entry_points = []
51+
52+ # Read existing data once if resuming
53+ existing_data = {}
54+ if resume and os .path .exists (save_path ):
55+ with open (save_path , "r" ) as f :
56+ for line in f :
57+ item = json .loads (line )
58+ existing_data [item ["task_id" ]] = existing_data .get (item ["task_id" ], 0 ) + 1
59+
4460 for id_num , (task_id , task ) in enumerate (p .track (dataset .items ())):
4561 if id_range is not None :
4662 low , high = id_range
47- if id_num < low or id_num >= high :
63+ if id_num < low :
4864 p .console .print (f"Skipping { task_id } as it is not in { id_range } " )
4965 continue
66+ if id_num > id_range [1 ]:
67+ break
5068
5169 p_name = task_id .replace ("/" , "_" )
5270
53- # read the existing file if save_path exists
54- if os .path .exists (save_path ):
55- with open (save_path , "r" ) as f :
56- existing_data = f .read ().splitlines ()
57- log = f"Codegen: { p_name } @ { model } "
58- n_existing = 0
59- if resume :
60- if os .path .exists (save_path ):
61- n_existing = len ([1 for line in existing_data if json .loads (line )["task_id" ] == task_id ])
62- else :
63- n_existing = 0
71+ n_existing = existing_data .get (task_id , 0 )
72+ nsamples = n_samples - n_existing
73+
74+ try :
75+ prompt = task [f"{ split } _prompt" ]
76+ except :
77+ raise Exception (f"Invalid split { split } for bigcodebench-{ subset } " )
78+ if strip_newlines :
79+ prompt = prompt .strip ("\n " )
80+
81+ if nsamples > 0 :
82+ batch_prompts .append (prompt )
83+ batch_task_ids .append (task_id )
84+ batch_nsamples .append (nsamples )
85+ batch_entry_points .append (task ["entry_point" ])
86+
87+ log = f"Codegen: { p_name } @ { model } "
6488 if n_existing > 0 :
6589 log += f" (resuming from { n_existing } )"
66-
67- nsamples = n_samples - n_existing
68- p .console .print (log )
69-
70- sidx = n_samples - nsamples
71- while sidx < n_samples :
72- try :
73- prompt = task [f"{ split } _prompt" ]
74- except :
75- raise Exception (f"Invalid split { split } " )
76- if strip_newlines :
77- prompt = prompt .strip ("\n " )
90+ p .console .print (log )
91+
92+ if (batch_size and len (batch_prompts ) == batch_size ) or id_num == len (dataset ) - 1 or (id_range and id_num == id_range [1 ] - 1 ):
93+ if not batch_prompts and id_num == len (dataset ) - 1 :
94+ break
7895 outputs = model .codegen (
79- prompt ,
96+ batch_prompts ,
8097 do_sample = not greedy ,
81- num_samples = n_samples - sidx ,
98+ num_samples = max ( batch_nsamples ) ,
8299 )
83100 assert outputs , "No outputs from model!"
84- if model .is_direct_completion ():
85- samples = [
86- dict (
87- task_id = task_id ,
88- solution = task ["complete_prompt" ]+ completion
89- )
90- for task_id , completion in zip ([task_id ]* len (outputs ), outputs )
91- ]
92- else :
93- samples = [
94- dict (
95- task_id = task_id ,
96- solution = completion ,
97- )
98- for task_id , completion in zip ([task_id ]* len (outputs ), outputs )
99- ]
101+
102+ samples = []
103+ for task_id , content , entry_point , nsamples , task_outputs in zip (batch_task_ids , batch_prompts , batch_entry_points , batch_nsamples , outputs ):
104+ if model .is_direct_completion ():
105+ samples .extend ([
106+ dict (task_id = task_id , solution = sanitize (content + completion , entry_point ))
107+ for completion in task_outputs [:nsamples ]
108+ ])
109+ else :
110+ samples .extend ([
111+ dict (task_id = task_id , solution = sanitize (completion , entry_point ))
112+ for completion in task_outputs [:nsamples ]
113+ ])
100114 print (f"Generated { len (samples )} samples" )
101115 write_jsonl (save_path , samples , append = True )
102- sidx += len (outputs )
116+
117+ # Clear batches
118+ batch_prompts = []
119+ batch_task_ids = []
120+ batch_nsamples = []
103121
104122
105123def main ():
@@ -113,6 +131,7 @@ def main():
113131 parser .add_argument ("--temperature" , default = 0.0 , type = float )
114132 parser .add_argument ("--greedy" , action = "store_true" )
115133 parser .add_argument ("--strip_newlines" , action = "store_true" )
134+ parser .add_argument ("--direct_completion" , action = "store_true" )
116135 parser .add_argument ("--resume" , action = "store_true" )
117136 parser .add_argument ("--id_range" , nargs = 2 , type = int )
118137 parser .add_argument ("--backend" , default = "vllm" , type = str , choices = ["vllm" , "hf" , "openai" , "mistral" , "anthropic" , "google" ])
@@ -126,7 +145,6 @@ def main():
126145
127146 if args .greedy or (args .temperature == 0 and args .n_samples == 1 ):
128147 args .temperature = 0
129- args .bs = 1
130148 args .n_samples = 1
131149 args .greedy = True
132150 print ("Greedy decoding ON (--greedy): setting bs=1, n_samples=1, temperature=0" )
@@ -140,18 +158,20 @@ def main():
140158 model_runner = make_model (
141159 model = args .model ,
142160 backend = args .backend ,
143- batch_size = args .bs ,
161+ subset = args .subset ,
162+ split = args .split ,
144163 temperature = args .temperature ,
145164 base_url = args .base_url ,
146165 tp = args .tp ,
147166 trust_remote_code = args .trust_remote_code ,
167+ direct_completion = args .direct_completion ,
148168 tokenizer_name = args .tokenizer_name ,
149169 tokenizer_legacy = args .tokenizer_legacy
150170 )
151171
152172 extra = "-" + args .subset if args .subset != "full" else ""
153173 if not args .save_path :
154- save_path = args .model .replace ("/" , "--" ) + f"--bigcodebench{ extra } -{ args .split } --{ args .backend } -{ args .temperature } -{ args .n_samples } .jsonl"
174+ save_path = args .model .replace ("/" , "--" ) + f"--bigcodebench{ extra } -{ args .split } --{ args .backend } -{ args .temperature } -{ args .n_samples } -sanitized_calibrated .jsonl"
155175 else :
156176 save_path = args .save_path
157177
@@ -164,7 +184,8 @@ def main():
164184 strip_newlines = args .strip_newlines ,
165185 n_samples = args .n_samples ,
166186 resume = args .resume ,
167- id_range = args .id_range
187+ id_range = args .id_range ,
188+ batch_size = args .bs
168189 )
169190
170191
0 commit comments