Skip to content

Commit 0ad63df

Browse files
Add Metaclip-2 export (#71)
This PR adds support for exporting Metaclip-2 models and Zero-shot-for-image-classification ort model. --------- Co-authored-by: Ilyas Moutawwakil <[email protected]>
1 parent 1e4fe5b commit 0ad63df

File tree

16 files changed

+502
-13
lines changed

16 files changed

+502
-13
lines changed

docs/source/onnx/overview.mdx

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ Supported architectures from [🤗 Transformers](https://huggingface.co/docs/tra
7979
- MarkupLM
8080
- MaskFormer
8181
- MBart
82+
- MetaClip2
8283
- MGP-STR
8384
- Mistral
8485
- MobileBert

docs/source/onnxruntime/package_reference/modeling_ort.mdx

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,11 @@ The following ORT classes are available for the following computer vision tasks.
6868
[[autodoc]] onnxruntime.ORTModelForImageClassification
6969
- forward
7070

71+
### ORTModelForZeroShotImageClassification
72+
73+
[[autodoc]] onnxruntime.ORTModelForZeroShotImageClassification
74+
- forward
75+
7176
### ORTModelForSemanticSegmentation
7277

7378
[[autodoc]] onnxruntime.ORTModelForSemanticSegmentation

optimum/exporters/onnx/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
"validate_models_outputs",
3030
"onnx_export_from_model",
3131
],
32-
"utils": ["MODEL_TYPES_REQUIRING_POSITION_IDS"],
32+
"utils": ["MODEL_TYPES_REQUIRING_POSITION_IDS", "get_metaclip_2_models_for_export"],
3333
"__main__": ["main_export"],
3434
}
3535

@@ -44,7 +44,7 @@
4444
validate_model_outputs,
4545
validate_models_outputs,
4646
)
47-
from optimum.exporters.onnx.utils import MODEL_TYPES_REQUIRING_POSITION_IDS
47+
from optimum.exporters.onnx.utils import MODEL_TYPES_REQUIRING_POSITION_IDS, get_metaclip_2_models_for_export
4848
else:
4949
import sys
5050

optimum/exporters/onnx/__main__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -449,6 +449,7 @@ def main():
449449
pad_token_id=args.pad_token_id,
450450
library_name=args.library_name,
451451
do_constant_folding=not args.no_constant_folding,
452+
slim=args.slim,
452453
**input_shapes,
453454
)
454455

optimum/exporters/onnx/convert.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ def validate_model_outputs(
211211
io_process.join()
212212

213213
if io_process.exception:
214-
error, traceback = io_process.exception
214+
error, _ = io_process.exception
215215
raise error
216216
else:
217217
_run_validation(

optimum/exporters/onnx/model_configs.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
CLIPModelPatcher,
4343
CohereModelPatcher,
4444
FluxTransformerModelPatcher,
45+
MetaCLIP2Patcher,
4546
MgpstrModelPatcher,
4647
MoonshineModelPatcher,
4748
MusicgenModelPatcher,
@@ -1247,6 +1248,85 @@ def outputs(self) -> dict[str, dict[int, str]]:
12471248
return common_outputs
12481249

12491250

1251+
@register_tasks_manager_onnx(
1252+
"metaclip_2",
1253+
*["feature-extraction", "zero-shot-image-classification", "image-classification"],
1254+
library_name="transformers",
1255+
)
1256+
class MetaCLIP2OnnxConfig(TextAndVisionOnnxConfig):
1257+
NORMALIZED_CONFIG_CLASS = CLIPNormalizedConfig
1258+
MIN_TRANSFORMERS_VERSION = version.parse("4.56.2")
1259+
VARIANTS = { # noqa: RUF012
1260+
"monolith": "All the MetaClip2 model components are exported as a single model.onnx.",
1261+
"split": "The vision model is exported as a separate vision_model.onnx, and the text_model is exported as text_model.onnx",
1262+
}
1263+
DEFAULT_VARIANT = "monolith"
1264+
_MODEL_PATCHER = MetaCLIP2Patcher
1265+
1266+
def __init__(
1267+
self,
1268+
config: PretrainedConfig,
1269+
task: str = "feature-extraction",
1270+
int_dtype: str = "int64",
1271+
float_dtype: str = "fp32",
1272+
variant: str = "monolith",
1273+
vision_model: bool | None = None,
1274+
preprocessors: list[Any] | None = None,
1275+
):
1276+
super().__init__(
1277+
config=config,
1278+
task=task,
1279+
int_dtype=int_dtype,
1280+
float_dtype=float_dtype,
1281+
preprocessors=preprocessors,
1282+
)
1283+
self.variant = variant
1284+
self.vision_model = vision_model
1285+
1286+
@property
1287+
def inputs(self) -> dict[str, dict[int, str]]:
1288+
if self.variant == "monolith":
1289+
inputs = {"pixel_values": {0: "batch_size", 1: "num_channels", 2: "height", 3: "width"}}
1290+
if self.task in ["feature-extraction", "zero-shot-image-classification"]:
1291+
inputs.update(
1292+
{
1293+
"input_ids": {0: "text_batch_size", 1: "sequence_length"},
1294+
"attention_mask": {0: "text_batch_size", 1: "sequence_length"},
1295+
}
1296+
)
1297+
else:
1298+
if self.vision_model:
1299+
inputs = {"pixel_values": {0: "batch_size", 1: "num_channels", 2: "height", 3: "width"}}
1300+
else:
1301+
inputs = {
1302+
"input_ids": {0: "text_batch_size", 1: "sequence_length"},
1303+
"attention_mask": {0: "text_batch_size", 1: "sequence_length"},
1304+
}
1305+
return inputs
1306+
1307+
@property
1308+
def outputs(self) -> dict[str, dict[int, str]]:
1309+
if self.variant == "split":
1310+
if self.vision_model:
1311+
return {
1312+
"image_embeds": {0: "batch_size"},
1313+
}
1314+
else:
1315+
return {
1316+
"text_embeds": {0: "batch_size"},
1317+
}
1318+
else:
1319+
if self.task in ["feature-extraction", "zero-shot-image-classification"]:
1320+
return {
1321+
"logits_per_image": {0: "image_batch_size", 1: "text_batch_size"},
1322+
"logits_per_text": {0: "text_batch_size", 1: "image_batch_size"},
1323+
"text_embeds": {0: "text_batch_size"},
1324+
"image_embeds": {0: "image_batch_size"},
1325+
}
1326+
else:
1327+
return super().outputs
1328+
1329+
12501330
class SiglipNormalizedConfig(CLIPNormalizedConfig):
12511331
pass
12521332

optimum/exporters/onnx/model_patcher.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1196,6 +1196,32 @@ def patched_forward(
11961196
self.patched_forward = patched_forward
11971197

11981198

1199+
class MetaCLIP2Patcher(ModelPatcher):
1200+
def __init__(
1201+
self,
1202+
config: OnnxConfig,
1203+
model: PreTrainedModel,
1204+
model_kwargs: dict[str, Any] | None = None,
1205+
):
1206+
super().__init__(config, model, model_kwargs)
1207+
1208+
def patched_forward(input_ids=None, pixel_values=None, attention_mask=None):
1209+
if config.variant == "monolith":
1210+
return self.orig_forward(input_ids=input_ids, pixel_values=pixel_values, attention_mask=attention_mask)
1211+
1212+
if config.variant == "split":
1213+
if config.vision_model:
1214+
image_embeds = model.get_image_features(pixel_values)
1215+
return {"image_embeds": image_embeds}
1216+
1217+
text_embeds = model.get_text_features(input_ids, attention_mask)
1218+
return {
1219+
"text_embeds": text_embeds,
1220+
}
1221+
1222+
self.patched_forward = patched_forward
1223+
1224+
11991225
class CLIPModelPatcher(ModelPatcher):
12001226
def __enter__(self):
12011227
super().__enter__()
@@ -1379,7 +1405,7 @@ def __exit__(self, exc_type, exc_value, traceback):
13791405

13801406
def patched_cohere_rotary_forward(self, x, position_ids):
13811407
# Get batch size and sequence length for manual expansion
1382-
batch_size, seq_len = position_ids.shape[:2]
1408+
batch_size, _ = position_ids.shape[:2]
13831409

13841410
# Instead of using expand, manually repeat the tensor.
13851411
# Problem with expand: it creates a view with shared memory rather than copying data,

optimum/exporters/onnx/utils.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
import torch
2121
from transformers.utils import is_torch_available
2222

23+
from optimum.exporters.base import ExporterConfig
24+
from optimum.exporters.tasks import TasksManager
2325
from optimum.exporters.utils import _get_submodels_and_export_configs
2426
from optimum.utils.import_utils import is_transformers_version
2527

@@ -131,6 +133,38 @@ def __setstate__(self, values):
131133
self.sess = ort.InferenceSession(self.model_path, sess_options=self.sess_options, providers=self.providers)
132134

133135

136+
def _get_submodels_for_export_metaclip_2(model, variant):
137+
models_for_export = {}
138+
139+
if variant == "monolith":
140+
models_for_export["model"] = model
141+
else:
142+
# We rather use the model patcher to patch their forward method.
143+
models_for_export["vision_model"] = model
144+
models_for_export["text_model"] = model
145+
146+
return models_for_export
147+
148+
149+
def get_metaclip_2_models_for_export(model: PreTrainedModel, config: ExporterConfig):
150+
models_for_export = _get_submodels_for_export_metaclip_2(model, config.variant)
151+
152+
if config.variant == "monolith":
153+
export_config = config.__class__(model.config, task=config.task, variant=config.variant)
154+
models_for_export["model"] = (models_for_export["model"], export_config)
155+
else:
156+
vision_model_export_config = config.__class__(
157+
model.config, task=config.task, variant=config.variant, vision_model=True
158+
)
159+
text_model_export_config = config.__class__(
160+
model.config, task=config.task, variant=config.variant, vision_model=False
161+
)
162+
models_for_export["vision_model"] = (models_for_export["vision_model"], vision_model_export_config)
163+
models_for_export["text_model"] = (models_for_export["text_model"], text_model_export_config)
164+
165+
return models_for_export
166+
167+
134168
def _get_submodels_and_onnx_configs(
135169
model: PreTrainedModel,
136170
task: str,
@@ -145,6 +179,18 @@ def _get_submodels_and_onnx_configs(
145179
preprocessors: list[Any] | None = None,
146180
model_kwargs: dict | None = None,
147181
):
182+
if library_name == "transformers" and model.config.model_type == "metaclip_2":
183+
export_config_constructor = TasksManager.get_exporter_config_constructor(
184+
model=model, exporter="onnx", task=task, library_name="transformers"
185+
)
186+
export_config = export_config_constructor(
187+
model.config,
188+
int_dtype=int_dtype,
189+
float_dtype=float_dtype,
190+
preprocessors=preprocessors,
191+
)
192+
export_config.variant = _variant
193+
return export_config, get_metaclip_2_models_for_export(model, export_config)
148194
return _get_submodels_and_export_configs(
149195
model,
150196
task,

optimum/onnxruntime/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
"ORTModelForCTC",
5454
"ORTModelForFeatureExtraction",
5555
"ORTModelForImageClassification",
56+
"ORTModelForZeroShotImageClassification",
5657
"ORTModelForMaskedLM",
5758
"ORTModelForMultipleChoice",
5859
"ORTModelForQuestionAnswering",
@@ -159,6 +160,7 @@
159160
ORTModelForSemanticSegmentation,
160161
ORTModelForSequenceClassification,
161162
ORTModelForTokenClassification,
163+
ORTModelForZeroShotImageClassification,
162164
)
163165
from optimum.onnxruntime.modeling_seq2seq import (
164166
ORTModelForPix2Struct,

optimum/onnxruntime/modeling_decoder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -362,7 +362,7 @@ def forward(
362362
outputs_to_not_bind = set()
363363
if use_cache and self.use_io_binding:
364364
# Infers the shape of the output pkv
365-
batch_size, seq_len = input_ids.shape
365+
batch_size, _ = input_ids.shape
366366
if self.old_bloom_modeling:
367367
num_key_value_heads_batch_size, embed_size_per_head = past_key_values[0].shape[:2]
368368
k_shape = (num_key_value_heads_batch_size, embed_size_per_head, out_seq_len)

0 commit comments

Comments
 (0)