@@ -640,9 +640,9 @@ def _prepare_decoder_calibration_data(
640640 self .model .request = InferRequestWrapper (self .model .request , collected_inputs )
641641 try :
642642 for data in tqdm (dataloader , desc = "Collecting calibration data" , total = num_samples ):
643- self .model .generate (** data , max_new_tokens = 1 )
644- if len (collected_inputs ) >= num_samples :
643+ if len (collected_inputs ) > num_samples :
645644 break
645+ self .model .generate (** data , max_new_tokens = 1 )
646646 finally :
647647 self .model .request = self .model .request .request
648648
@@ -695,6 +695,9 @@ def _prepare_visual_causal_lm_calibration_data(
695695 calibration_data = []
696696 num_samples = config .num_samples or 32
697697 for item in tqdm (dataset , desc = "Collecting calibration dataset" , total = num_samples ):
698+ if len (calibration_data ) > num_samples :
699+ break
700+
698701 instruction = item [dataset_metadata ["inputs" ]["instruction" ]]
699702 image_url = item [dataset_metadata ["inputs" ]["image_url" ]]
700703 image = Image .open (requests .get (image_url , stream = True ).raw ).convert ("RGB" )
@@ -725,9 +728,6 @@ def _prepare_visual_causal_lm_calibration_data(
725728
726729 calibration_data .append (language_model_inputs )
727730
728- if len (calibration_data ) >= num_samples :
729- break
730-
731731 return OVCalibrationDataset ({"lm_model" : nncf .Dataset (calibration_data )})
732732
733733 def _prepare_speech_to_text_calibration_data (
0 commit comments