Skip to content

Commit 9a06f9c

Browse files
committed
Add --revision to optimum-cli export openvino
1 parent 69311c0 commit 9a06f9c

File tree

3 files changed

+47
-10
lines changed

3 files changed

+47
-10
lines changed

optimum/commands/export/openvino.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,11 @@ def parse_args_openvino(parser: "ArgumentParser"):
263263
"reduces quantization error. Valid only when activations quantization is enabled."
264264
),
265265
)
266+
optional_group.add_argument(
267+
"--revision",
268+
type=str,
269+
help=("Revision of the model to load."),
270+
)
266271
optional_group.add_argument(
267272
"--model-kwargs",
268273
type=json.loads,
@@ -332,7 +337,7 @@ def run(self):
332337
from ...intel.openvino.configuration import _DEFAULT_4BIT_WQ_CONFIG, OVConfig, get_default_quantization_config
333338

334339
if self.args.library is None:
335-
# TODO: add revision, subfolder and token to args
340+
# TODO: add subfolder and token to args
336341
library_name = _infer_library_from_model_name_or_path(
337342
model_name_or_path=self.args.model, cache_dir=self.args.cache_dir
338343
)
@@ -427,6 +432,7 @@ def run(self):
427432
self.args.model,
428433
cache_dir=self.args.cache_dir,
429434
trust_remote_code=self.args.trust_remote_code,
435+
revision=self.args.revision,
430436
)
431437
if getattr(config, "model_type", "").replace("_", "-") in MULTI_MODAL_TEXT_GENERATION_MODELS:
432438
task = "image-text-to-text"
@@ -473,7 +479,9 @@ def run(self):
473479
else:
474480
raise NotImplementedError(f"Quantization isn't supported for class {class_name}.")
475481

476-
model = model_cls.from_pretrained(self.args.model, export=True, quantization_config=quantization_config)
482+
model = model_cls.from_pretrained(
483+
self.args.model, export=True, quantization_config=quantization_config, revision=self.args.revision
484+
)
477485
model.save_pretrained(self.args.output)
478486
if not self.args.disable_convert_tokenizer:
479487
maybe_convert_tokenizers(library_name, self.args.output, model, task=task)
@@ -529,6 +537,7 @@ def run(self):
529537
trust_remote_code=self.args.trust_remote_code,
530538
variant=self.args.variant,
531539
cache_dir=self.args.cache_dir,
540+
revision=self.args.revision,
532541
)
533542
model.save_pretrained(self.args.output)
534543

@@ -551,6 +560,7 @@ def run(self):
551560
convert_tokenizer=not self.args.disable_convert_tokenizer,
552561
library_name=library_name,
553562
variant=self.args.variant,
563+
revision=self.args.revision,
554564
model_kwargs=self.args.model_kwargs,
555565
# **input_shapes,
556566
)

optimum/exporters/openvino/__main__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -383,7 +383,9 @@ class StoreAttr(object):
383383

384384
try:
385385
if library_name == "open_clip":
386-
model = _OpenClipForZeroShotImageClassification.from_pretrained(model_name_or_path, cache_dir=cache_dir)
386+
model = _OpenClipForZeroShotImageClassification.from_pretrained(
387+
model_name_or_path, cache_dir=cache_dir, revision=revision
388+
)
387389
else:
388390
model = TasksManager.get_model_from_task(
389391
task,
@@ -407,7 +409,7 @@ class StoreAttr(object):
407409
if pad_token_id is not None:
408410
model.config.pad_token_id = pad_token_id
409411
else:
410-
tok = AutoTokenizer.from_pretrained(model_name_or_path)
412+
tok = AutoTokenizer.from_pretrained(model_name_or_path, revision=revision)
411413
pad_token_id = getattr(tok, "pad_token_id", None)
412414
if pad_token_id is None:
413415
raise ValueError(

tests/openvino/test_exporters_cli.py

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -201,9 +201,11 @@ class OVCLIExportTestCase(unittest.TestCase):
201201
"whisper",
202202
"int8",
203203
"--dataset librispeech --num-samples 1 --smooth-quant-alpha 0.9 --trust-remote-code",
204-
{"encoder": 10, "decoder": 12, "decoder_with_past": 11}
205-
if is_transformers_version("<=", "4.36.0")
206-
else {"encoder": 8, "decoder": 12, "decoder_with_past": 25},
204+
(
205+
{"encoder": 10, "decoder": 12, "decoder_with_past": 11}
206+
if is_transformers_version("<=", "4.36.0")
207+
else {"encoder": 8, "decoder": 12, "decoder_with_past": 25}
208+
),
207209
(
208210
{"encoder": {"int8": 8}, "decoder": {"int8": 11}, "decoder_with_past": {"int8": 9}}
209211
if is_transformers_version("<=", "4.36.0")
@@ -215,9 +217,11 @@ class OVCLIExportTestCase(unittest.TestCase):
215217
"whisper",
216218
"f8e4m3",
217219
"--dataset librispeech --num-samples 1 --smooth-quant-alpha 0.9 --trust-remote-code",
218-
{"encoder": 10, "decoder": 12, "decoder_with_past": 11}
219-
if is_transformers_version("<=", "4.36.0")
220-
else {"encoder": 8, "decoder": 12, "decoder_with_past": 25},
220+
(
221+
{"encoder": 10, "decoder": 12, "decoder_with_past": 11}
222+
if is_transformers_version("<=", "4.36.0")
223+
else {"encoder": 8, "decoder": 12, "decoder_with_past": 25}
224+
),
221225
(
222226
{"encoder": {"f8e4m3": 8}, "decoder": {"f8e4m3": 11}, "decoder_with_past": {"f8e4m3": 9}}
223227
if is_transformers_version("<=", "4.36.0")
@@ -1142,3 +1146,24 @@ def test_export_openvino_with_custom_variant(self):
11421146
model = eval(_HEAD_TO_AUTOMODELS["stable-diffusion"]).from_pretrained(tmpdir, compile=False)
11431147
for component in ["text_encoder", "tokenizer", "unet", "vae_encoder", "vae_decoder"]:
11441148
self.assertIsNotNone(getattr(model, component))
1149+
1150+
def test_export_openvino_with_revision(self):
1151+
with TemporaryDirectory() as tmpdir:
1152+
subprocess.run(
1153+
f"optimum-cli export openvino --model hf-internal-testing/tiny-random-MistralForCausalLM --revision 7158fab {tmpdir}",
1154+
shell=True,
1155+
check=True,
1156+
)
1157+
model = eval(_HEAD_TO_AUTOMODELS["text-generation"]).from_pretrained(tmpdir, compile=False)
1158+
1159+
with TemporaryDirectory() as tmpdir:
1160+
result = subprocess.run(
1161+
f"optimum-cli export openvino --model hf-internal-testing/tiny-random-MistralForCausalLM --revision 7158fac {tmpdir}",
1162+
shell=True,
1163+
check=False,
1164+
text=True,
1165+
stdout=subprocess.PIPE,
1166+
stderr=subprocess.PIPE,
1167+
)
1168+
self.assertNotEqual(result.returncode, 0)
1169+
self.assertIn("not a valid git identifier", result.stderr)

0 commit comments

Comments
 (0)