@@ -17,21 +17,23 @@ def codegen(
1717 model : DecoderBase ,
1818 save_path : str ,
1919 subset : str ,
20+ hard = False ,
2021 greedy = False ,
2122 strip_newlines = False ,
2223 n_samples = 1 ,
2324 id_range = None ,
2425 resume = True ,
2526):
27+ extra = "Full" if not hard else "Hard"
2628 with Progress (
27- TextColumn (f"BigCodeBench--{ subset } •" + "[progress.percentage]{task.percentage:>3.0f}%" ),
29+ TextColumn (f"BigCodeBench--{ subset } ( { extra } ) •" + "[progress.percentage]{task.percentage:>3.0f}%" ),
2830 BarColumn (),
2931 MofNCompleteColumn (),
3032 TextColumn ("•" ),
3133 TimeElapsedColumn (),
3234 ) as p :
3335
34- dataset = get_bigcodebench ()
36+ dataset = get_bigcodebench (hard = hard )
3537
3638 if model .is_direct_completion () and subset == "instruct" :
3739 raise Exception ("Base model does not support direct completion for instruct tasks" )
@@ -106,6 +108,7 @@ def main():
106108 parser = argparse .ArgumentParser ()
107109 parser .add_argument ("--model" , required = True , type = str )
108110 parser .add_argument ("--subset" , required = True , type = str )
111+ parser .add_argument ("--hard" , action = "store_true" )
109112 parser .add_argument ("--save_path" , default = None , type = str )
110113 parser .add_argument ("--bs" , default = 1 , type = int )
111114 parser .add_argument ("--n_samples" , default = 1 , type = int )
@@ -147,16 +150,18 @@ def main():
147150 tp = args .tp ,
148151 trust_remote_code = args .trust_remote_code
149152 )
150-
153+
154+ extra = "" if not args .hard else "-hard"
151155 if not args .save_path :
152- save_path = args .model .replace ("/" , "--" ) + f"--bigcodebench-{ args .subset } --{ args .backend } -{ args .temperature } -{ args .n_samples } .jsonl"
156+ save_path = args .model .replace ("/" , "--" ) + f"--bigcodebench{ extra } -{ args .subset } --{ args .backend } -{ args .temperature } -{ args .n_samples } .jsonl"
153157 else :
154158 save_path = args .save_path
155159
156160 codegen (
157161 model = model_runner ,
158162 save_path = save_path ,
159163 subset = args .subset ,
164+ hard = args .hard ,
160165 greedy = args .greedy ,
161166 strip_newlines = args .strip_newlines ,
162167 n_samples = args .n_samples ,
0 commit comments