@@ -199,14 +199,28 @@ def _run_mct_qat(self, float_model, rep_dataset, abits, a_qmethod):
199199 quantized_model = pytorch_quantization_aware_training_finalize_experimental (qat_ready_model )
200200 return quantized_model
201201
202- def _run_exporter (self , quantized_model , rep_dataset , quantization_format ):
202+ def _run_exporter (self , quantized_model , rep_dataset , quantization_format , output_names = None ):
203203 pytorch_export_model (quantized_model ,
204204 save_model_path = self .onnx_file ,
205205 repr_dataset = rep_dataset ,
206206 serialization_format = PytorchExportSerializationFormat .ONNX ,
207- quantization_format = quantization_format )
207+ quantization_format = quantization_format ,
208+ output_names = output_names )
208209
209- return onnx_reader (self .onnx_file , quantized_model .linear_activation_holder_quantizer .activation_holder_quantizer )
210+ return onnx_reader (self .onnx_file ,
211+ quantized_model .linear_activation_holder_quantizer .activation_holder_quantizer )
212+
213+ def _assert_outputs_names (self , output_names ):
214+ model = onnx .load (self .onnx_file )
215+ exported_output_names = [output .name for output in model .graph .output ]
216+
217+ if output_names is None :
218+ if len (exported_output_names ) == 1 :
219+ output_names = ['output' ]
220+ else :
221+ output_names = [f"output_{ i } " for i in range (len (exported_output_names ))]
222+ assert all (name in exported_output_names for name in output_names )
223+ assert len (output_names ) == len (exported_output_names )
210224
211225 def _assert_outputs_match (self , quantized_model , rep_dataset , quantization_format , tol = 1e-8 ):
212226 pass
@@ -304,6 +318,17 @@ def test_mct_ptq_and_exporter_mctq(self, w_qmethod, abits, a_qmethod, tol):
304318 self ._assert_quant_params_match (quantized_model , onnx_model_dict , a_qmethod , w_qmethod )
305319 self ._assert_outputs_match (quantized_model , self .representative_dataset (1 ), QuantizationFormat .MCTQ , tol = tol )
306320
321+ @pytest .mark .parametrize ('w_qmethod' , [mctq .QuantizationMethod .POWER_OF_TWO ])
322+ @pytest .mark .parametrize ('a_qmethod' , [mctq .QuantizationMethod .SYMMETRIC ])
323+ @pytest .mark .parametrize ('abits' , [8 , 16 ])
324+ @pytest .mark .parametrize ('output_names' , [None , ['x' ]])
325+ def test_mct_ptq_exporter_mctq_output_names (self , w_qmethod , abits , a_qmethod , output_names ):
326+ # set_seed(13)
327+ quantized_model = self ._run_mct (self .get_model (), self .representative_dataset (1 ), abits , a_qmethod , w_qmethod )
328+ onnx_model_dict = self ._run_exporter (quantized_model , self .representative_dataset (1 ), QuantizationFormat .MCTQ ,
329+ output_names = output_names )
330+ self ._assert_outputs_names (output_names = output_names )
331+
307332 @pytest .mark .parametrize ('abits, tol' , ([8 , 1e-4 ], [16 , 1e-2 ]))
308333 def test_mct_ptq_and_exporter_fq (self , abits , tol ):
309334 quantized_model = self ._run_mct (self .get_model (), self .representative_dataset (1 ), abits , mctq .QuantizationMethod .POWER_OF_TWO )
@@ -363,6 +388,25 @@ def forward(self, x):
363388 self ._run_exporter (quantized_model , self .representative_dataset (1 ), QuantizationFormat .MCTQ )
364389 self ._assert_outputs_match (quantized_model , self .representative_dataset (1 ), QuantizationFormat .MCTQ )
365390
391+ @pytest .mark .parametrize ('abits' , [8 , 16 ])
392+ @pytest .mark .parametrize ('output_names' , [None , ['x' , 'y' ]])
393+ def test_multi_output_names_mct_and_exporter_mctq (self , abits , output_names ):
394+ class MultiOutputModel (torch .nn .Module ):
395+ def __init__ (self , in_channels , out_channels ):
396+ super ().__init__ ()
397+ self .linear = torch .nn .Linear (in_channels , out_channels )
398+ self .linear_y = torch .nn .Linear (in_channels , out_channels )
399+
400+ def forward (self , x ):
401+ return self .linear (x ), self .linear_y (x )
402+
403+ quantized_model = self ._run_mct (MultiOutputModel (self .in_channels , self .out_channels ),
404+ self .representative_dataset (1 ),
405+ abits , mctq .QuantizationMethod .POWER_OF_TWO )
406+ self ._run_exporter (quantized_model , self .representative_dataset (1 ), QuantizationFormat .MCTQ ,
407+ output_names = output_names )
408+ self ._assert_outputs_names (output_names = output_names )
409+
366410 @pytest .mark .parametrize ('abits' , [8 , 16 ])
367411 def test_multi_input_output_mct_and_exporter_mctq (self , abits ):
368412 class MultiInputOutputModel (torch .nn .Module ):
0 commit comments