diff --git a/.ci/scripts/test_huggingface_optimum_model.py b/.ci/scripts/test_huggingface_optimum_model.py index 8a0b244c549..6a31eabb0c8 100644 --- a/.ci/scripts/test_huggingface_optimum_model.py +++ b/.ci/scripts/test_huggingface_optimum_model.py @@ -262,7 +262,7 @@ def test_vit(model_id, model_dir, recipe, *, quantize=False, run_only=False): assert torch.allclose( eager_output.logits, et_output, atol=1e-02, rtol=1e-02 - ), "CoreML output does not match eager" + ), "Model output does not match eager" if __name__ == "__main__": @@ -270,6 +270,12 @@ def test_vit(model_id, model_dir, recipe, *, quantize=False, run_only=False): parser.add_argument("--model", type=str, required=True) parser.add_argument("--recipe", type=str, required=True) parser.add_argument("--quantize", action="store_true", help="Enable quantization") + parser.add_argument( + "--model_dir", + type=str, + required=False, + help="When provided, write the pte file to this directory. Otherwise, a temporary directory is created for the test.", + ) args = parser.parse_args() model_to_model_id_and_test_function = { @@ -294,11 +300,11 @@ def test_vit(model_id, model_dir, recipe, *, quantize=False, run_only=False): f"Unknown model name: {args.model}. Available models: {model_to_model_id_and_test_function.keys()}" ) + model_id, test_fn = model_to_model_id_and_test_function[args.model] with tempfile.TemporaryDirectory() as tmp_dir: - model_id, test_fn = model_to_model_id_and_test_function[args.model] test_fn( model_id=model_id, - model_dir=tmp_dir, + model_dir=tmp_dir if args.model_dir is None else args.model_dir, recipe=args.recipe, quantize=args.quantize, )