Skip to content

Commit 88b017c

Browse files
committed
allow passing ov_config to from_pretrained
1 parent 8f9cf89 commit 88b017c

File tree

6 files changed

+53
-37
lines changed

6 files changed

+53
-37
lines changed

optimum/intel/openvino/modeling_base.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -619,11 +619,13 @@ def _from_transformers(
619619
)
620620
compile_only = False
621621

622-
# If load_in_8bit and quantization_config not specified then ov_config is set to None and will be set by default in convert depending on the model size
623-
if load_in_8bit is None and not quantization_config:
624-
ov_config = None
625-
else:
626-
ov_config = OVConfig(dtype="fp32")
622+
ov_config = kwargs.get("ov_config")
623+
if ov_config is None:
624+
# If load_in_8bit and quantization_config not specified then ov_config is set to None and will be set by default in convert depending on the model size
625+
if load_in_8bit is None and not quantization_config:
626+
ov_config = None
627+
else:
628+
ov_config = OVConfig(dtype="fp32")
627629

628630
variant = kwargs.pop("variant", None)
629631

optimum/intel/openvino/modeling_decoder.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -306,11 +306,13 @@ def _from_transformers(
306306
if use_cache:
307307
task = task + "-with-past"
308308

309-
# If load_in_8bit and quantization_config not specified then ov_config is set to None and will be set by default in convert depending on the model size
310-
if load_in_8bit is None and not quantization_config:
311-
ov_export_config = None
312-
else:
313-
ov_export_config = OVConfig(dtype="auto")
309+
ov_export_config = kwargs.get("ov_config")
310+
if ov_export_config is None:
311+
# If load_in_8bit and quantization_config not specified then ov_config is set to None and will be set by default in convert depending on the model size
312+
if load_in_8bit is None and not quantization_config:
313+
ov_export_config = None
314+
else:
315+
ov_export_config = OVConfig(dtype="auto")
314316

315317
stateful = kwargs.pop("stateful", ensure_stateful_is_available(warn=False) and use_cache)
316318

optimum/intel/openvino/modeling_diffusion.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -605,12 +605,15 @@ def _from_transformers(
605605
)
606606
compile_only = False
607607

608-
# If load_in_8bit and quantization_config not specified then ov_config is set
609-
# to None and will be set by default in convert depending on the model size
610-
if load_in_8bit is None and not quantization_config:
611-
ov_config = None
612-
else:
613-
ov_config = OVConfig(dtype="auto")
608+
ov_config = kwargs.get("ov_config")
609+
610+
if ov_config is None:
611+
# If load_in_8bit and quantization_config not specified then ov_config is set
612+
# to None and will be set by default in convert depending on the model size
613+
if load_in_8bit is None and not quantization_config:
614+
ov_config = None
615+
else:
616+
ov_config = OVConfig(dtype="auto")
614617

615618
torch_dtype = kwargs.pop("torch_dtype", None)
616619

optimum/intel/openvino/modeling_open_clip.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -243,11 +243,13 @@ def _from_transformers(
243243
# would end-up removing the directory containing the underlying OpenVINO model
244244
cls._model_save_dir_tempdirectory_instance = save_dir
245245

246-
# If load_in_8bit and quantization_config not specified then ov_config is set to None and will be set by default in convert depending on the model size
247-
if load_in_8bit is None and not quantization_config:
248-
ov_config = None
249-
else:
250-
ov_config = OVConfig(dtype="fp32")
246+
ov_config = kwargs.get("ov_config")
247+
if ov_config is None:
248+
# If load_in_8bit and quantization_config not specified then ov_config is set to None and will be set by default in convert depending on the model size
249+
if load_in_8bit is None and not quantization_config:
250+
ov_config = None
251+
else:
252+
ov_config = OVConfig(dtype="fp32")
251253

252254
def fn_get_submodels(model):
253255
return {"model_text": model.text}
@@ -368,11 +370,14 @@ def _from_transformers(
368370
# would end-up removing the directory containing the underlying OpenVINO model
369371
cls._model_save_dir_tempdirectory_instance = save_dir
370372

371-
# If load_in_8bit and quantization_config not specified then ov_config is set to None and will be set by default in convert depending on the model size
372-
if load_in_8bit is None and not quantization_config:
373-
ov_config = None
374-
else:
375-
ov_config = OVConfig(dtype="fp32")
373+
ov_config = kwargs.get("ov_config")
374+
375+
if ov_config is None:
376+
# If load_in_8bit and quantization_config not specified then ov_config is set to None and will be set by default in convert depending on the model size
377+
if load_in_8bit is None and not quantization_config:
378+
ov_config = None
379+
else:
380+
ov_config = OVConfig(dtype="fp32")
376381

377382
def fn_get_submodels(model):
378383
return {"model_vision": model.visual}

optimum/intel/openvino/modeling_seq2seq.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -602,11 +602,13 @@ def _from_transformers(
602602
"Please provide openvino model obtained using optimum-cli or saved on disk using `save_pretrained`"
603603
)
604604
compile_only = False
605-
# If load_in_8bit and quantization_config not specified then ov_config is set to None and will be set by default in convert depending on the model size
606-
if load_in_8bit is None and not quantization_config:
607-
ov_config = None
608-
else:
609-
ov_config = OVConfig(dtype="fp32")
605+
ov_config = kwargs.get("ov_config")
606+
if ov_config is None:
607+
# If load_in_8bit and quantization_config not specified then ov_config is set to None and will be set by default in convert depending on the model size
608+
if load_in_8bit is None and not quantization_config:
609+
ov_config = None
610+
else:
611+
ov_config = OVConfig(dtype="fp32")
610612
stateful = kwargs.get("stateful", True)
611613
variant = kwargs.pop("variant", None)
612614

optimum/intel/openvino/modeling_visual_language.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -660,12 +660,14 @@ def _from_transformers(
660660
if task is None:
661661
task = cls.export_feature
662662

663-
# If load_in_8bit and quantization_config not specified then ov_config is set to None and will be set by default in convert depending on the model size
664-
if load_in_8bit is None and not quantization_config:
665-
ov_config = None
666-
else:
667-
# Export in fp32 if compression won't be applied later
668-
ov_config = OVConfig(dtype="fp32" if load_in_8bit is False else "auto")
663+
ov_config = kwargs.get("ov_config")
664+
if ov_config is None:
665+
# If load_in_8bit and quantization_config not specified then ov_config is set to None and will be set by default in convert depending on the model size
666+
if load_in_8bit is None and not quantization_config:
667+
ov_config = None
668+
else:
669+
# Export in fp32 if compression won't be applied later
670+
ov_config = OVConfig(dtype="fp32" if load_in_8bit is False else "auto")
669671

670672
stateful = kwargs.pop("stateful", ensure_stateful_is_available(warn=False) and use_cache)
671673
variant = kwargs.pop("variant", None)

0 commit comments

Comments
 (0)