|
42 | 42 | CLIPModelPatcher, |
43 | 43 | CohereModelPatcher, |
44 | 44 | FluxTransformerModelPatcher, |
| 45 | + MetaCLIP2Patcher, |
45 | 46 | MgpstrModelPatcher, |
46 | 47 | MoonshineModelPatcher, |
47 | 48 | MusicgenModelPatcher, |
@@ -1247,6 +1248,85 @@ def outputs(self) -> dict[str, dict[int, str]]: |
1247 | 1248 | return common_outputs |
1248 | 1249 |
|
1249 | 1250 |
|
| 1251 | +@register_tasks_manager_onnx( |
| 1252 | + "metaclip_2", |
| 1253 | + *["feature-extraction", "zero-shot-image-classification", "image-classification"], |
| 1254 | + library_name="transformers", |
| 1255 | +) |
| 1256 | +class MetaCLIP2OnnxConfig(TextAndVisionOnnxConfig): |
| 1257 | + NORMALIZED_CONFIG_CLASS = CLIPNormalizedConfig |
| 1258 | + MIN_TRANSFORMERS_VERSION = version.parse("4.56.2") |
| 1259 | + VARIANTS = { # noqa: RUF012 |
| 1260 | + "monolith": "All the MetaClip2 model components are exported as a single model.onnx.", |
| 1261 | + "split": "The vision model is exported as a separate vision_model.onnx, and the text_model is exported as text_model.onnx", |
| 1262 | + } |
| 1263 | + DEFAULT_VARIANT = "monolith" |
| 1264 | + _MODEL_PATCHER = MetaCLIP2Patcher |
| 1265 | + |
| 1266 | + def __init__( |
| 1267 | + self, |
| 1268 | + config: PretrainedConfig, |
| 1269 | + task: str = "feature-extraction", |
| 1270 | + int_dtype: str = "int64", |
| 1271 | + float_dtype: str = "fp32", |
| 1272 | + variant: str = "monolith", |
| 1273 | + vision_model: bool | None = None, |
| 1274 | + preprocessors: list[Any] | None = None, |
| 1275 | + ): |
| 1276 | + super().__init__( |
| 1277 | + config=config, |
| 1278 | + task=task, |
| 1279 | + int_dtype=int_dtype, |
| 1280 | + float_dtype=float_dtype, |
| 1281 | + preprocessors=preprocessors, |
| 1282 | + ) |
| 1283 | + self.variant = variant |
| 1284 | + self.vision_model = vision_model |
| 1285 | + |
| 1286 | + @property |
| 1287 | + def inputs(self) -> dict[str, dict[int, str]]: |
| 1288 | + if self.variant == "monolith": |
| 1289 | + inputs = {"pixel_values": {0: "batch_size", 1: "num_channels", 2: "height", 3: "width"}} |
| 1290 | + if self.task in ["feature-extraction", "zero-shot-image-classification"]: |
| 1291 | + inputs.update( |
| 1292 | + { |
| 1293 | + "input_ids": {0: "text_batch_size", 1: "sequence_length"}, |
| 1294 | + "attention_mask": {0: "text_batch_size", 1: "sequence_length"}, |
| 1295 | + } |
| 1296 | + ) |
| 1297 | + else: |
| 1298 | + if self.vision_model: |
| 1299 | + inputs = {"pixel_values": {0: "batch_size", 1: "num_channels", 2: "height", 3: "width"}} |
| 1300 | + else: |
| 1301 | + inputs = { |
| 1302 | + "input_ids": {0: "text_batch_size", 1: "sequence_length"}, |
| 1303 | + "attention_mask": {0: "text_batch_size", 1: "sequence_length"}, |
| 1304 | + } |
| 1305 | + return inputs |
| 1306 | + |
| 1307 | + @property |
| 1308 | + def outputs(self) -> dict[str, dict[int, str]]: |
| 1309 | + if self.variant == "split": |
| 1310 | + if self.vision_model: |
| 1311 | + return { |
| 1312 | + "image_embeds": {0: "batch_size"}, |
| 1313 | + } |
| 1314 | + else: |
| 1315 | + return { |
| 1316 | + "text_embeds": {0: "batch_size"}, |
| 1317 | + } |
| 1318 | + else: |
| 1319 | + if self.task in ["feature-extraction", "zero-shot-image-classification"]: |
| 1320 | + return { |
| 1321 | + "logits_per_image": {0: "image_batch_size", 1: "text_batch_size"}, |
| 1322 | + "logits_per_text": {0: "text_batch_size", 1: "image_batch_size"}, |
| 1323 | + "text_embeds": {0: "text_batch_size"}, |
| 1324 | + "image_embeds": {0: "image_batch_size"}, |
| 1325 | + } |
| 1326 | + else: |
| 1327 | + return super().outputs |
| 1328 | + |
| 1329 | + |
1250 | 1330 | class SiglipNormalizedConfig(CLIPNormalizedConfig): |
1251 | 1331 | pass |
1252 | 1332 |
|
|
0 commit comments