Skip to content

Commit 4c2b568

Browse files
Sanity test and code cleaning
1 parent 4641f8d commit 4c2b568

File tree

13 files changed

+519
-1538
lines changed

13 files changed

+519
-1538
lines changed

examples/post_training_quantization/openvino/yolov8/main.py

Lines changed: 55 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import time
1919
from copy import deepcopy
2020
from pathlib import Path
21-
from typing import Any, Dict, Tuple
21+
from typing import Dict, Tuple
2222

2323
import numpy as np
2424
import openvino as ov
@@ -29,13 +29,10 @@
2929
from torch.fx.passes.graph_drawer import FxGraphDrawer
3030
from tqdm import tqdm
3131
from ultralytics.cfg import get_cfg
32-
from ultralytics.data.converter import coco80_to_coco91_class
3332
from ultralytics.data.utils import check_det_dataset
3433
from ultralytics.engine.validator import BaseValidator as Validator
3534
from ultralytics.models.yolo import YOLO
36-
from ultralytics.utils import DATASETS_DIR
3735
from ultralytics.utils import DEFAULT_CFG
38-
from ultralytics.utils.metrics import ConfusionMatrix
3936
from ultralytics.utils.torch_utils import de_parallel
4037

4138
import nncf
@@ -55,15 +52,18 @@ def measure_time(model, example_inputs, num_iters=500):
5552
return average_time
5653

5754

58-
def validate_fx_ult_method(model: ov.Model) -> Tuple[Dict, int, int]:
59-
"""
60-
Uses .val ultralitics method instead of a dataloader loop.
61-
For some reason this shows better metrics on torch.compiled models
62-
"""
63-
yolo = YOLO(f"{ROOT}/{MODEL_NAME}.pt")
64-
yolo.model = model
65-
result = yolo.val(data="coco128.yaml", batch=1, rect=False)
66-
return result.results_dict
55+
def measure_time_ov(model, example_inputs, num_iters=1000):
56+
ie = ov.Core()
57+
compiled_model = ie.compile_model(model, "CPU")
58+
infer_request = compiled_model.create_infer_request()
59+
infer_request.infer(example_inputs)
60+
total_time = 0
61+
for i in range(0, num_iters):
62+
start_time = time.time()
63+
infer_request.infer(example_inputs)
64+
total_time += time.time() - start_time
65+
average_time = (total_time / num_iters) * 1000
66+
return average_time
6767

6868

6969
def validate_fx(
@@ -100,10 +100,10 @@ def print_statistics_short(stats: np.ndarray) -> None:
100100
def validate_ov(
101101
model: ov.Model, data_loader: torch.utils.data.DataLoader, validator: Validator, num_samples: int = None
102102
) -> Tuple[Dict, int, int]:
103-
validator.seen = 0
104-
validator.jdict = []
105-
validator.stats = []
106-
validator.confusion_matrix = ConfusionMatrix(nc=validator.nc)
103+
# validator.seen = 0
104+
# validator.jdict = []
105+
# validator.stats = []
106+
# validator.confusion_matrix = ConfusionMatrix(nc=validator.nc)
107107
model.reshape({0: [1, 3, -1, -1]})
108108
compiled_model = ov.compile_model(model)
109109
output_layer = compiled_model.output(0)
@@ -131,7 +131,7 @@ def print_statistics(stats: np.ndarray, total_images: int, total_objects: int) -
131131
print(pf % ("all", total_images, total_objects, mp, mr, map50, mean_ap))
132132

133133

134-
def prepare_validation_new(model: YOLO, data: str) -> Tuple[Validator, torch.utils.data.DataLoader]:
134+
def prepare_validation(model: YOLO, data: str) -> Tuple[Validator, torch.utils.data.DataLoader]:
135135
# custom = {"rect": True, "batch": 1} # method defaults
136136
# rect: false forces to resize all input pictures to one size
137137
custom = {"rect": False, "batch": 1} # method defaults
@@ -148,25 +148,6 @@ def prepare_validation_new(model: YOLO, data: str) -> Tuple[Validator, torch.uti
148148
return validator, data_loader
149149

150150

151-
def prepare_validation(model: YOLO, args: Any) -> Tuple[Validator, torch.utils.data.DataLoader]:
152-
validator = model.smart_load("validator")(args)
153-
validator.data = check_det_dataset(args.data)
154-
dataset = validator.data["val"]
155-
print(f"{dataset}")
156-
157-
data_loader = validator.get_dataloader(f"{DATASETS_DIR}/coco128", 1)
158-
159-
validator = model.smart_load("validator")(args)
160-
161-
validator.is_coco = True
162-
validator.class_map = coco80_to_coco91_class()
163-
validator.names = model.model.names
164-
validator.metrics.names = validator.names
165-
validator.nc = model.model.model[-1].nc
166-
167-
return validator, data_loader
168-
169-
170151
def benchmark_performance(model_path, config) -> float:
171152
command = f"benchmark_app -m {model_path} -d CPU -api async -t 30"
172153
command += f' -shape "[1,3,{config.imgsz},{config.imgsz}]"'
@@ -221,7 +202,7 @@ def transform_fn(data_item: Dict):
221202
return quantized_model
222203

223204

224-
NNCF_QUANTIZATION = True
205+
NNCF_QUANTIZATION = False
225206

226207

227208
def quantize_impl(exported_model, val_loader, validator):
@@ -290,26 +271,25 @@ def main():
290271
# args.data = "coco128.yaml"
291272
# Prepare validation dataset and helper
292273

293-
validator, data_loader = prepare_validation_new(model, "coco128.yaml")
274+
validator, data_loader = prepare_validation(model, "coco128.yaml")
294275

295276
# Convert to OpenVINO model
296-
if TORCH_FX:
297-
batch = next(iter(data_loader))
298-
batch = validator.preprocess(batch)
277+
batch = next(iter(data_loader))
278+
batch = validator.preprocess(batch)
299279

280+
if TORCH_FX:
300281
fp_stats, total_images, total_objects = validate_fx(model.model, tqdm(data_loader), validator)
301282
print("Floating-point Torch model validation results:")
302283
print_statistics(fp_stats, total_images, total_objects)
303284

304-
fp32_compiled_model = torch.compile(model.model, backend="openvino")
285+
if NNCF_QUANTIZATION:
286+
fp32_compiled_model = torch.compile(model.model, backend="openvino")
287+
else:
288+
fp32_compiled_model = torch.compile(model.model)
305289
fp32_stats, total_images, total_objects = validate_fx(fp32_compiled_model, tqdm(data_loader), validator)
306290
print("FP32 FX model validation results:")
307291
print_statistics(fp32_stats, total_images, total_objects)
308292

309-
# result = validate_fx_ult_method(fp32_compiled_model)
310-
# print("FX FP32 model .val validation")
311-
# print_statistics_short(result)
312-
313293
print("Start quantization...")
314294
# Rebuild model to reset ultralitics cache
315295
model = YOLO(f"{ROOT}/{MODEL_NAME}.pt")
@@ -323,10 +303,6 @@ def main():
323303
)
324304
quantized_model = quantize_impl(deepcopy(exported_model), data_loader, validator)
325305

326-
# result = validate_fx_ult_method(quantized_model)
327-
# print("FX INT8 model .val validation")
328-
# print_statistics_short(result)
329-
330306
int8_stats, total_images, total_objects = validate_fx(quantized_model, tqdm(data_loader), validator)
331307
print("INT8 FX model validation results:")
332308
print_statistics(int8_stats, total_images, total_objects)
@@ -360,35 +336,52 @@ def main():
360336
print("Quantized model validation results:")
361337
print_statistics(q_stats, total_images, total_objects)
362338

363-
# Benchmark performance of FP32 model
364-
fp_model_perf = benchmark_performance(ov_model_path, args)
365-
print(f"Floating-point model performance: {fp_model_perf} FPS")
366-
367-
# Benchmark performance of quantized model
368-
quantized_model_perf = benchmark_performance(quantized_model_path, args)
369-
print(f"Quantized model performance: {quantized_model_perf} FPS")
339+
fps = True
340+
latency = True
341+
fp_model_perf = -1
342+
quantized_model_perf = -1
343+
if fps:
344+
# Benchmark performance of FP32 model
345+
fp_model_perf = benchmark_performance(ov_model_path, args)
346+
print(f"Floating-point model performance: {fp_model_perf} FPS")
347+
348+
# Benchmark performance of quantized model
349+
quantized_model_perf = benchmark_performance(quantized_model_path, args)
350+
print(f"Quantized model performance: {quantized_model_perf} FPS")
351+
if latency:
352+
fp_model_latency = measure_time_ov(ov_model, batch["img"])
353+
print(f"FP32 OV model latency: {fp_model_latency}")
354+
int8_model_latency = measure_time_ov(quantized_model, batch["img"])
355+
print(f"INT8 OV model latency: {int8_model_latency}")
370356

371357
return fp_stats["metrics/mAP50-95(B)"], q_stats["metrics/mAP50-95(B)"], fp_model_perf, quantized_model_perf
372358

373359

374-
def check_export_not_strict():
360+
def main_export_not_strict():
375361
model = YOLO(f"{ROOT}/{MODEL_NAME}.pt")
376362

377363
# Prepare validation dataset and helper
378-
validator, data_loader = prepare_validation_new(model, "coco128.yaml")
364+
validator, data_loader = prepare_validation(model, "coco128.yaml")
379365

380366
batch = next(iter(data_loader))
381367
batch = validator.preprocess(batch)
382368

383369
model.model(batch["img"])
384370
ex_model = torch.export.export(model.model, args=(batch["img"],), strict=False)
385371
ex_model = capture_pre_autograd_graph(ex_model.module(), args=(batch["img"],))
372+
ex_model = torch.compile(ex_model)
386373

387374
fp_stats, total_images, total_objects = validate_fx(ex_model, tqdm(data_loader), validator)
388375
print("Floating-point ex strict=False")
389376
print_statistics(fp_stats, total_images, total_objects)
390377

378+
quantized_model = quantize_impl(deepcopy(ex_model), data_loader, validator)
379+
int8_stats, total_images, total_objects = validate_fx(quantized_model, tqdm(data_loader), validator)
380+
print("Int8 ex strict=False")
381+
print_statistics(int8_stats, total_images, total_objects)
382+
# No quantized were inserted, metrics are OK
383+
391384

392385
if __name__ == "__main__":
393-
check_export_not_strict()
394-
# main()
386+
# main_export_not_strict()
387+
main()

nncf/experimental/torch_fx/model_transformer.py

Lines changed: 1 addition & 152 deletions
Original file line numberDiff line numberDiff line change
@@ -10,48 +10,23 @@
1010
# limitations under the License.
1111

1212
from collections import defaultdict
13-
from dataclasses import dataclass
1413

1514
# from functools import partial
16-
from typing import Callable, List, Optional, Union
15+
from typing import Callable, List, Union
1716

1817
import torch
1918
import torch.fx
20-
from torch.ao.quantization.fx.utils import create_getattr_from_value
21-
from torch.ao.quantization.pt2e.duplicate_dq_pass import DuplicateDQPass
22-
from torch.ao.quantization.pt2e.port_metadata_pass import PortNodeMetaForQDQ
23-
from torch.ao.quantization.pt2e.qat_utils import _fold_conv_bn_qat
24-
from torch.ao.quantization.pt2e.utils import _disallow_eval_train
25-
from torch.ao.quantization.pt2e.utils import _fuse_conv_bn_
26-
from torch.fx import GraphModule
27-
from torch.fx.passes.infra.pass_manager import PassManager
2819
from torch.fx.passes.split_utils import split_by_tags
2920

3021
from nncf.common.graph.model_transformer import ModelTransformer
31-
32-
# from nncf.torch.graph.transformations.commands import PTModelExtractionCommand
33-
# from nncf.common.graph.transformations.commands import TransformationPriority
3422
from nncf.common.graph.transformations.commands import Command
3523
from nncf.common.graph.transformations.commands import TargetType
3624
from nncf.common.graph.transformations.commands import TransformationPriority
3725
from nncf.common.graph.transformations.commands import TransformationType
38-
39-
# from torch import Tensor
40-
# from torch import nn
4126
from nncf.torch.graph.transformations.commands import PTModelExtractionCommand
4227
from nncf.torch.graph.transformations.commands import PTTargetPoint
43-
44-
# from nncf.torch.graph.transformations.commands import PTTargetPoint
45-
# from nncf.torch.graph.transformations.commands import PTWeightUpdateCommand
4628
from nncf.torch.graph.transformations.layout import PTTransformationLayout
4729

48-
# from torch.nn.parameter import Parameter
49-
# from nncf.torch.model_graph_manager import update_fused_bias
50-
# from nncf.torch.nncf_network import PTInsertionPoint
51-
# from nncf.torch.nncf_network import compression_module_type_to_attr_name
52-
# from nncf.torch.utils import get_model_device
53-
# from nncf.torch.utils import is_multidevice
54-
5530

5631
class FXModuleInsertionCommand(Command):
5732
def __init__(
@@ -206,129 +181,3 @@ def _apply_transformation(
206181
for transformation in transformations:
207182
transformation.tranformation_fn(model)
208183
return model
209-
210-
211-
@dataclass
212-
class QPARAMSPerTensor:
213-
scale: float
214-
zero_point: int
215-
quant_min: int
216-
quant_max: int
217-
dtype: torch.dtype
218-
219-
220-
@dataclass
221-
class QPARAMPerChannel:
222-
scales: torch.Tensor
223-
zero_points: Optional[torch.Tensor]
224-
axis: int
225-
quant_min: int
226-
quant_max: int
227-
dtype: torch.dtype
228-
229-
230-
def insert_qdq_to_model(model: torch.fx.GraphModule, qsetup) -> torch.fx.GraphModule:
231-
# from prepare
232-
_fuse_conv_bn_(model)
233-
234-
# from convert
235-
original_graph_meta = model.meta
236-
_insert_qdq_to_model(model, qsetup)
237-
238-
# Magic. Without this call compiled model
239-
# is not preformant
240-
model = GraphModule(model, model.graph)
241-
242-
model = _fold_conv_bn_qat(model)
243-
pm = PassManager([DuplicateDQPass()])
244-
245-
model = pm(model).graph_module
246-
pm = PassManager([PortNodeMetaForQDQ()])
247-
model = pm(model).graph_module
248-
249-
model.meta.update(original_graph_meta)
250-
model = _disallow_eval_train(model)
251-
return model
252-
253-
254-
def _insert_qdq_to_model(model: torch.fx.GraphModule, qsetup) -> torch.fx.GraphModule:
255-
for idx, node in enumerate(list(model.graph.nodes)):
256-
if node.name not in qsetup:
257-
continue
258-
# 1. extract information for inserting q/dq node from activation_post_process
259-
params = qsetup[node.name]
260-
node_type = "call_function"
261-
quantize_op: Optional[Callable] = None
262-
# scale, zero_point = activation_post_process.calculate_qparams() # type: ignore[attr-defined, operator]
263-
if isinstance(params, QPARAMPerChannel):
264-
quantize_op = torch.ops.quantized_decomposed.quantize_per_channel.default
265-
dequantize_op = torch.ops.quantized_decomposed.dequantize_per_channel.default
266-
qparams = {
267-
"_scale_": params.scales,
268-
"_zero_point_": params.zero_points,
269-
"_axis_": params.axis,
270-
"_quant_min_": params.quant_min,
271-
"_quant_max_": params.quant_max,
272-
"_dtype_": params.dtype,
273-
}
274-
elif isinstance(params, QPARAMSPerTensor):
275-
quantize_op = torch.ops.quantized_decomposed.quantize_per_tensor.default
276-
dequantize_op = torch.ops.quantized_decomposed.dequantize_per_tensor.default
277-
qparams = {
278-
"_scale_": params.scale,
279-
"_zero_point_": params.zero_point,
280-
"_quant_min_": params.quant_min,
281-
"_quant_max_": params.quant_max,
282-
"_dtype_": params.dtype,
283-
}
284-
285-
else:
286-
raise RuntimeError(f"params {params} are unknown")
287-
# 2. replace activation_post_process node with quantize and dequantize
288-
graph = model.graph
289-
290-
# TODO: use metatype to get correct input_port_id
291-
# Do not quantize already quantized nodes
292-
# inserting_before handle only order in the graph generated code.
293-
# so, inserting quantize-dequantize and all constant nodes before the usage of the nodes
294-
with graph.inserting_before(node):
295-
quantize_op_inputs = [node]
296-
for key, value_or_node in qparams.items():
297-
# TODO: we can add the information of whether a value needs to
298-
# be registered as an attribute in qparams dict itself
299-
if key in ["_scale_", "_zero_point_"] and (not isinstance(value_or_node, (float, int))):
300-
# For scale and zero_point values we register them as buffers in the root module.
301-
# However, note that when the values are not tensors, as in the case of
302-
# per_tensor quantization, they will be treated as literals.
303-
# However, registering them as a node seems to cause issue with dynamo
304-
# tracing where it may consider tensor overload as opposed to default.
305-
# With extra check of scale and zero_point being scalar, it makes
306-
# sure that the default overload can be used.
307-
# TODO: maybe need more complex attr name here
308-
qparam_node = create_getattr_from_value(model, graph, str(idx) + key, value_or_node)
309-
quantize_op_inputs.append(qparam_node)
310-
else:
311-
# for qparams that are not scale/zero_point (like axis, dtype) we store
312-
# them as literals in the graph.
313-
quantize_op_inputs.append(value_or_node)
314-
315-
with graph.inserting_after(node):
316-
quantized_node = graph.create_node(node_type, quantize_op, tuple(quantize_op_inputs), {})
317-
# use the same qparams from quantize op
318-
dq_inputs = [quantized_node] + quantize_op_inputs[1:]
319-
user_dq_nodes = []
320-
with graph.inserting_after(quantized_node):
321-
for user in node.users:
322-
if user is quantized_node:
323-
continue
324-
user_dq_nodes.append((user, graph.call_function(dequantize_op, tuple(dq_inputs), {})))
325-
326-
for user, dq_node in user_dq_nodes:
327-
user.replace_input_with(node, dq_node)
328-
329-
# node.replace_all_uses_with(dequantized_node)
330-
# graph.erase_node(node)
331-
from torch.fx.passes.graph_drawer import FxGraphDrawer
332-
333-
g = FxGraphDrawer(model, "model_after_qdq_insertion")
334-
g.get_dot_graph().write_svg("model_after_qdq_insertion.svg")

0 commit comments

Comments
 (0)