Skip to content

Commit 4685a37

Browse files
quantize_model cleanup
1 parent cff39cd commit 4685a37

File tree

2 files changed

+38
-13
lines changed

2 files changed

+38
-13
lines changed

backends/openvino/quantizer/quantizer.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from collections import defaultdict
88
from enum import Enum
9-
from typing import Dict, List, Optional, Tuple
9+
from typing import Dict, List, Optional, Tuple, Callable, Any
1010

1111
import nncf
1212
import nncf.common.quantization as quantization
@@ -351,32 +351,43 @@ def transform_for_annotation(
351351

352352
def quantize_model(
353353
captured_model: torch.fx.GraphModule,
354+
quantizer: Quantizer,
354355
calibration_dataset: torch.utils.data.DataLoader,
356+
subset_size: int,
357+
fast_bias_correction: Optional[bool] = True,
358+
smooth_quant: bool = False,
359+
transform_fn: Optional[Callable[[Any], Any]]= None,
360+
**kwargs,
355361
) -> torch.fx.GraphModule:
356362
"""
357-
Quantizes a model using either NNCF-based or PTQ-based quantization.
363+
Quantizes a model using NNCF quantize_pt2e API.
358364
359365
:param captured_model: The model to be quantized, represented as a torch.fx.GraphModule.
366+
:param quantizer: Torch ao quantizer to annotate nodes in the graph with quantization setups
360367
:param calibration_dataset: A DataLoader containing calibration data for quantization.
368+
:param subset_size: Size of a subset to calculate activations
369+
statistics used for quantization.
370+
:param fast_bias_correction: Setting this option to `False` enables a different
371+
bias correction method which is more accurate, in general, and takes
372+
more time but requires less memory. None disables the bias correction algorithm.
373+
:param smooth_quant: Setting this option to `True` enables the SmoothQuant algorithm.
374+
:param kwargs: The keyword arguments for the nncf quantize_pt2e function.
361375
:return: The quantized model as a torch.fx.GraphModule.
362376
"""
363377
quantizer = OpenVINOQuantizer()
364378

365379
print("PTQ: Quantize the model")
366-
default_subset_size = 300
367-
batch_size = calibration_dataset.batch_size
368-
subset_size = (default_subset_size // batch_size) + int(
369-
default_subset_size % batch_size > 0
370-
)
371380

372-
def transform(x):
373-
return x[0]
381+
if "fold_quantize" not in kwargs:
382+
kwargs["fold_quantize"] = False
374383

375384
quantized_model = nncf_fx.quantize_pt2e(
376385
captured_model,
377386
quantizer,
378387
subset_size=subset_size,
379-
calibration_dataset=nncf.Dataset(calibration_dataset, transform_func=transform),
380-
fold_quantize=False,
388+
calibration_dataset=nncf.Dataset(calibration_dataset, transform_fn),
389+
fast_bias_correction=fast_bias_correction,
390+
smooth_quant=smooth_quant,
391+
**kwargs
381392
)
382393
return quantized_model

examples/openvino/aot_openvino_compiler.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,10 @@
1313
import torch
1414
import torchvision.models as torchvision_models
1515
from executorch.backends.openvino.partitioner import OpenvinoPartitioner
16-
from executorch.backends.openvino.quantizer.quantizer import quantize_model
16+
from executorch.backends.openvino.quantizer.quantizer import (
17+
OpenVINOQuantizer,
18+
quantize_model,
19+
)
1720
from executorch.exir import EdgeProgramManager, to_edge_transform_and_lower
1821
from executorch.exir.backend.backend_details import CompileSpec
1922
from executorch.extension.pybindings.portable_lib import ( # @manual
@@ -182,9 +185,20 @@ def main(
182185
if not dataset_path:
183186
msg = "Quantization requires a calibration dataset."
184187
raise ValueError(msg)
188+
189+
subset_size = 300
190+
batch_size = calibration_dataset.batch_size
191+
subset_size = (subset_size // batch_size) + int(subset_size % batch_size > 0)
192+
193+
quantizer = OpenVINOQuantizer()
194+
195+
transform_fn = lambda x: x[0]
185196
quantized_model = quantize_model(
186197
aten_dialect.module(),
187-
calibration_dataset,
198+
quantizer=quantizer,
199+
calibration_dataset=calibration_dataset,
200+
subset_size=subset_size,
201+
transform_fn=transform_fn,
188202
)
189203

190204
aten_dialect: ExportedProgram = export(quantized_model, example_args)

0 commit comments

Comments
 (0)