66
77from collections import defaultdict
88from enum import Enum
9- from typing import Dict , List , Optional , Tuple , Callable , Any
9+ from typing import Any , Callable , Dict , List , Optional , Tuple
1010
1111import nncf
1212import nncf .common .quantization as quantization
@@ -351,30 +351,45 @@ def transform_for_annotation(
351351
352352def quantize_model (
353353 captured_model : torch .fx .GraphModule ,
354- quantizer : Quantizer ,
355354 calibration_dataset : torch .utils .data .DataLoader ,
356- subset_size : int ,
355+ * ,
356+ mode : QuantizationMode = QuantizationMode .INT8_SYM ,
357+ subset_size : int = 300 ,
357358 fast_bias_correction : Optional [bool ] = True ,
358359 smooth_quant : bool = False ,
359- transform_fn : Optional [Callable [[Any ], Any ]]= None ,
360+ transform_fn : Optional [Callable [[Any ], Any ]] = None ,
361+ extra_quantizer_options : Optional [Dict [str , Any ]] = None ,
360362 ** kwargs ,
361363) -> torch .fx .GraphModule :
362364 """
363365 Quantizes a model using NNCF quantize_pt2e API.
364366
365367 :param captured_model: The model to be quantized, represented as a torch.fx.GraphModule.
366- :param quantizer: Torch ao quantizer to annotate nodes in the graph with quantization setups
367368 :param calibration_dataset: A DataLoader containing calibration data for quantization.
369+ :param mode: Defines special quantization modes.
370+ - INT8_SYM: INT8 symmetric quantization for both activations and weights.
371+ - INT8_MIXED: INT8 asymmetric quantization for activations, symmetric for weights.
372+ - INT8_TRANSFORMER: Optimized INT8 quantization for transformer-based models
373+ Default value is INT8_SYM.
368374 :param subset_size: Size of a subset to calculate activations
369375 statistics used for quantization.
370376 :param fast_bias_correction: Setting this option to `False` enables a different
371377 bias correction method which is more accurate, in general, and takes
372378 more time but requires less memory. None disables the bias correction algorithm.
373379 :param smooth_quant: Setting this option to `True` enables the SmoothQuant algorithm.
380+ :param extra_quantizer_options: A dictionary containing additional configuration options
381+ for the OpenVINOQuantizer.
374382 :param kwargs: The keyword arguments for the nncf quantize_pt2e function.
375383 :return: The quantized model as a torch.fx.GraphModule.
376384 """
377- quantizer = OpenVINOQuantizer ()
385+ extra_quantizer_options = extra_quantizer_options or {}
386+ if "mode" in extra_quantizer_options :
387+ print (
388+ f'Ignoring "mode" from the quantizer_config. Using parameter mode = { mode } '
389+ )
390+ del extra_quantizer_options ["mode" ]
391+
392+ quantizer = OpenVINOQuantizer (mode = mode , ** extra_quantizer_options )
378393
379394 print ("PTQ: Quantize the model" )
380395
@@ -388,6 +403,6 @@ def quantize_model(
388403 calibration_dataset = nncf .Dataset (calibration_dataset , transform_fn ),
389404 fast_bias_correction = fast_bias_correction ,
390405 smooth_quant = smooth_quant ,
391- ** kwargs
406+ ** kwargs ,
392407 )
393408 return quantized_model
0 commit comments