|
53 | 53 | AutoModelForSpeechSeq2Seq, |
54 | 54 | AutoModelForTokenClassification, |
55 | 55 | AutoModelForVision2Seq, |
| 56 | + AutoModelForZeroShotImageClassification, |
56 | 57 | AutoProcessor, |
57 | 58 | AutoTokenizer, |
58 | 59 | GenerationConfig, |
|
95 | 96 | OVModelForTokenClassification, |
96 | 97 | OVModelForVision2Seq, |
97 | 98 | OVModelForVisualCausalLM, |
| 99 | + OVModelForZeroShotImageClassification, |
98 | 100 | OVModelOpenCLIPForZeroShotImageClassification, |
99 | 101 | OVSamModel, |
100 | 102 | OVSentenceTransformer, |
@@ -2817,7 +2819,7 @@ def test_pipeline(self, model_arch: str): |
2817 | 2819 | ov_model.reshape(1, -1) |
2818 | 2820 | ov_model.compile() |
2819 | 2821 |
|
2820 | | - # Speech recogition generation |
| 2822 | + # Image caption generation |
2821 | 2823 | pipe = pipeline( |
2822 | 2824 | "image-to-text", |
2823 | 2825 | model=ov_model, |
@@ -3295,5 +3297,56 @@ def test_compare_to_transformers(self, model_arch): |
3295 | 3297 | del vocoder |
3296 | 3298 | del model |
3297 | 3299 | del processor |
| 3300 | + gc.collect() |
| 3301 | + |
3298 | 3302 |
|
| 3303 | +class OVModelForZeroShotImageClassificationIntegrationTest(unittest.TestCase): |
| 3304 | + SUPPORTED_ARCHITECTURES = ["clip"] |
| 3305 | + if is_transformers_version(">=", "4.45"): |
| 3306 | + SUPPORTED_ARCHITECTURES.append("siglip") |
| 3307 | + TASK = "zero-shot-image-classification" |
| 3308 | + IMAGE_URL = "http://images.cocodataset.org/val2017/000000039769.jpg" |
| 3309 | + |
| 3310 | + @parameterized.expand(SUPPORTED_ARCHITECTURES) |
| 3311 | + def test_compare_to_transformers(self, model_arch): |
| 3312 | + model_id = MODEL_NAMES[model_arch] |
| 3313 | + set_seed(SEED) |
| 3314 | + ov_model = OVModelForZeroShotImageClassification.from_pretrained(model_id, export=True, ov_config=F32_CONFIG) |
| 3315 | + processor = get_preprocessor(model_id) |
| 3316 | + |
| 3317 | + self.assertIsInstance(ov_model.config, PretrainedConfig) |
| 3318 | + |
| 3319 | + IMAGE = Image.open( |
| 3320 | + requests.get( |
| 3321 | + self.IMAGE_URL, |
| 3322 | + stream=True, |
| 3323 | + ).raw |
| 3324 | + ).convert("RGB") |
| 3325 | + labels = ["a photo of a cat", "a photo of a dog"] |
| 3326 | + inputs = processor(images=IMAGE, text=labels, return_tensors="pt") |
| 3327 | + |
| 3328 | + transformers_model = AutoModelForZeroShotImageClassification.from_pretrained(model_id) |
| 3329 | + |
| 3330 | + # test end-to-end inference |
| 3331 | + ov_outputs = ov_model(**inputs) |
| 3332 | + |
| 3333 | + self.assertTrue("logits_per_image" in ov_outputs) |
| 3334 | + self.assertIsInstance(ov_outputs.logits_per_image, torch.Tensor) |
| 3335 | + self.assertTrue("logits_per_text" in ov_outputs) |
| 3336 | + self.assertIsInstance(ov_outputs.logits_per_text, torch.Tensor) |
| 3337 | + self.assertTrue("text_embeds" in ov_outputs) |
| 3338 | + self.assertIsInstance(ov_outputs.text_embeds, torch.Tensor) |
| 3339 | + self.assertTrue("image_embeds" in ov_outputs) |
| 3340 | + self.assertIsInstance(ov_outputs.image_embeds, torch.Tensor) |
| 3341 | + |
| 3342 | + with torch.no_grad(): |
| 3343 | + transformers_outputs = transformers_model(**inputs) |
| 3344 | + # Compare tensor outputs |
| 3345 | + self.assertTrue(torch.allclose(ov_outputs.logits_per_image, transformers_outputs.logits_per_image, atol=1e-4)) |
| 3346 | + self.assertTrue(torch.allclose(ov_outputs.logits_per_text, transformers_outputs.logits_per_text, atol=1e-4)) |
| 3347 | + self.assertTrue(torch.allclose(ov_outputs.text_embeds, transformers_outputs.text_embeds, atol=1e-4)) |
| 3348 | + self.assertTrue(torch.allclose(ov_outputs.image_embeds, transformers_outputs.image_embeds, atol=1e-4)) |
| 3349 | + |
| 3350 | + del transformers_model |
| 3351 | + del ov_model |
3299 | 3352 | gc.collect() |
0 commit comments