Skip to content

Commit 3846a62

Browse files
committed
Add ability to pass model_dir to .ci/scripts/test_huggingface_optimum_model.py
1 parent 7750116 commit 3846a62

File tree

1 file changed

+13
-3
lines changed

1 file changed

+13
-3
lines changed

.ci/scripts/test_huggingface_optimum_model.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,7 @@ def test_vit(model_id, model_dir, recipe, *, quantize=False, run_only=False):
270270
parser.add_argument("--model", type=str, required=True)
271271
parser.add_argument("--recipe", type=str, required=True)
272272
parser.add_argument("--quantize", action="store_true", help="Enable quantization")
273+
parser.add_argument("--model_dir", type=str, required=False)
273274
args = parser.parse_args()
274275

275276
model_to_model_id_and_test_function = {
@@ -294,11 +295,20 @@ def test_vit(model_id, model_dir, recipe, *, quantize=False, run_only=False):
294295
f"Unknown model name: {args.model}. Available models: {model_to_model_id_and_test_function.keys()}"
295296
)
296297

297-
with tempfile.TemporaryDirectory() as tmp_dir:
298-
model_id, test_fn = model_to_model_id_and_test_function[args.model]
298+
model_id, test_fn = model_to_model_id_and_test_function[args.model]
299+
if args.model_dir is None:
300+
with tempfile.TemporaryDirectory() as tmp_dir:
301+
test_fn(
302+
model_id=model_id,
303+
model_dir=tmp_dir,
304+
recipe=args.recipe,
305+
quantize=args.quantize,
306+
)
307+
else:
299308
test_fn(
300309
model_id=model_id,
301-
model_dir=tmp_dir,
310+
model_dir=args.model_dir,
302311
recipe=args.recipe,
303312
quantize=args.quantize,
313+
run_only=False,
304314
)

0 commit comments

Comments
 (0)