88import os
99import shutil
1010import subprocess
11- from itertools import islice
1211from pathlib import Path
1312
1413import executorch
1514
16- import nncf
15+ import nncf . torch
1716import numpy as np
1817import timm
1918import torch
2019import torchvision .models as torchvision_models
21- from executorch .backends .openvino import OpenVINOQuantizer
2220from executorch .backends .openvino .partitioner import OpenvinoPartitioner
21+ from executorch .backends .openvino .quantizer .quantizer import quantize_model
2322from executorch .exir import EdgeProgramManager , to_edge_transform_and_lower
2423from executorch .exir .backend .backend_details import CompileSpec
25- from nncf .experimental .torch .fx .quantization .quantize_pt2e import quantize_pt2e
2624from sklearn .metrics import accuracy_score
2725from timm .data import resolve_data_config
2826from timm .data .transforms_factory import create_transform
29- from torch .ao .quantization .quantize_pt2e import convert_pt2e , prepare_pt2e
3027from torch .export import export
3128from torch .export .exported_program import ExportedProgram
3229from torchvision import datasets
@@ -129,55 +126,6 @@ def dump_inputs(calibration_dataset, dest_path):
129126 return input_files , targets
130127
131128
132- def quantize_model (
133- captured_model : torch .fx .GraphModule ,
134- calibration_dataset : torch .utils .data .DataLoader ,
135- use_nncf : bool ,
136- ) -> torch .fx .GraphModule :
137- """
138- Quantizes a model using either NNCF-based or PTQ-based quantization.
139-
140- :param captured_model: The model to be quantized, represented as a torch.fx.GraphModule.
141- :param calibration_dataset: A DataLoader containing calibration data for quantization.
142- :param use_nncf: Whether to use NNCF-based quantization (True) or standard PTQ (False).
143- :return: The quantized model as a torch.fx.GraphModule.
144- """
145- quantizer = OpenVINOQuantizer ()
146-
147- print ("PTQ: Quantize the model" )
148- default_subset_size = 300
149- batch_size = calibration_dataset .batch_size
150- subset_size = (default_subset_size // batch_size ) + int (
151- default_subset_size % batch_size > 0
152- )
153-
154- def transform (x ):
155- return x [0 ]
156-
157- if use_nncf :
158-
159- quantized_model = quantize_pt2e (
160- captured_model ,
161- quantizer ,
162- subset_size = subset_size ,
163- calibration_dataset = nncf .Dataset (
164- calibration_dataset , transform_func = transform
165- ),
166- fold_quantize = False ,
167- )
168- else :
169- annotated_model = prepare_pt2e (captured_model , quantizer )
170-
171- print ("PTQ: Calibrate the model..." )
172- for data in islice (calibration_dataset , subset_size ):
173- annotated_model (transform (data ))
174-
175- print ("PTQ: Convert the quantized model..." )
176- quantized_model = convert_pt2e (annotated_model , fold_quantize = False )
177-
178- return quantized_model
179-
180-
181129def validate_model (
182130 model_file_name : str , calibration_dataset : torch .utils .data .DataLoader
183131) -> float :
@@ -231,7 +179,6 @@ def main(
231179 dataset_path : str ,
232180 device : str ,
233181 batch_size : int ,
234- quantization_flow : str ,
235182):
236183 """
237184 Main function to load, quantize, and validate a model.
@@ -244,7 +191,6 @@ def main(
244191 :param dataset_path: Path to the dataset for calibration/validation.
245192 :param device: The device to run the model on (e.g., "cpu", "gpu").
246193 :param batch_size: Batch size for dataset loading.
247- :param quantization_flow: The quantization method to use.
248194 """
249195
250196 # Load the selected model
@@ -281,7 +227,6 @@ def main(
281227 quantized_model = quantize_model (
282228 aten_dialect .module (),
283229 calibration_dataset ,
284- use_nncf = quantization_flow == "nncf" ,
285230 )
286231
287232 aten_dialect : ExportedProgram = export (quantized_model , example_args )
@@ -360,15 +305,6 @@ def main(
360305 default = "CPU" ,
361306 help = "Target device for compiling the model (e.g., CPU, GPU). Default is CPU." ,
362307 )
363- parser .add_argument (
364- "--quantization_flow" ,
365- type = str ,
366- choices = ["pt2e" , "nncf" ],
367- default = "nncf" ,
368- help = "Select the quantization flow (nncf or pt2e):"
369- " pt2e is the default torch.ao quantization flow, while"
370- " nncf is a custom method with additional algorithms to improve model performance." ,
371- )
372308
373309 args = parser .parse_args ()
374310
@@ -384,5 +320,4 @@ def main(
384320 args .dataset ,
385321 args .device ,
386322 args .batch_size ,
387- args .quantization_flow ,
388323 )
0 commit comments