|
6 | 6 |
|
7 | 7 | from collections import defaultdict |
8 | 8 | from enum import Enum |
9 | | -from typing import Dict, List, Optional, Tuple |
| 9 | +from typing import Dict, List, Optional, Tuple, Callable, Any |
10 | 10 |
|
11 | 11 | import nncf |
12 | 12 | import nncf.common.quantization as quantization |
@@ -351,32 +351,43 @@ def transform_for_annotation( |
351 | 351 |
|
352 | 352 | def quantize_model( |
353 | 353 | captured_model: torch.fx.GraphModule, |
| 354 | + quantizer: Quantizer, |
354 | 355 | calibration_dataset: torch.utils.data.DataLoader, |
| 356 | + subset_size: int, |
| 357 | + fast_bias_correction: Optional[bool] = True, |
| 358 | + smooth_quant: bool = False, |
| 359 | + transform_fn: Optional[Callable[[Any], Any]]= None, |
| 360 | + **kwargs, |
355 | 361 | ) -> torch.fx.GraphModule: |
356 | 362 | """ |
357 | | - Quantizes a model using either NNCF-based or PTQ-based quantization. |
| 363 | + Quantizes a model using NNCF quantize_pt2e API. |
358 | 364 |
|
359 | 365 | :param captured_model: The model to be quantized, represented as a torch.fx.GraphModule. |
| 366 | + :param quantizer: Torch ao quantizer to annotate nodes in the graph with quantization setups |
360 | 367 | :param calibration_dataset: A DataLoader containing calibration data for quantization. |
| 368 | + :param subset_size: Size of a subset to calculate activations |
| 369 | + statistics used for quantization. |
| 370 | + :param fast_bias_correction: Setting this option to `False` enables a different |
| 371 | + bias correction method which is more accurate, in general, and takes |
| 372 | + more time but requires less memory. None disables the bias correction algorithm. |
| 373 | + :param smooth_quant: Setting this option to `True` enables the SmoothQuant algorithm. |
| 374 | + :param kwargs: The keyword arguments for the nncf quantize_pt2e function. |
361 | 375 | :return: The quantized model as a torch.fx.GraphModule. |
362 | 376 | """ |
363 | 377 | quantizer = OpenVINOQuantizer() |
364 | 378 |
|
365 | 379 | print("PTQ: Quantize the model") |
366 | | - default_subset_size = 300 |
367 | | - batch_size = calibration_dataset.batch_size |
368 | | - subset_size = (default_subset_size // batch_size) + int( |
369 | | - default_subset_size % batch_size > 0 |
370 | | - ) |
371 | 380 |
|
372 | | - def transform(x): |
373 | | - return x[0] |
| 381 | + if "fold_quantize" not in kwargs: |
| 382 | + kwargs["fold_quantize"] = False |
374 | 383 |
|
375 | 384 | quantized_model = nncf_fx.quantize_pt2e( |
376 | 385 | captured_model, |
377 | 386 | quantizer, |
378 | 387 | subset_size=subset_size, |
379 | | - calibration_dataset=nncf.Dataset(calibration_dataset, transform_func=transform), |
380 | | - fold_quantize=False, |
| 388 | + calibration_dataset=nncf.Dataset(calibration_dataset, transform_fn), |
| 389 | + fast_bias_correction=fast_bias_correction, |
| 390 | + smooth_quant=smooth_quant, |
| 391 | + **kwargs |
381 | 392 | ) |
382 | 393 | return quantized_model |
0 commit comments