Skip to content

Commit fa8feac

Browse files
Comments
1 parent a28fcf3 commit fa8feac

File tree

2 files changed

+37
-66
lines changed

2 files changed

+37
-66
lines changed

backends/openvino/quantizer/quantizer.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,15 @@
66

77
from collections import defaultdict
88
from enum import Enum
9+
from itertools import islice
910
from typing import Dict, List, Optional, Tuple
1011

1112
import nncf
1213
import nncf.common.quantization as quantization
1314
import nncf.experimental.torch.fx as nncf_fx
1415

1516
import torch.fx
17+
1618
from nncf.common.graph.graph import NNCFGraph
1719
from torch.ao.quantization.observer import HistogramObserver, PerChannelMinMaxObserver
1820
from torch.ao.quantization.quantizer.quantizer import (
@@ -343,5 +345,39 @@ def validate(self, model: torch.fx.GraphModule) -> None:
343345
def transform_for_annotation(
344346
self, model: torch.fx.GraphModule
345347
) -> torch.fx.GraphModule:
348+
# Fold constant branches to avoid their quantization
346349
nncf_fx.transformations.fold_constant_except_qdq(model)
347350
return model
351+
352+
353+
def quantize_model(
354+
captured_model: torch.fx.GraphModule,
355+
calibration_dataset: torch.utils.data.DataLoader,
356+
) -> torch.fx.GraphModule:
357+
"""
358+
Quantizes a model using either NNCF-based or PTQ-based quantization.
359+
360+
:param captured_model: The model to be quantized, represented as a torch.fx.GraphModule.
361+
:param calibration_dataset: A DataLoader containing calibration data for quantization.
362+
:return: The quantized model as a torch.fx.GraphModule.
363+
"""
364+
quantizer = OpenVINOQuantizer()
365+
366+
print("PTQ: Quantize the model")
367+
default_subset_size = 300
368+
batch_size = calibration_dataset.batch_size
369+
subset_size = (default_subset_size // batch_size) + int(
370+
default_subset_size % batch_size > 0
371+
)
372+
373+
def transform(x):
374+
return x[0]
375+
376+
quantized_model = nncf_fx.quantize_pt2e(
377+
captured_model,
378+
quantizer,
379+
subset_size=subset_size,
380+
calibration_dataset=nncf.Dataset(calibration_dataset, transform_func=transform),
381+
fold_quantize=False,
382+
)
383+
return quantized_model

examples/openvino/aot/aot_openvino_compiler.py

Lines changed: 1 addition & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import os
99
import shutil
1010
import subprocess
11-
from itertools import islice
1211
from pathlib import Path
1312

1413
import executorch
@@ -18,15 +17,13 @@
1817
import timm
1918
import torch
2019
import torchvision.models as torchvision_models
21-
from executorch.backends.openvino import OpenVINOQuantizer
2220
from executorch.backends.openvino.partitioner import OpenvinoPartitioner
21+
from executorch.backends.openvino.quantizer.quantizer import quantize_model
2322
from executorch.exir import EdgeProgramManager, to_edge_transform_and_lower
2423
from executorch.exir.backend.backend_details import CompileSpec
25-
from nncf.experimental.torch.fx.quantization.quantize_pt2e import quantize_pt2e
2624
from sklearn.metrics import accuracy_score
2725
from timm.data import resolve_data_config
2826
from timm.data.transforms_factory import create_transform
29-
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
3027
from torch.export import export
3128
from torch.export.exported_program import ExportedProgram
3229
from torchvision import datasets
@@ -129,55 +126,6 @@ def dump_inputs(calibration_dataset, dest_path):
129126
return input_files, targets
130127

131128

132-
def quantize_model(
133-
captured_model: torch.fx.GraphModule,
134-
calibration_dataset: torch.utils.data.DataLoader,
135-
use_nncf: bool,
136-
) -> torch.fx.GraphModule:
137-
"""
138-
Quantizes a model using either NNCF-based or PTQ-based quantization.
139-
140-
:param captured_model: The model to be quantized, represented as a torch.fx.GraphModule.
141-
:param calibration_dataset: A DataLoader containing calibration data for quantization.
142-
:param use_nncf: Whether to use NNCF-based quantization (True) or standard PTQ (False).
143-
:return: The quantized model as a torch.fx.GraphModule.
144-
"""
145-
quantizer = OpenVINOQuantizer()
146-
147-
print("PTQ: Quantize the model")
148-
default_subset_size = 300
149-
batch_size = calibration_dataset.batch_size
150-
subset_size = (default_subset_size // batch_size) + int(
151-
default_subset_size % batch_size > 0
152-
)
153-
154-
def transform(x):
155-
return x[0]
156-
157-
if use_nncf:
158-
159-
quantized_model = quantize_pt2e(
160-
captured_model,
161-
quantizer,
162-
subset_size=subset_size,
163-
calibration_dataset=nncf.Dataset(
164-
calibration_dataset, transform_func=transform
165-
),
166-
fold_quantize=False,
167-
)
168-
else:
169-
annotated_model = prepare_pt2e(captured_model, quantizer)
170-
171-
print("PTQ: Calibrate the model...")
172-
for data in islice(calibration_dataset, subset_size):
173-
annotated_model(transform(data))
174-
175-
print("PTQ: Convert the quantized model...")
176-
quantized_model = convert_pt2e(annotated_model, fold_quantize=False)
177-
178-
return quantized_model
179-
180-
181129
def validate_model(
182130
model_file_name: str, calibration_dataset: torch.utils.data.DataLoader
183131
) -> float:
@@ -231,7 +179,6 @@ def main(
231179
dataset_path: str,
232180
device: str,
233181
batch_size: int,
234-
quantization_flow: str,
235182
):
236183
"""
237184
Main function to load, quantize, and validate a model.
@@ -244,7 +191,6 @@ def main(
244191
:param dataset_path: Path to the dataset for calibration/validation.
245192
:param device: The device to run the model on (e.g., "cpu", "gpu").
246193
:param batch_size: Batch size for dataset loading.
247-
:param quantization_flow: The quantization method to use.
248194
"""
249195

250196
# Load the selected model
@@ -281,7 +227,6 @@ def main(
281227
quantized_model = quantize_model(
282228
aten_dialect.module(),
283229
calibration_dataset,
284-
use_nncf=quantization_flow == "nncf",
285230
)
286231

287232
aten_dialect: ExportedProgram = export(quantized_model, example_args)
@@ -360,15 +305,6 @@ def main(
360305
default="CPU",
361306
help="Target device for compiling the model (e.g., CPU, GPU). Default is CPU.",
362307
)
363-
parser.add_argument(
364-
"--quantization_flow",
365-
type=str,
366-
choices=["pt2e", "nncf"],
367-
default="nncf",
368-
help="Select the quantization flow (nncf or pt2e):"
369-
" pt2e is the default torch.ao quantization flow, while"
370-
" nncf is a custom method with additional algorithms to improve model performance.",
371-
)
372308

373309
args = parser.parse_args()
374310

@@ -384,5 +320,4 @@ def main(
384320
args.dataset,
385321
args.device,
386322
args.batch_size,
387-
args.quantization_flow,
388323
)

0 commit comments

Comments
 (0)