Skip to content

Commit 264205b

Browse files
authored
Add openvino Zero-shot-Image-Classification support (#1273)
* support Zero-shot-Image-Classification * add tests * update supported models * skip siglip for old transformers * apply review comments
1 parent c909268 commit 264205b

File tree

10 files changed

+122
-1
lines changed

10 files changed

+122
-1
lines changed

docs/source/openvino/models.mdx

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ Here is the list of the supported architectures :
119119
- SEW
120120
- SEW-D
121121
- Segformer
122+
- SigLIP
122123
- SmolVLM(SmolVLM2)
123124
- SpeechT5 (text-to-speech)
124125
- SqueezeBert

optimum/intel/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,7 @@
191191
"OVModelOpenCLIPVisual",
192192
"OVModelOpenCLIPText",
193193
"OVModelOpenCLIPForZeroShotImageClassification",
194+
"OVModelForZeroShotImageClassification",
194195
"OVSamModel",
195196
]
196197
)
@@ -356,6 +357,7 @@
356357
OVModelForTokenClassification,
357358
OVModelForVision2Seq,
358359
OVModelForVisualCausalLM,
360+
OVModelForZeroShotImageClassification,
359361
OVModelOpenCLIPForZeroShotImageClassification,
360362
OVModelOpenCLIPText,
361363
OVModelOpenCLIPVisual,

optimum/intel/openvino/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@
6969
OVModelForQuestionAnswering,
7070
OVModelForSequenceClassification,
7171
OVModelForTokenClassification,
72+
OVModelForZeroShotImageClassification,
7273
)
7374
from .modeling_decoder import OVModelForCausalLM
7475
from .modeling_open_clip import (

optimum/intel/openvino/modeling.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
AutoModelForQuestionAnswering,
3636
AutoModelForSequenceClassification,
3737
AutoModelForTokenClassification,
38+
AutoModelForZeroShotImageClassification,
3839
PretrainedConfig,
3940
)
4041
from transformers.file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
@@ -49,6 +50,7 @@
4950
TokenClassifierOutput,
5051
XVectorOutput,
5152
)
53+
from transformers.models.clip.modeling_clip import CLIPOutput
5254

5355
from ..utils.import_utils import is_timm_available, is_timm_version
5456
from .modeling_base import OVBaseModel
@@ -952,3 +954,45 @@ def forward(self, **kwargs):
952954
model_outputs[key_name] = torch.from_numpy(value).to(self.device) if not np_inputs else value
953955

954956
return ModelOutput(**model_outputs)
957+
958+
959+
class OVModelForZeroShotImageClassification(OVModel):
960+
auto_model_class = AutoModelForZeroShotImageClassification
961+
export_feature = "zero-shot-image-classification"
962+
963+
def forward(self, input_ids, pixel_values, attention_mask: Optional[torch.Tensor] = None, **kwargs):
964+
self.compile()
965+
966+
np_inputs = isinstance(input_ids, np.ndarray)
967+
if not np_inputs:
968+
input_ids = input_ids.cpu().numpy()
969+
pixel_values = pixel_values.cpu().numpy()
970+
attention_mask = attention_mask.cpu().numpy() if attention_mask is not None else attention_mask
971+
inputs = {"input_ids": input_ids, "pixel_values": pixel_values}
972+
# Add the attention_mask when needed
973+
if "attention_mask" in self.input_names:
974+
inputs["attention_mask"] = attention_mask if attention_mask is not None else np.ones_like(input_ids)
975+
outputs = self._inference(inputs)
976+
logits_per_image = (
977+
torch.from_numpy(outputs["logits_per_image"]).to(self.device)
978+
if not np_inputs
979+
else outputs["logits_per_image"]
980+
)
981+
logits_per_text = (
982+
torch.from_numpy(outputs["logits_per_text"]).to(self.device)
983+
if not np_inputs
984+
else outputs["logits_per_text"]
985+
)
986+
text_embeds = (
987+
torch.from_numpy(outputs["text_embeds"]).to(self.device) if not np_inputs else outputs["text_embeds"]
988+
)
989+
image_embeds = (
990+
torch.from_numpy(outputs["image_embeds"]).to(self.device) if not np_inputs else outputs["image_embeds"]
991+
)
992+
993+
return CLIPOutput(
994+
logits_per_image=logits_per_image,
995+
logits_per_text=logits_per_text,
996+
text_embeds=text_embeds,
997+
image_embeds=image_embeds,
998+
)

optimum/intel/openvino/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@
129129
"question-answering": "OVModelForQuestionAnswering",
130130
"image-classification": "OVModelForImageClassification",
131131
"image-text-to-text": "OVModelForVisualCausalLM",
132+
"zero-shot-image-classification": "OVModelForZeroShotImageClassification",
132133
"audio-classification": "OVModelForAudioClassification",
133134
"stable-diffusion": "OVStableDiffusionPipeline",
134135
"stable-diffusion-xl": "OVStableDiffusionXLPipeline",

optimum/intel/utils/dummy_openvino_objects.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,17 @@ def from_pretrained(cls, *args, **kwargs):
5959
requires_backends(cls, ["openvino"])
6060

6161

62+
class OVModelForZeroShotImageClassification(metaclass=DummyObject):
63+
_backends = ["openvino"]
64+
65+
def __init__(self, *args, **kwargs):
66+
requires_backends(self, ["openvino"])
67+
68+
@classmethod
69+
def from_pretrained(cls, *args, **kwargs):
70+
requires_backends(cls, ["openvino"])
71+
72+
6273
class OVModelForAudioFrameClassification(metaclass=DummyObject):
6374
_backends = ["openvino"]
6475

tests/openvino/test_export.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
OVModelForTextToSpeechSeq2Seq,
4545
OVModelForTokenClassification,
4646
OVModelForVisualCausalLM,
47+
OVModelForZeroShotImageClassification,
4748
OVSamModel,
4849
OVStableDiffusion3Pipeline,
4950
OVStableDiffusionPipeline,
@@ -78,6 +79,7 @@ class ExportModelTest(unittest.TestCase):
7879
"llava": OVModelForVisualCausalLM,
7980
"sam": OVSamModel,
8081
"speecht5": OVModelForTextToSpeechSeq2Seq,
82+
"clip": OVModelForZeroShotImageClassification,
8183
}
8284

8385
EXPECTED_DIFFUSERS_SCALE_FACTORS = {

tests/openvino/test_exporters_cli.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
OVModelForTextToSpeechSeq2Seq,
4646
OVModelForTokenClassification,
4747
OVModelForVisualCausalLM,
48+
OVModelForZeroShotImageClassification,
4849
OVModelOpenCLIPForZeroShotImageClassification,
4950
OVModelOpenCLIPText,
5051
OVModelOpenCLIPVisual,
@@ -87,6 +88,7 @@ class OVCLIExportTestCase(unittest.TestCase):
8788
("image-to-image", "stable-diffusion-xl-refiner"),
8889
("feature-extraction", "sam"),
8990
("text-to-audio", "speecht5"),
91+
("zero-shot-image-classification", "clip"),
9092
]
9193

9294
if is_transformers_version(">=", "4.45"):
@@ -119,6 +121,7 @@ class OVCLIExportTestCase(unittest.TestCase):
119121
"ltx-video": 2 if is_tokenizers_version("<", "0.20.0") or is_openvino_version(">=", "2024.5") else 0,
120122
"sam": 0, # no tokenizer
121123
"speecht5": 2,
124+
"clip": 2 if is_tokenizers_version("<", "0.20.0") or is_openvino_version(">=", "2024.5") else 0,
122125
}
123126

124127
TOKENIZER_CHAT_TEMPLATE_TESTS_MODELS = {

tests/openvino/test_modeling.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
AutoModelForSpeechSeq2Seq,
5454
AutoModelForTokenClassification,
5555
AutoModelForVision2Seq,
56+
AutoModelForZeroShotImageClassification,
5657
AutoProcessor,
5758
AutoTokenizer,
5859
GenerationConfig,
@@ -95,6 +96,7 @@
9596
OVModelForTokenClassification,
9697
OVModelForVision2Seq,
9798
OVModelForVisualCausalLM,
99+
OVModelForZeroShotImageClassification,
98100
OVModelOpenCLIPForZeroShotImageClassification,
99101
OVSamModel,
100102
OVSentenceTransformer,
@@ -2817,7 +2819,7 @@ def test_pipeline(self, model_arch: str):
28172819
ov_model.reshape(1, -1)
28182820
ov_model.compile()
28192821

2820-
# Speech recogition generation
2822+
# Image caption generation
28212823
pipe = pipeline(
28222824
"image-to-text",
28232825
model=ov_model,
@@ -3295,5 +3297,56 @@ def test_compare_to_transformers(self, model_arch):
32953297
del vocoder
32963298
del model
32973299
del processor
3300+
gc.collect()
3301+
32983302

3303+
class OVModelForZeroShotImageClassificationIntegrationTest(unittest.TestCase):
3304+
SUPPORTED_ARCHITECTURES = ["clip"]
3305+
if is_transformers_version(">=", "4.45"):
3306+
SUPPORTED_ARCHITECTURES.append("siglip")
3307+
TASK = "zero-shot-image-classification"
3308+
IMAGE_URL = "http://images.cocodataset.org/val2017/000000039769.jpg"
3309+
3310+
@parameterized.expand(SUPPORTED_ARCHITECTURES)
3311+
def test_compare_to_transformers(self, model_arch):
3312+
model_id = MODEL_NAMES[model_arch]
3313+
set_seed(SEED)
3314+
ov_model = OVModelForZeroShotImageClassification.from_pretrained(model_id, export=True, ov_config=F32_CONFIG)
3315+
processor = get_preprocessor(model_id)
3316+
3317+
self.assertIsInstance(ov_model.config, PretrainedConfig)
3318+
3319+
IMAGE = Image.open(
3320+
requests.get(
3321+
self.IMAGE_URL,
3322+
stream=True,
3323+
).raw
3324+
).convert("RGB")
3325+
labels = ["a photo of a cat", "a photo of a dog"]
3326+
inputs = processor(images=IMAGE, text=labels, return_tensors="pt")
3327+
3328+
transformers_model = AutoModelForZeroShotImageClassification.from_pretrained(model_id)
3329+
3330+
# test end-to-end inference
3331+
ov_outputs = ov_model(**inputs)
3332+
3333+
self.assertTrue("logits_per_image" in ov_outputs)
3334+
self.assertIsInstance(ov_outputs.logits_per_image, torch.Tensor)
3335+
self.assertTrue("logits_per_text" in ov_outputs)
3336+
self.assertIsInstance(ov_outputs.logits_per_text, torch.Tensor)
3337+
self.assertTrue("text_embeds" in ov_outputs)
3338+
self.assertIsInstance(ov_outputs.text_embeds, torch.Tensor)
3339+
self.assertTrue("image_embeds" in ov_outputs)
3340+
self.assertIsInstance(ov_outputs.image_embeds, torch.Tensor)
3341+
3342+
with torch.no_grad():
3343+
transformers_outputs = transformers_model(**inputs)
3344+
# Compare tensor outputs
3345+
self.assertTrue(torch.allclose(ov_outputs.logits_per_image, transformers_outputs.logits_per_image, atol=1e-4))
3346+
self.assertTrue(torch.allclose(ov_outputs.logits_per_text, transformers_outputs.logits_per_text, atol=1e-4))
3347+
self.assertTrue(torch.allclose(ov_outputs.text_embeds, transformers_outputs.text_embeds, atol=1e-4))
3348+
self.assertTrue(torch.allclose(ov_outputs.image_embeds, transformers_outputs.image_embeds, atol=1e-4))
3349+
3350+
del transformers_model
3351+
del ov_model
32993352
gc.collect()

tests/openvino/utils_tests.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
"blenderbot": "hf-internal-testing/tiny-random-BlenderbotModel",
4040
"bloom": "hf-internal-testing/tiny-random-BloomModel",
4141
"camembert": "hf-internal-testing/tiny-random-camembert",
42+
"clip": "hf-tiny-model-private/tiny-random-CLIPModel",
4243
"convbert": "hf-internal-testing/tiny-random-ConvBertForSequenceClassification",
4344
"cohere": "hf-internal-testing/tiny-random-CohereForCausalLM",
4445
"chatglm": "katuni4ka/tiny-random-chatglm2",
@@ -154,6 +155,7 @@
154155
"stable-diffusion-3": "yujiepan/stable-diffusion-3-tiny-random",
155156
"stablelm": "hf-internal-testing/tiny-random-StableLmForCausalLM",
156157
"starcoder2": "hf-internal-testing/tiny-random-Starcoder2ForCausalLM",
158+
"siglip": "katuni4ka/tiny-random-SiglipModel",
157159
"latent-consistency": "echarlaix/tiny-random-latent-consistency",
158160
"sew": "hf-internal-testing/tiny-random-SEWModel",
159161
"sew_d": "asapp/sew-d-tiny-100k-ft-ls100h",
@@ -223,6 +225,7 @@
223225
"ltx-video": (34, 28, 28, 64),
224226
"sam": (102, 100),
225227
"speecht5": (28, 52, 10, 80),
228+
"clip": (130,),
226229
}
227230

228231
TEST_IMAGE_URL = "http://images.cocodataset.org/val2017/000000039769.jpg"

0 commit comments

Comments
 (0)