1818import os
1919from collections import UserDict , deque
2020from contextlib import contextmanager
21+ from io import BytesIO
2122from pathlib import Path
2223from typing import Any , Callable , Dict , Iterable , List , Optional , Tuple , Union
2324
7374 PREDEFINED_LANGUAGE_DATASETS ,
7475 PREDEFINED_SD_DATASETS ,
7576 PREDEFINED_SPEECH_TO_TEXT_DATASETS ,
77+ PREDEFINED_TEXT_IMAGE_ENCODER_DATASETS ,
7678 PREDEFINED_VISUAL_LM_DATASETS ,
7779)
7880
@@ -268,6 +270,7 @@ def build_from_quantization_config(self, config: OVQuantizationConfigBase) -> OV
268270 OVModelForFeatureExtraction ,
269271 OVModelForMaskedLM ,
270272 OVModelForVisualCausalLM ,
273+ OVModelForZeroShotImageClassification ,
271274 OVSentenceTransformer ,
272275 )
273276 from optimum .intel .openvino .modeling_seq2seq import _OVModelForWhisper
@@ -280,7 +283,9 @@ def build_from_quantization_config(self, config: OVQuantizationConfigBase) -> OV
280283
281284 if isinstance (self .model , OVModelForCausalLM ):
282285 return self ._prepare_causal_lm_calibration_data (config )
283- elif isinstance (self .model , (OVModelForVisualCausalLM , _OVModelForWhisper )):
286+ elif isinstance (
287+ self .model , (OVModelForVisualCausalLM , _OVModelForWhisper , OVModelForZeroShotImageClassification )
288+ ):
284289 if config .processor is None :
285290 raise ValueError (
286291 "`processor` must be specified in order to run data-aware quantization. Please provide it as a"
@@ -307,6 +312,16 @@ def build_from_quantization_config(self, config: OVQuantizationConfigBase) -> OV
307312 trust_remote_code = config .trust_remote_code ,
308313 streaming = dataset_metadata ["streaming" ],
309314 )
315+ elif isinstance (self .model , OVModelForZeroShotImageClassification ):
316+ dataset_metadata = PREDEFINED_TEXT_IMAGE_ENCODER_DATASETS [config .dataset ]
317+ return self .build_from_dataset_name (
318+ config ,
319+ dataset_metadata ["id" ],
320+ num_samples = None ,
321+ dataset_split = dataset_metadata ["split" ],
322+ trust_remote_code = config .trust_remote_code ,
323+ streaming = dataset_metadata ["streaming" ],
324+ )
310325 else :
311326 raise Exception
312327 elif is_diffusers_available () and isinstance (self .model , OVDiffusionPipeline ):
@@ -330,13 +345,14 @@ def build_from_quantization_config(self, config: OVQuantizationConfigBase) -> OV
330345 return self .build_from_dataset (config , dataset )
331346 elif isinstance (self .model , (OVModelForFeatureExtraction , OVSentenceTransformer , OVModelForMaskedLM )):
332347 if isinstance (config .dataset , str ):
348+ dataset_metadata = PREDEFINED_LANGUAGE_DATASETS [config .dataset ]
333349 dataset = self .load_dataset (
334- PREDEFINED_LANGUAGE_DATASETS [ config . dataset ][ "path " ],
350+ dataset_metadata [ "id " ],
335351 num_samples = None ,
336- dataset_config_name = PREDEFINED_LANGUAGE_DATASETS [ config . dataset ] ["name" ],
337- dataset_split = PREDEFINED_LANGUAGE_DATASETS [ config . dataset ] ["split" ],
352+ dataset_config_name = dataset_metadata ["name" ],
353+ dataset_split = dataset_metadata ["split" ],
338354 trust_remote_code = config .trust_remote_code ,
339- streaming = PREDEFINED_LANGUAGE_DATASETS [ config . dataset ] ["streaming" ],
355+ streaming = dataset_metadata ["streaming" ],
340356 )
341357 elif isinstance (config .dataset , list ) and all (isinstance (it , str ) for it in config .dataset ):
342358 dataset = datasets .Dataset .from_list ([{"text" : it } for it in config .dataset ])
@@ -345,6 +361,8 @@ def build_from_quantization_config(self, config: OVQuantizationConfigBase) -> OV
345361 "Please provide dataset as one of the accepted dataset labels or as a list of strings."
346362 )
347363 return self .build_from_dataset (config , dataset )
364+ else :
365+ raise RuntimeError ("Unsupported model type for calibration dataset collection." )
348366
349367 def build_from_dataset_name (
350368 self ,
@@ -449,6 +467,7 @@ def build_from_dataset(
449467 OVModelForFeatureExtraction ,
450468 OVModelForMaskedLM ,
451469 OVModelForVisualCausalLM ,
470+ OVModelForZeroShotImageClassification ,
452471 OVSentenceTransformer ,
453472 )
454473 from optimum .intel .openvino .modeling_decoder import OVBaseDecoderModel
@@ -470,6 +489,7 @@ def build_from_dataset(
470489 _OVModelForWhisper ,
471490 OVModelForFeatureExtraction ,
472491 OVModelForMaskedLM ,
492+ OVModelForZeroShotImageClassification ,
473493 OVSentenceTransformer ,
474494 ),
475495 ) or (is_diffusers_available () and isinstance (self .model , OVDiffusionPipeline )):
@@ -487,6 +507,8 @@ def build_from_dataset(
487507 return self ._prepare_diffusion_calibration_data (quantization_config , dataset )
488508 elif isinstance (self .model , (OVModelForFeatureExtraction , OVSentenceTransformer , OVModelForMaskedLM )):
489509 return self ._prepare_text_encoder_model_calibration_data (quantization_config , dataset )
510+ elif isinstance (self .model , OVModelForZeroShotImageClassification ):
511+ return self ._prepare_text_image_encoder_model_calibration_data (quantization_config , dataset )
490512 else :
491513 raise RuntimeError ("Unsupported model type for calibration dataset collection." )
492514 else :
@@ -878,6 +900,74 @@ def get_tokenizer():
878900
879901 return OVCalibrationDataset ({"model" : nncf .Dataset (calibration_data )})
880902
903+ def _prepare_text_image_encoder_model_calibration_data (
904+ self ,
905+ quantization_config : OVQuantizationConfigBase ,
906+ dataset : "Dataset" ,
907+ seq_len : int = 128 ,
908+ ) -> OVCalibrationDataset :
909+ self .model .compile ()
910+
911+ def get_processor ():
912+ processor = AutoProcessor .from_pretrained (
913+ quantization_config .processor , trust_remote_code = quantization_config .trust_remote_code
914+ )
915+ return processor
916+
917+ max_position_embeddings = getattr (self .model .config , "max_position_embeddings" , None )
918+ if max_position_embeddings is not None and max_position_embeddings > 0 :
919+ seq_len = min (seq_len , max_position_embeddings )
920+
921+ num_samples = quantization_config .num_samples or 128
922+ calibration_data = []
923+ try :
924+ inference_result_mock = {
925+ "logits_per_image" : np .empty ((1 ,), np .float32 ),
926+ "logits_per_text" : np .empty ((1 ,), np .float32 ),
927+ "text_embeds" : np .empty ((1 ,), np .float32 ),
928+ "image_embeds" : np .empty ((1 ,), np .float32 ),
929+ }
930+
931+ self .model .request = InferRequestWrapper (
932+ self .model .request ,
933+ calibration_data ,
934+ inference_result_mock = inference_result_mock ,
935+ )
936+
937+ processor = None
938+ pbar = tqdm (total = num_samples , desc = "Collecting calibration data" )
939+ for item in dataset :
940+ if "input_ids" in item :
941+ # Assuming that dataset contains already preprocessed text
942+ inputs = self ._wrap_sample_as_array (item , add_batch_dim = True )
943+ else :
944+ dataset_metadata = PREDEFINED_TEXT_IMAGE_ENCODER_DATASETS [quantization_config .dataset ]
945+ try :
946+ response = requests .get (item [dataset_metadata ["image_column_name" ]], timeout = 5 )
947+ response .raise_for_status ()
948+ image = Image .open (BytesIO (response .content ))
949+ except Exception :
950+ continue
951+ processor = processor or get_processor ()
952+ inputs = processor (
953+ text = item [dataset_metadata ["text_column_name" ]],
954+ images = image .convert ("RGB" ),
955+ return_tensors = "pt" ,
956+ padding = True ,
957+ )
958+ if inputs ["input_ids" ].shape [1 ] > seq_len :
959+ inputs ["input_ids" ] = inputs ["input_ids" ][:, :seq_len ]
960+
961+ self .model (** inputs )
962+
963+ pbar .update (min (num_samples , len (calibration_data )) - pbar .n )
964+ if len (calibration_data ) >= num_samples :
965+ break
966+ finally :
967+ self .model .request = self .model .request .request
968+
969+ return OVCalibrationDataset ({"model" : nncf .Dataset (calibration_data )})
970+
881971 @staticmethod
882972 def _wrap_sample_as_array (
883973 sample : Dict [str , Any ], add_batch_dim : bool = False
0 commit comments