|
6 | 6 |
|
7 | 7 | from collections import defaultdict |
8 | 8 | from enum import Enum |
9 | | -from typing import Dict, List, Optional, Tuple |
| 9 | +from typing import Any, Callable, Dict, List, Optional, Tuple |
10 | 10 |
|
11 | 11 | import nncf |
12 | 12 | import nncf.common.quantization as quantization |
@@ -345,31 +345,57 @@ def validate(self, model: torch.fx.GraphModule) -> None: |
345 | 345 | def quantize_model( |
346 | 346 | captured_model: torch.fx.GraphModule, |
347 | 347 | calibration_dataset: torch.utils.data.DataLoader, |
| 348 | + *, |
| 349 | + mode: QuantizationMode = QuantizationMode.INT8_SYM, |
| 350 | + subset_size: int = 300, |
| 351 | + fast_bias_correction: Optional[bool] = True, |
| 352 | + smooth_quant: bool = False, |
| 353 | + transform_fn: Optional[Callable[[Any], Any]] = None, |
| 354 | + extra_quantizer_options: Optional[Dict[str, Any]] = None, |
| 355 | + **kwargs, |
348 | 356 | ) -> torch.fx.GraphModule: |
349 | 357 | """ |
350 | | - Quantizes a model using either NNCF-based or PTQ-based quantization. |
| 358 | + Quantizes a model using NNCF quantize_pt2e API. |
351 | 359 |
|
352 | 360 | :param captured_model: The model to be quantized, represented as a torch.fx.GraphModule. |
353 | 361 | :param calibration_dataset: A DataLoader containing calibration data for quantization. |
| 362 | + :param mode: Defines special quantization modes. |
| 363 | + - INT8_SYM: INT8 symmetric quantization for both activations and weights. |
| 364 | + - INT8_MIXED: INT8 asymmetric quantization for activations, symmetric for weights. |
| 365 | + - INT8_TRANSFORMER: Optimized INT8 quantization for transformer-based models |
| 366 | + Default value is INT8_SYM. |
| 367 | + :param subset_size: Size of a subset to calculate activations |
| 368 | + statistics used for quantization. |
| 369 | + :param fast_bias_correction: Setting this option to `False` enables a different |
| 370 | + bias correction method which is more accurate, in general, and takes |
| 371 | + more time but requires less memory. None disables the bias correction algorithm. |
| 372 | + :param smooth_quant: Setting this option to `True` enables the SmoothQuant algorithm. |
| 373 | + :param extra_quantizer_options: A dictionary containing additional configuration options |
| 374 | + for the OpenVINOQuantizer. |
| 375 | + :param kwargs: The keyword arguments for the nncf quantize_pt2e function. |
354 | 376 | :return: The quantized model as a torch.fx.GraphModule. |
355 | 377 | """ |
356 | | - quantizer = OpenVINOQuantizer() |
| 378 | + extra_quantizer_options = extra_quantizer_options or {} |
| 379 | + if "mode" in extra_quantizer_options: |
| 380 | + print( |
| 381 | + f'Ignoring "mode" from the quantizer_config. Using parameter mode = {mode}' |
| 382 | + ) |
| 383 | + del extra_quantizer_options["mode"] |
| 384 | + |
| 385 | + quantizer = OpenVINOQuantizer(mode=mode, **extra_quantizer_options) |
357 | 386 |
|
358 | 387 | print("PTQ: Quantize the model") |
359 | | - default_subset_size = 300 |
360 | | - batch_size = calibration_dataset.batch_size |
361 | | - subset_size = (default_subset_size // batch_size) + int( |
362 | | - default_subset_size % batch_size > 0 |
363 | | - ) |
364 | 388 |
|
365 | | - def transform(x): |
366 | | - return x[0] |
| 389 | + if "fold_quantize" not in kwargs: |
| 390 | + kwargs["fold_quantize"] = False |
367 | 391 |
|
368 | 392 | quantized_model = nncf_fx.quantize_pt2e( |
369 | 393 | captured_model, |
370 | 394 | quantizer, |
371 | 395 | subset_size=subset_size, |
372 | | - calibration_dataset=nncf.Dataset(calibration_dataset, transform_func=transform), |
373 | | - fold_quantize=False, |
| 396 | + calibration_dataset=nncf.Dataset(calibration_dataset, transform_fn), |
| 397 | + fast_bias_correction=fast_bias_correction, |
| 398 | + smooth_quant=smooth_quant, |
| 399 | + **kwargs, |
374 | 400 | ) |
375 | 401 | return quantized_model |
0 commit comments