@@ -262,14 +262,20 @@ def test_vit(model_id, model_dir, recipe, *, quantize=False, run_only=False):
262262
263263 assert torch .allclose (
264264 eager_output .logits , et_output , atol = 1e-02 , rtol = 1e-02
265- ), "CoreML output does not match eager"
265+ ), "Model output does not match eager"
266266
267267
268268if __name__ == "__main__" :
269269 parser = argparse .ArgumentParser ()
270270 parser .add_argument ("--model" , type = str , required = True )
271271 parser .add_argument ("--recipe" , type = str , required = True )
272272 parser .add_argument ("--quantize" , action = "store_true" , help = "Enable quantization" )
273+ parser .add_argument (
274+ "--model_dir" ,
275+ type = str ,
276+ required = False ,
277+ help = "When provided, write the pte file to this directory. Otherwise, a temporary directory is created for the test." ,
278+ )
273279 args = parser .parse_args ()
274280
275281 model_to_model_id_and_test_function = {
@@ -294,11 +300,11 @@ def test_vit(model_id, model_dir, recipe, *, quantize=False, run_only=False):
294300 f"Unknown model name: { args .model } . Available models: { model_to_model_id_and_test_function .keys ()} "
295301 )
296302
303+ model_id , test_fn = model_to_model_id_and_test_function [args .model ]
297304 with tempfile .TemporaryDirectory () as tmp_dir :
298- model_id , test_fn = model_to_model_id_and_test_function [args .model ]
299305 test_fn (
300306 model_id = model_id ,
301- model_dir = tmp_dir ,
307+ model_dir = tmp_dir if args . model_dir is None else args . model_dir ,
302308 recipe = args .recipe ,
303309 quantize = args .quantize ,
304310 )
0 commit comments