@@ -114,10 +114,23 @@ class OVQuantizerTest(unittest.TestCase):
114114 (14 , 22 , 21 ) if is_transformers_version ("<=" , "4.42.4" ) else (14 , 22 , 25 ),
115115 (14 , 21 , 17 ) if is_transformers_version ("<=" , "4.42.4" ) else (14 , 22 , 18 ),
116116 ),
117+ (
118+ OVModelForCausalLM ,
119+ "llama" ,
120+ OVQuantizationConfig (
121+ dataset = "wikitext2" ,
122+ num_samples = 1 ,
123+ weight_only = False ,
124+ weight_format = "f8e4m3" ,
125+ activation_format = "f8e4m3" ,
126+ ),
127+ (13 ,),
128+ (16 ,),
129+ ),
117130 ]
118131
119132 @parameterized .expand (SUPPORTED_ARCHITECTURES_TORCH_MODEL )
120- def test_automodel_static_quantization (self , model_cls , model_name , expected_fake_quantize , expected_int8 ):
133+ def test_automodel_static_quantization (self , model_cls , model_name , expected_fake_nodes , expected_int8_nodes ):
121134 model_id = MODEL_NAMES [model_name ]
122135 task = model_cls .export_feature
123136 dataset_name , dataset_config_name , column_name = _TASK_TO_DATASET [task ]
@@ -149,9 +162,9 @@ def preprocess_function(examples, tokenizer):
149162 ov_config = ov_config ,
150163 )
151164 model = model_cls .from_pretrained (tmp_dir , file_name = file_name )
152- num_fake_quantize , num_weight_nodes = get_num_quantized_nodes (model )
153- self .assertEqual (expected_fake_quantize , num_fake_quantize )
154- self .assertEqual (expected_int8 , num_weight_nodes ["int8" ])
165+ num_fake_nodes , num_weight_nodes = get_num_quantized_nodes (model )
166+ self .assertEqual (expected_fake_nodes , num_fake_nodes )
167+ self .assertEqual (expected_int8_nodes , num_weight_nodes ["int8" ])
155168
156169 tokens = tokenizer ("This is a sample input" , return_tensors = "pt" )
157170 outputs = model (** tokens )
@@ -162,7 +175,7 @@ def preprocess_function(examples, tokenizer):
162175 self .assertEqual (ov_config .quantization_config .to_dict (), loaded_config .quantization_config .to_dict ())
163176
164177 @parameterized .expand (SUPPORTED_ARCHITECTURES_OV_MODEL )
165- def test_ovmodel_static_quantization (self , model_cls , model_name , expected_fake_quantize , expected_int8 ):
178+ def test_ovmodel_static_quantization (self , model_cls , model_name , expected_fake_nodes , expected_int8_nodes ):
166179 model_id = MODEL_NAMES [model_name ]
167180 task = model_cls .export_feature
168181 dataset_name , dataset_config_name , column_name = _TASK_TO_DATASET [task ]
@@ -190,9 +203,9 @@ def preprocess_function(examples, tokenizer):
190203
191204 model = model_cls .from_pretrained (tmp_dir )
192205
193- num_fake_quantize , num_weight_nodes = get_num_quantized_nodes (model )
194- self .assertEqual (expected_fake_quantize , num_fake_quantize )
195- self .assertEqual (expected_int8 , num_weight_nodes ["int8" ])
206+ num_fake_nodes , num_weight_nodes = get_num_quantized_nodes (model )
207+ self .assertEqual (expected_fake_nodes , num_fake_nodes )
208+ self .assertEqual (expected_int8_nodes , num_weight_nodes ["int8" ])
196209
197210 tokens = tokenizer ("This is a sample input" , return_tensors = "pt" )
198211 outputs = model (** tokens )
@@ -204,9 +217,10 @@ def preprocess_function(examples, tokenizer):
204217
205218 @parameterized .expand (SUPPORTED_ARCHITECTURES_OV_MODEL_WITH_AUTO_DATASET )
206219 def test_ov_model_static_quantization_with_auto_dataset (
207- self , model_cls , model_name , quantization_config , expected_fake_quantize , expected_int8
220+ self , model_cls , model_name , quantization_config , expected_fake_nodes , expected_low_precision_nodes
208221 ):
209222 model_id = MODEL_NAMES [model_name ]
223+ quant_mode = quantization_config .activation_format
210224
211225 with TemporaryDirectory () as tmp_dir :
212226 ov_model = model_cls .from_pretrained (model_id , quantization_config = quantization_config )
@@ -217,17 +231,28 @@ def test_ov_model_static_quantization_with_auto_dataset(
217231
218232 if ov_model .decoder_with_past is not None :
219233 models .append (ov_model .decoder_with_past .model )
220- for model , expected_fq , expected_i8 in zip (
234+ for model , expected_fake_nodes , expected_lp_nodes in zip (
221235 models ,
222- expected_fake_quantize ,
223- expected_int8 ,
236+ expected_fake_nodes ,
237+ expected_low_precision_nodes ,
224238 ):
225- num_fake_quantize , num_weight_nodes = get_num_quantized_nodes (model )
226- self .assertEqual (expected_fq , num_fake_quantize )
227- self .assertEqual (expected_i8 , num_weight_nodes ["int8" ])
239+ num_fake_nodes , num_weight_nodes = get_num_quantized_nodes (model )
240+ self .assertEqual (expected_fake_nodes , num_fake_nodes )
241+ self .assertEqual (expected_lp_nodes , num_weight_nodes [quant_mode ])
228242
229243 input_features = torch .randn ((1 , 128 , 3000 ), dtype = torch .float32 )
230244 ov_model .generate (input_features )
245+ elif model_cls == OVModelForCausalLM :
246+ num_fake_nodes , num_weight_nodes = get_num_quantized_nodes (ov_model .model )
247+ self .assertEqual (expected_fake_nodes [0 ], num_fake_nodes )
248+ self .assertEqual (expected_low_precision_nodes [0 ], num_weight_nodes [quant_mode ])
249+
250+ tokenizer = AutoTokenizer .from_pretrained (model_id )
251+ if tokenizer .pad_token is None :
252+ tokenizer .pad_token = tokenizer .eos_token
253+ tokens = tokenizer ("This is a sample input" , return_tensors = "pt" )
254+ outputs = ov_model (** tokens )
255+ self .assertTrue ("logits" in outputs )
231256 else :
232257 raise Exception ("Unexpected model class." )
233258
@@ -608,7 +633,7 @@ def test_ovmodel_8bit_weight_compression(self, model_cls, model_name, expected_p
608633 self .assertEqual (OVWeightQuantizationConfig ().to_dict (), loaded_config .quantization_config .to_dict ())
609634
610635 @parameterized .expand (SUPPORTED_ARCHITECTURES_WITH_EXPECTED_4BIT_COMPRESSED_MATMULS )
611- def test_ovmodel_4bit_weight_compression (self , model_cls , model_name , expected_int8 , expected_int4 ):
636+ def test_ovmodel_4bit_weight_compression (self , model_cls , model_name , expected_int8_nodes , expected_int4_nodes ):
612637 task = model_cls .export_feature
613638 model_id = MODEL_NAMES [model_name ]
614639 with TemporaryDirectory () as tmp_dir :
@@ -623,8 +648,8 @@ def test_ovmodel_4bit_weight_compression(self, model_cls, model_name, expected_i
623648 model = model_cls .from_pretrained (tmp_dir )
624649
625650 _ , num_weight_nodes = get_num_quantized_nodes (model )
626- self .assertEqual (expected_int8 , num_weight_nodes ["int8" ])
627- self .assertEqual (expected_int4 , num_weight_nodes ["int4" ])
651+ self .assertEqual (expected_int8_nodes , num_weight_nodes ["int8" ])
652+ self .assertEqual (expected_int4_nodes , num_weight_nodes ["int4" ])
628653
629654 tokens = tokenizer ("This is a sample input" , return_tensors = "pt" )
630655 outputs = model (** tokens )
@@ -699,17 +724,17 @@ def test_ovmodel_load_with_compressed_weights(self, model_cls, model_type, trust
699724 self .assertEqual (expected_ov_int8 [i ], num_weight_nodes ["int8" ])
700725
701726 @parameterized .expand (SUPPORTED_ARCHITECTURES_WITH_HYBRID_QUANTIZATION )
702- def test_ovmodel_hybrid_quantization (self , model_cls , model_type , expected_num_fake_quantize , expected_ov_int8 ):
727+ def test_ovmodel_hybrid_quantization (self , model_cls , model_type , expected_fake_nodes , expected_int8_nodes ):
703728 model_id = MODEL_NAMES [model_type ]
704729 quantization_config = OVWeightQuantizationConfig (bits = 8 , dataset = "conceptual_captions" , num_samples = 2 )
705730 with TemporaryDirectory () as tmp_dir :
706731 model = model_cls .from_pretrained (model_id , export = True , quantization_config = quantization_config )
707732
708- num_fake_quantize , num_weight_nodes = get_num_quantized_nodes (
733+ num_fake , num_weight_nodes = get_num_quantized_nodes (
709734 model .unet if model .unet is not None else model .transformer
710735 )
711- self .assertEqual (expected_num_fake_quantize , num_fake_quantize )
712- self .assertEqual (expected_ov_int8 , num_weight_nodes ["int8" ])
736+ self .assertEqual (expected_fake_nodes , num_fake )
737+ self .assertEqual (expected_int8_nodes , num_weight_nodes ["int8" ])
713738 self .assertEqual (0 , num_weight_nodes ["int4" ])
714739
715740 model .save_pretrained (tmp_dir )
@@ -721,16 +746,16 @@ def test_stable_diffusion_with_weight_compression(self):
721746
722747 quantizer .quantize (ov_config = OVConfig (quantization_config = quantization_config ))
723748
724- num_fake_quantize , num_weight_nodes = get_num_quantized_nodes (
749+ num_fake_nodes , num_weight_nodes = get_num_quantized_nodes (
725750 int8_pipe .unet if int8_pipe .unet is not None else int8_pipe .transformer
726751 )
727- self .assertEqual (0 , num_fake_quantize )
752+ self .assertEqual (0 , num_fake_nodes )
728753 self .assertEqual (242 , num_weight_nodes ["int8" ])
729754 self .assertEqual (0 , num_weight_nodes ["int4" ])
730755
731756 @parameterized .expand (SUPPORTED_ARCHITECTURES_WITH_HYBRID_QUANTIZATION [- 1 :])
732757 def test_ovmodel_hybrid_quantization_with_custom_dataset (
733- self , model_cls , model_type , expected_num_fake_quantize , expected_ov_int8
758+ self , model_cls , model_type , expected_fake_nodes , expected_int8_nodes
734759 ):
735760 model_id = MODEL_NAMES [model_type ]
736761 dataset = [
@@ -742,11 +767,11 @@ def test_ovmodel_hybrid_quantization_with_custom_dataset(
742767 self .assertEqual (quantization_config .quant_method , OVQuantizationMethod .HYBRID )
743768
744769 quantizer .quantize (ov_config = OVConfig (quantization_config = quantization_config ), calibration_dataset = dataset )
745- num_fake_quantize , num_weight_nodes = get_num_quantized_nodes (
770+ num_fake_nodes , num_weight_nodes = get_num_quantized_nodes (
746771 model .unet if model .unet is not None else model .transformer
747772 )
748- self .assertEqual (expected_num_fake_quantize , num_fake_quantize )
749- self .assertEqual (expected_ov_int8 , num_weight_nodes ["int8" ])
773+ self .assertEqual (expected_fake_nodes , num_fake_nodes )
774+ self .assertEqual (expected_int8_nodes , num_weight_nodes ["int8" ])
750775 self .assertEqual (0 , num_weight_nodes ["int4" ])
751776
752777 @parameterized .expand (SUPPORTED_ARCHITECTURES_WITH_EXPECTED_4BIT_AUTOCOMPRESSED_MATMULS )
@@ -1050,7 +1075,7 @@ class OVTrainerTest(unittest.TestCase):
10501075 @unittest .skipIf (
10511076 is_transformers_version (">=" , "4.46" ), reason = "OVTrainer is not compatible with transformers>=v4.46"
10521077 )
1053- def test_aware_training_quantization (self , model_name , expected_fake_quantize , expected_int8 ):
1078+ def test_aware_training_quantization (self , model_name , expected_fake_nodes , expected_int8_nodes ):
10541079 model_id = MODEL_NAMES [model_name ]
10551080 model = AutoModelForSequenceClassification .from_pretrained (model_id , attn_implementation = "eager" )
10561081 tokenizer = AutoTokenizer .from_pretrained (model_id )
@@ -1084,9 +1109,9 @@ def compute_metrics(p):
10841109 trainer .save_model ()
10851110
10861111 model = OVModelForSequenceClassification .from_pretrained (tmp_dir )
1087- num_fake_quantize , num_weight_nodes = get_num_quantized_nodes (model )
1088- self .assertEqual (expected_fake_quantize , num_fake_quantize )
1089- self .assertEqual (expected_int8 , num_weight_nodes ["int8" ])
1112+ num_fake_nodes , num_weight_nodes = get_num_quantized_nodes (model )
1113+ self .assertEqual (expected_fake_nodes , num_fake_nodes )
1114+ self .assertEqual (expected_int8_nodes , num_weight_nodes ["int8" ])
10901115
10911116 tokens = tokenizer ("This is a sample input" , return_tensors = "pt" )
10921117 outputs = model (** tokens )
0 commit comments