Skip to content

Commit be2b6b5

Browse files
authored
Add ability to pass model_dir to .ci/scripts/test_huggingface_optimum (#13116)
Adds ability to specify model_dir to optimum test script. This is convinient if you want the pte file for local debugging. If no model_dir is specified, a temp directory is created.
1 parent 414fc32 commit be2b6b5

File tree

1 file changed

+9
-3
lines changed

1 file changed

+9
-3
lines changed

.ci/scripts/test_huggingface_optimum_model.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -262,14 +262,20 @@ def test_vit(model_id, model_dir, recipe, *, quantize=False, run_only=False):
262262

263263
assert torch.allclose(
264264
eager_output.logits, et_output, atol=1e-02, rtol=1e-02
265-
), "CoreML output does not match eager"
265+
), "Model output does not match eager"
266266

267267

268268
if __name__ == "__main__":
269269
parser = argparse.ArgumentParser()
270270
parser.add_argument("--model", type=str, required=True)
271271
parser.add_argument("--recipe", type=str, required=True)
272272
parser.add_argument("--quantize", action="store_true", help="Enable quantization")
273+
parser.add_argument(
274+
"--model_dir",
275+
type=str,
276+
required=False,
277+
help="When provided, write the pte file to this directory. Otherwise, a temporary directory is created for the test.",
278+
)
273279
args = parser.parse_args()
274280

275281
model_to_model_id_and_test_function = {
@@ -294,11 +300,11 @@ def test_vit(model_id, model_dir, recipe, *, quantize=False, run_only=False):
294300
f"Unknown model name: {args.model}. Available models: {model_to_model_id_and_test_function.keys()}"
295301
)
296302

303+
model_id, test_fn = model_to_model_id_and_test_function[args.model]
297304
with tempfile.TemporaryDirectory() as tmp_dir:
298-
model_id, test_fn = model_to_model_id_and_test_function[args.model]
299305
test_fn(
300306
model_id=model_id,
301-
model_dir=tmp_dir,
307+
model_dir=tmp_dir if args.model_dir is None else args.model_dir,
302308
recipe=args.recipe,
303309
quantize=args.quantize,
304310
)

0 commit comments

Comments
 (0)