@@ -107,8 +107,8 @@ def codegen(
107107def main ():
108108 parser = argparse .ArgumentParser ()
109109 parser .add_argument ("--model" , required = True , type = str )
110- parser .add_argument ("--split" , required = True , type = str )
111- parser .add_argument ("--subset" , default = "full" , type = str )
110+ parser .add_argument ("--split" , required = True , type = str , choices = [ "complete" , "instruct" ] )
111+ parser .add_argument ("--subset" , default = "full" , type = str , choices = [ "full" , "hard" ] )
112112 parser .add_argument ("--save_path" , default = None , type = str )
113113 parser .add_argument ("--bs" , default = 1 , type = int )
114114 parser .add_argument ("--n_samples" , default = 1 , type = int )
@@ -117,17 +117,12 @@ def main():
117117 parser .add_argument ("--strip_newlines" , action = "store_true" )
118118 parser .add_argument ("--resume" , action = "store_true" )
119119 parser .add_argument ("--id_range" , nargs = 2 , type = int )
120- parser .add_argument ("--backend" , default = "vllm" , type = str )
120+ parser .add_argument ("--backend" , default = "vllm" , type = str , choices = [ "vllm" , "hf" , "openai" , "mistral" , "anthropic" , "google" ] )
121121 parser .add_argument ("--base_url" , default = None , type = str )
122122 parser .add_argument ("--tp" , default = 1 , type = int )
123123 parser .add_argument ("--trust_remote_code" , action = "store_true" )
124124 args = parser .parse_args ()
125125
126-
127- assert args .split in ["complete" , "instruct" ], f"Invalid split { args .split } "
128- assert args .subset in ["full" , "hard" ], f"Invalid subset { args .subset } "
129- assert args .backend in ["vllm" , "hf" , "openai" , "mistral" , "anthropic" , "google" ]
130-
131126 if args .greedy and (args .temperature != 0 or args .bs != 1 or args .n_samples != 1 )\
132127 or (args .temperature == 0 and args .n_samples == 1 ):
133128 args .temperature = 0
0 commit comments