Skip to content

Commit f1517e3

Browse files
sbalandiecharlaix
andauthored
Add openvino support for OpenCLIP (#857)
* Add support for OpenCLIP * fix comments * fix comments * rename class * remove patching --------- Co-authored-by: Ella Charlaix <[email protected]>
1 parent cff9902 commit f1517e3

File tree

18 files changed

+1400
-72
lines changed

18 files changed

+1400
-72
lines changed

optimum/commands/export/openvino.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
from ...exporters import TasksManager
2424
from ...intel.utils.import_utils import DIFFUSERS_IMPORT_ERROR, is_diffusers_available
25+
from ...intel.utils.modeling_utils import _infer_library_from_model_name_or_path
2526
from ...utils.save_utils import maybe_load_preprocessors, maybe_save_preprocessors
2627
from ..base import BaseOptimumCLICommand, CommandInfo
2728

@@ -77,7 +78,7 @@ def parse_args_openvino(parser: "ArgumentParser"):
7778
optional_group.add_argument(
7879
"--library",
7980
type=str,
80-
choices=["transformers", "diffusers", "timm", "sentence_transformers"],
81+
choices=["transformers", "diffusers", "timm", "sentence_transformers", "open_clip"],
8182
default=None,
8283
help="The library used to load the model before export. If not provided, will attempt to infer the local checkpoint's library",
8384
)
@@ -234,7 +235,7 @@ def run(self):
234235

235236
if self.args.library is None:
236237
# TODO: add revision, subfolder and token to args
237-
library_name = TasksManager._infer_library_from_model_name_or_path(
238+
library_name = _infer_library_from_model_name_or_path(
238239
model_name_or_path=self.args.model, cache_dir=self.args.cache_dir
239240
)
240241
if library_name == "sentence_transformers":
@@ -280,7 +281,7 @@ def run(self):
280281

281282
quantization_config = ov_config.quantization_config if ov_config else None
282283
quantize_with_dataset = quantization_config and getattr(quantization_config, "dataset", None) is not None
283-
task = infer_task(self.args.task, self.args.model)
284+
task = infer_task(self.args.task, self.args.model, library_name=library_name)
284285

285286
if library_name == "diffusers" and quantize_with_dataset:
286287
if not is_diffusers_available():

optimum/exporters/openvino/__main__.py

Lines changed: 58 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,10 @@
3636
is_openvino_version,
3737
is_transformers_version,
3838
)
39+
from optimum.intel.utils.modeling_utils import (
40+
_infer_library_from_model_name_or_path,
41+
_OpenClipForZeroShotImageClassification,
42+
)
3943
from optimum.utils.save_utils import maybe_load_preprocessors
4044

4145
from .utils import _MAX_UNCOMPRESSED_SIZE, clear_class_registry
@@ -59,25 +63,29 @@ def infer_task(
5963
revision: Optional[str] = None,
6064
cache_dir: str = HUGGINGFACE_HUB_CACHE,
6165
token: Optional[Union[bool, str]] = None,
66+
library_name: Optional[str] = None,
6267
):
6368
task = TasksManager.map_from_synonym(task)
6469
if task == "auto":
65-
try:
66-
task = TasksManager._infer_task_from_model_name_or_path(
67-
model_name_or_path=model_name_or_path,
68-
subfolder=subfolder,
69-
revision=revision,
70-
cache_dir=cache_dir,
71-
token=token,
72-
)
73-
except KeyError as e:
74-
raise KeyError(
75-
f"The task could not be automatically inferred. Please provide the argument --task with the relevant task from {', '.join(TasksManager.get_all_tasks())}. Detailed error: {e}"
76-
)
77-
except RequestsConnectionError as e:
78-
raise RequestsConnectionError(
79-
f"The task could not be automatically inferred as this is available only for models hosted on the Hugging Face Hub. Please provide the argument --task with the relevant task from {', '.join(TasksManager.get_all_tasks())}. Detailed error: {e}"
80-
)
70+
if library_name == "open_clip":
71+
task = "zero-shot-image-classification"
72+
else:
73+
try:
74+
task = TasksManager._infer_task_from_model_name_or_path(
75+
model_name_or_path=model_name_or_path,
76+
subfolder=subfolder,
77+
revision=revision,
78+
cache_dir=cache_dir,
79+
token=token,
80+
)
81+
except KeyError as e:
82+
raise KeyError(
83+
f"The task could not be automatically inferred. Please provide the argument --task with the relevant task from {', '.join(TasksManager.get_all_tasks())}. Detailed error: {e}"
84+
)
85+
except RequestsConnectionError as e:
86+
raise RequestsConnectionError(
87+
f"The task could not be automatically inferred as this is available only for models hosted on the Hugging Face Hub. Please provide the argument --task with the relevant task from {', '.join(TasksManager.get_all_tasks())}. Detailed error: {e}"
88+
)
8189
return task
8290

8391

@@ -182,16 +190,13 @@ def main_export(
182190
raise ValueError("You cannot use both `use_auth_token` and `token` arguments at the same time.")
183191
token = use_auth_token
184192

185-
original_task = task
186-
task = infer_task(
187-
task, model_name_or_path, subfolder=subfolder, revision=revision, cache_dir=cache_dir, token=token
188-
)
189-
framework = TasksManager.determine_framework(
190-
model_name_or_path, subfolder=subfolder, revision=revision, cache_dir=cache_dir, token=token
191-
)
193+
if framework is None:
194+
framework = TasksManager.determine_framework(
195+
model_name_or_path, subfolder=subfolder, revision=revision, cache_dir=cache_dir, token=token
196+
)
192197

193198
if library_name is None:
194-
library_name = TasksManager._infer_library_from_model_name_or_path(
199+
library_name = _infer_library_from_model_name_or_path(
195200
model_name_or_path=model_name_or_path,
196201
subfolder=subfolder,
197202
revision=revision,
@@ -205,6 +210,17 @@ def main_export(
205210
)
206211
library_name = "transformers"
207212

213+
original_task = task
214+
task = infer_task(
215+
task,
216+
model_name_or_path,
217+
subfolder=subfolder,
218+
revision=revision,
219+
cache_dir=cache_dir,
220+
token=token,
221+
library_name=library_name,
222+
)
223+
208224
do_gptq_patching = False
209225
custom_architecture = False
210226
patch_16bit = False
@@ -305,21 +321,24 @@ class StoreAttr(object):
305321

306322
GPTQQuantizer.post_init_model = post_init_model
307323

308-
model = TasksManager.get_model_from_task(
309-
task,
310-
model_name_or_path,
311-
subfolder=subfolder,
312-
revision=revision,
313-
cache_dir=cache_dir,
314-
token=token,
315-
local_files_only=local_files_only,
316-
force_download=force_download,
317-
trust_remote_code=trust_remote_code,
318-
framework=framework,
319-
device=device,
320-
library_name=library_name,
321-
**loading_kwargs,
322-
)
324+
if library_name == "open_clip":
325+
model = _OpenClipForZeroShotImageClassification.from_pretrained(model_name_or_path, cache_dir=cache_dir)
326+
else:
327+
model = TasksManager.get_model_from_task(
328+
task,
329+
model_name_or_path,
330+
subfolder=subfolder,
331+
revision=revision,
332+
cache_dir=cache_dir,
333+
token=token,
334+
local_files_only=local_files_only,
335+
force_download=force_download,
336+
trust_remote_code=trust_remote_code,
337+
framework=framework,
338+
device=device,
339+
library_name=library_name,
340+
**loading_kwargs,
341+
)
323342

324343
needs_pad_token_id = task == "text-classification" and getattr(model.config, "pad_token_id", None) is None
325344

optimum/exporters/openvino/convert.py

Lines changed: 27 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from optimum.exporters.utils import _get_submodels_and_export_configs
3434
from optimum.intel.utils.import_utils import (
3535
_nncf_version,
36+
_open_clip_version,
3637
_optimum_intel_version,
3738
_optimum_version,
3839
_timm_version,
@@ -44,11 +45,13 @@
4445
from optimum.utils.save_utils import maybe_save_preprocessors
4546

4647
from ...intel.utils.import_utils import is_nncf_available
48+
from ...intel.utils.modeling_utils import _infer_library_from_model_or_model_class
4749
from .model_patcher import patch_model_with_bettertransformer
4850
from .stateful import ensure_export_task_support_stateful, ensure_stateful_is_available, patch_stateful
4951
from .utils import (
5052
OV_XML_FILE_NAME,
5153
_get_input_info,
54+
_get_open_clip_submodels_fn_and_export_configs,
5255
clear_class_registry,
5356
remove_none_from_dummy_inputs,
5457
)
@@ -183,11 +186,7 @@ def export_tensorflow(
183186
input_names, output_names = export_tensorflow_onnx(model, config, opset, onnx_path)
184187
ov_model = convert_model(str(onnx_path))
185188

186-
if model.__class__.__module__.startswith("optimum"):
187-
# for wrapped models
188-
library_name = TasksManager._infer_library_from_model_or_model_class(model=model.model)
189-
else:
190-
library_name = TasksManager._infer_library_from_model_or_model_class(model=model)
189+
library_name = _infer_library_from_model_or_model_class(model=model)
191190

192191
_save_model(
193192
ov_model,
@@ -248,11 +247,7 @@ def export_pytorch_via_onnx(
248247
torch.onnx.export = orig_torch_onnx_export
249248
ov_model = convert_model(str(onnx_output))
250249

251-
if model.__class__.__module__.startswith("optimum"):
252-
# for wrapped models
253-
library_name = TasksManager._infer_library_from_model_or_model_class(model=model.model)
254-
else:
255-
library_name = TasksManager._infer_library_from_model_or_model_class(model=model)
250+
library_name = _infer_library_from_model_or_model_class(model=model)
256251

257252
_save_model(
258253
ov_model,
@@ -422,11 +417,7 @@ def ts_patched_forward(*args, **kwargs):
422417
if stateful:
423418
patch_stateful(model.config, ov_model)
424419

425-
if model.__module__.startswith("optimum"):
426-
# for wrapped models like timm in optimum.intel.openvino.modeling_timm
427-
library_name = TasksManager._infer_library_from_model_or_model_class(model=model.model)
428-
else:
429-
library_name = TasksManager._infer_library_from_model_or_model_class(model=model)
420+
library_name = _infer_library_from_model_or_model_class(model=model)
430421

431422
_save_model(
432423
ov_model,
@@ -535,8 +526,9 @@ def export_from_model(
535526
f"Compression of the weights to {ov_config.quantization_config} requires nncf, please install it with `pip install nncf`"
536527
)
537528

538-
library_name = TasksManager._infer_library_from_model_or_model_class(model=model)
539-
TasksManager.standardize_model_attributes(model)
529+
library_name = _infer_library_from_model_or_model_class(model)
530+
if library_name != "open_clip":
531+
TasksManager.standardize_model_attributes(model)
540532

541533
if hasattr(model.config, "export_model_type"):
542534
model_type = model.config.export_model_type.replace("_", "-")
@@ -597,6 +589,12 @@ def export_from_model(
597589
kwargs_shapes[input_name] if input_name in kwargs_shapes else DEFAULT_DUMMY_SHAPES[input_name]
598590
)
599591

592+
if library_name == "open_clip":
593+
custom_architecture = True
594+
custom_export_configs, fn_get_submodels = _get_open_clip_submodels_fn_and_export_configs(
595+
model, library_name, task, preprocessors, custom_export_configs, fn_get_submodels
596+
)
597+
600598
logging.disable(logging.INFO)
601599
export_config, models_and_export_configs = _get_submodels_and_export_configs(
602600
model=model,
@@ -614,7 +612,16 @@ def export_from_model(
614612
)
615613
logging.disable(logging.NOTSET)
616614

617-
if library_name != "diffusers":
615+
if library_name == "open_clip":
616+
if hasattr(model.config, "save_pretrained"):
617+
model.config.save_pretrained(output)
618+
619+
for preprocess in preprocessors:
620+
if hasattr(preprocess, "save_pretrained"):
621+
preprocess.save_pretrained(output)
622+
623+
files_subpaths = ["openvino_" + model_name + ".xml" for model_name in models_and_export_configs.keys()]
624+
elif library_name != "diffusers":
618625
# Saving the model config and preprocessor as this is needed sometimes.
619626
model.config.save_pretrained(output)
620627
generation_config = getattr(model, "generation_config", None)
@@ -744,6 +751,8 @@ def _add_version_info_to_model(model: Model, library_name: Optional[str] = None)
744751
model.set_rt_info(_optimum_version, ["optimum", "diffusers_version"])
745752
elif library_name == "timm":
746753
model.set_rt_info(_timm_version, ["optimum", "timm_version"])
754+
elif library_name == "open_clip":
755+
model.set_rt_info(_open_clip_version, ["optimum", "open_clip_version"])
747756
rt_info = model.get_rt_info()
748757
if "nncf" in rt_info:
749758
model.set_rt_info(_nncf_version, ["optimum", "nncf_version"])

0 commit comments

Comments
 (0)