Skip to content

Commit 09082ee

Browse files
WIP quantization
1 parent 01d12f5 commit 09082ee

File tree

12 files changed

+347
-293
lines changed

12 files changed

+347
-293
lines changed

backends/openvino/quantizer/quantizer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def set_ignored_scope(
8888
names: Optional[List[str]] = None,
8989
patterns: Optional[List[str]] = None,
9090
types: Optional[List[str]] = None,
91-
subgraphs: Optional[List[Tuple[List[str], List[str]]]] = None,
91+
subgraphs: Optional[List[nncf.Subgraph]] = None,
9292
validate: bool = True,
9393
) -> None:
9494
"""
@@ -107,7 +107,7 @@ def set_ignored_scope(
107107
names=names or [],
108108
patterns=patterns or [],
109109
types=types or [],
110-
subgraphs=subgraphs or [],
110+
subgraphs=subgraphs or nncf.Subgraph(),
111111
validate=validate,
112112
)
113113
)

examples/yolo12_/CMakeLists.txt renamed to examples/models/yolo12/CMakeLists.txt

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
cmake_minimum_required(VERSION 3.5)
22

3-
project(Yolov8CPPInference VERSION 0.1)
3+
project(Yolo12DetectionDemo VERSION 0.1)
44

55
option(USE_OPENVINO_BACKEND "Build the tutorial with the XNNPACK backend" ON)
66
option(USE_XNNPACK_BACKEND "Build the tutorial with the XNNPACK backend" OFF)
@@ -20,7 +20,7 @@ if(NOT PYTHON_EXECUTABLE)
2020
set(PYTHON_EXECUTABLE python3)
2121
endif()
2222

23-
set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../..)
23+
set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../..)
2424
set(TORCH_ROOT ${EXECUTORCH_ROOT}/third-party/pytorch)
2525

2626
include(${EXECUTORCH_ROOT}/tools/cmake/Utils.cmake)
@@ -29,8 +29,7 @@ include(${EXECUTORCH_ROOT}/tools/cmake/Utils.cmake)
2929
set(_common_include_directories ${EXECUTORCH_ROOT}/..)
3030

3131
# find `executorch` libraries Same as for gflags
32-
set(executorch_DIR ${CMAKE_CURRENT_BINARY_DIR}/../../lib/cmake/ExecuTorch)
33-
find_package(executorch CONFIG REQUIRED)
32+
find_package(executorch CONFIG REQUIRED PATHS ${EXECUTORCH_ROOT}/cmake-out)
3433
target_link_options_shared_lib(executorch)
3534

3635
set(link_libraries gflags)
@@ -78,8 +77,8 @@ set(PROJECT_SOURCES
7877
${EXECUTORCH_ROOT}/extension/runner_util/inputs_portable.cpp
7978
)
8079

81-
add_executable(Yolov8CPPInference ${PROJECT_SOURCES})
82-
target_link_libraries(Yolov8CPPInference PUBLIC
80+
add_executable(Yolo12DetectionDemo ${PROJECT_SOURCES})
81+
target_link_libraries(Yolo12DetectionDemo PUBLIC
8382
${link_libraries}
8483
${OpenCV_LIBS}
8584
executorch_core
@@ -88,5 +87,5 @@ target_link_libraries(Yolov8CPPInference PUBLIC
8887
)
8988

9089
find_package(Threads REQUIRED)
91-
target_link_libraries(Yolov8CPPInference PRIVATE Threads::Threads)
92-
target_include_directories(Yolov8CPPInference PUBLIC ${_common_include_directories})
90+
target_link_libraries(Yolo12DetectionDemo PRIVATE Threads::Threads)
91+
target_include_directories(Yolo12DetectionDemo PUBLIC ${_common_include_directories})
File renamed without changes.

examples/models/yolo12/build.sh

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
rm -r build
2+
mkdir build && cd build
3+
cmake -DCMAKE_BUILD_TYPE=Debug -DUSE_XNNPACK_BACKEND=ON -DUSE_OPENVINO_BACKEND=ON ..
4+
make -j 30
Lines changed: 325 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,325 @@
1+
# Copyright (c) Intel Corporation
2+
#
3+
# Licensed under the BSD License (the "License"); you may not use this file
4+
# except in compliance with the License. See the license file found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# mypy: disable-error-code="import-untyped,import-not-found"
8+
9+
10+
import argparse
11+
from itertools import islice
12+
from typing import Any, Iterator, Tuple
13+
14+
import cv2
15+
import executorch
16+
import nncf.torch
17+
import numpy as np
18+
import torch
19+
from executorch.backends.openvino.partitioner import OpenvinoPartitioner
20+
from executorch.backends.openvino.quantizer import OpenVINOQuantizer
21+
from executorch.backends.openvino.quantizer.quantizer import QuantizationMode
22+
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
23+
from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import (
24+
get_symmetric_quantization_config,
25+
XNNPACKQuantizer,
26+
)
27+
from executorch.exir import (
28+
EdgeCompileConfig,
29+
EdgeProgramManager,
30+
ExecutorchBackendConfig,
31+
ExecutorchProgramManager,
32+
to_edge_transform_and_lower,
33+
)
34+
from executorch.exir.backend.backend_details import CompileSpec
35+
from nncf.experimental.torch.fx import quantize_pt2e
36+
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
37+
from torch.export.exported_program import ExportedProgram
38+
from torch.fx.passes.graph_drawer import FxGraphDrawer
39+
from ultralytics import YOLO
40+
41+
42+
class CV2VideoIter:
43+
def __init__(self, cap) -> None:
44+
self._cap = cap
45+
46+
def __iter__(self):
47+
return self
48+
49+
def __next__(self):
50+
success, frame = self._cap.read()
51+
if not success:
52+
raise StopIteration()
53+
return frame
54+
55+
def __len__(self):
56+
return int(self._cap.get(cv2.CAP_PROP_FRAME_COUNT))
57+
58+
59+
class CV2VideoDataset(torch.utils.data.IterableDataset):
60+
def __init__(self, cap) -> None:
61+
super().__init__()
62+
self._iter = CV2VideoIter(cap)
63+
64+
def __iter__(self) -> Iterator:
65+
return self._iter
66+
67+
def __len__(self):
68+
return len(self._iter)
69+
70+
71+
def visualize_fx_model(model: torch.fx.GraphModule, output_svg_path: str):
72+
g = FxGraphDrawer(model, output_svg_path)
73+
g.get_dot_graph().write_svg(output_svg_path)
74+
75+
76+
def lower_to_openvino(
77+
aten_dialect: ExportedProgram,
78+
example_args: Tuple[Any, ...],
79+
transform_fn: callable,
80+
device: str,
81+
calibration_dataset: CV2VideoDataset,
82+
subset_size: int,
83+
quantize: bool,
84+
) -> ExecutorchProgramManager:
85+
if quantize:
86+
target_input_dims = tuple(example_args[0].shape[2:])
87+
88+
def ext_transform_fn(sample):
89+
sample = transform_fn(sample)
90+
return pad_to_target(sample, target_input_dims)
91+
92+
quantizer = OpenVINOQuantizer(mode=QuantizationMode.INT8_TRANSFORMER)
93+
quantizer.set_ignored_scope(
94+
types=["mul", "sub", "sigmoid", "__getitem__"],
95+
subgraphs=[nncf.Subgraph(inputs=["cat_18"], outputs=["output"])]
96+
)
97+
quantized_model = quantize_pt2e(
98+
aten_dialect.module(),
99+
quantizer,
100+
nncf.Dataset(calibration_dataset, ext_transform_fn),
101+
subset_size=subset_size,
102+
smooth_quant=True,
103+
fold_quantize=False
104+
)
105+
106+
visualize_fx_model(quantized_model, "tmp_quantized_model.svg")
107+
aten_dialect = torch.export.export(quantized_model, example_args)
108+
# Convert to edge dialect and lower the module to the backend with a custom partitioner
109+
compile_spec = [CompileSpec("device", device.encode())]
110+
lowered_module: EdgeProgramManager = to_edge_transform_and_lower(
111+
aten_dialect,
112+
partitioner=[
113+
OpenvinoPartitioner(compile_spec),
114+
],
115+
compile_config=EdgeCompileConfig(
116+
_skip_dim_order=True,
117+
),
118+
)
119+
120+
# Apply backend-specific passes
121+
return lowered_module.to_executorch(
122+
config=executorch.exir.ExecutorchBackendConfig()
123+
)
124+
125+
126+
def lower_to_xnnpack(
127+
aten_dialect: ExportedProgram,
128+
example_args: Tuple[Any, ...],
129+
transform_fn: callable,
130+
device: str,
131+
calibration_dataset: CV2VideoDataset,
132+
subset_size: int,
133+
quantize: bool,
134+
) -> ExecutorchProgramManager:
135+
if quantize:
136+
quantizer = XNNPACKQuantizer()
137+
operator_config = get_symmetric_quantization_config(
138+
is_per_channel=False,
139+
is_dynamic=False,
140+
)
141+
quantizer.set_global(operator_config)
142+
m = prepare_pt2e(aten_dialect.module(), quantizer)
143+
# calibration
144+
target_input_dims = tuple(example_args[0].shape[2:])
145+
print("Start quantization...")
146+
for sample in islice(calibration_dataset, subset_size):
147+
sample = transform_fn(sample)
148+
sample = pad_to_target(sample, target_input_dims)
149+
m(sample)
150+
m = convert_pt2e(m)
151+
print("Quantized succsessfully!")
152+
aten_dialect = torch.export.export(m, example_args)
153+
154+
edge = to_edge_transform_and_lower(
155+
aten_dialect,
156+
partitioner=[XnnpackPartitioner()],
157+
compile_config=EdgeCompileConfig(
158+
_check_ir_validity=False if args.quantize else True,
159+
_skip_dim_order=True, # TODO(T182187531): enable dim order in xnnpack
160+
),
161+
)
162+
163+
return edge.to_executorch(
164+
config=ExecutorchBackendConfig(extract_delegate_segments=False)
165+
)
166+
167+
168+
def pad_to_target(
169+
image: torch.Tensor,
170+
target_size: Tuple[int, int],
171+
):
172+
if image.shape[2:] == target_size:
173+
return image
174+
img_h, img_w = image.shape[2:]
175+
target_h, target_w = target_size
176+
177+
diff_h = target_h - img_h
178+
pad_h_from = diff_h // 2
179+
pad_h_to = -(pad_h_from + diff_h % 2) or None
180+
diff_w = target_w - img_w
181+
pad_w_from = diff_w // 2
182+
pad_w_to = -(pad_w_from + diff_w % 2) or None
183+
184+
result = torch.zeros(
185+
(
186+
1,
187+
3,
188+
)
189+
+ target_size,
190+
device=image.device,
191+
dtype=image.dtype,
192+
)
193+
result[:, :, pad_h_from:pad_h_to, pad_w_from:pad_w_to] = image
194+
return result
195+
196+
197+
def main(
198+
model_name: str,
199+
input_dims: Tuple[int, int],
200+
quantize: bool,
201+
video_path: str,
202+
subset_size: int,
203+
backend: str,
204+
device: str,
205+
):
206+
"""
207+
Main function to load, quantize, and export an Yolo model model.
208+
209+
:param model_name: The name of the YOLO model to load.
210+
:param quantize: Whether to quantize the model.
211+
:param video_path: Path to the video to use for the calibration
212+
:param backend: The Executorch inference backend (e.g., "openvino", "xnnpack").
213+
:param device: The device to run the model on (e.g., "cpu", "gpu").
214+
"""
215+
216+
# Load the selected model
217+
model = YOLO(model_name)
218+
219+
if quantize:
220+
if video_path is None:
221+
raise RuntimeError(
222+
"Could not quantize model without the video for the calibration."
223+
" --video_path parameter is needed."
224+
)
225+
cap = cv2.VideoCapture(video_path, cv2.CAP_FFMPEG)
226+
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
227+
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
228+
print(f"Calibration video dims: h: {height} w: {width}")
229+
calibration_dataset = CV2VideoDataset(cap)
230+
else:
231+
calibration_dataset = None
232+
233+
# Setup pre-processing
234+
np_dummy_tensor = np.ones((input_dims[0], input_dims[1], 3))
235+
model.predict(np_dummy_tensor, imgsz=((input_dims[0], input_dims[1])), device="cpu")
236+
237+
pt_model = model.model.to(torch.device("cpu"))
238+
239+
def transform_fn(frame):
240+
input_tensor = model.predictor.preprocess([frame])
241+
return input_tensor
242+
243+
example_args = (transform_fn(np_dummy_tensor),)
244+
with torch.no_grad():
245+
aten_dialect = torch.export.export(pt_model, args=example_args)
246+
247+
if backend == "openvino":
248+
lower_fn = lower_to_openvino
249+
elif backend == "xnnpack":
250+
lower_fn = lower_to_xnnpack
251+
252+
exec_prog = lower_fn(
253+
aten_dialect=aten_dialect,
254+
example_args=example_args,
255+
transform_fn=transform_fn,
256+
device=device,
257+
calibration_dataset=calibration_dataset,
258+
subset_size=subset_size,
259+
quantize=quantize,
260+
)
261+
262+
model_file_name = f"{model_name}_{'int8' if quantize else 'fp32'}_{backend}.pte"
263+
with open(model_file_name, "wb") as file:
264+
exec_prog.write_to_file(file)
265+
print(f"Model exported and saved as {model_file_name} on {device}.")
266+
267+
268+
if __name__ == "__main__":
269+
parser = argparse.ArgumentParser(
270+
description="Export FP32 and INT8 Ultralytics Yolo models with executorch."
271+
)
272+
parser.add_argument(
273+
"--model_name",
274+
type=str,
275+
default="yolo12s",
276+
help="Ultralytics yolo model name.",
277+
)
278+
parser.add_argument(
279+
"--input_dims",
280+
type=eval,
281+
default=[640, 640],
282+
help="Input model dimensions in format [hight, weight] or (hight, weight). Default models dimensions are [640, 640]",
283+
)
284+
parser.add_argument(
285+
"--video_path",
286+
type=str,
287+
help="Path to the input video file to use for the quantization callibration.",
288+
)
289+
parser.add_argument(
290+
"--quantize", action="store_true", help="Enable model quantization."
291+
)
292+
parser.add_argument(
293+
"--subset_size",
294+
type=int,
295+
default=300,
296+
help="Subset size for the quantized model calibration. The default value is 300.",
297+
)
298+
parser.add_argument(
299+
"--backend",
300+
type=str,
301+
default="openvino",
302+
choices=["openvino", "xnnpack"],
303+
help="Select the Executorch inference backend (openvino, xnnpack). openvino by default.",
304+
)
305+
parser.add_argument(
306+
"--device",
307+
type=str,
308+
default="CPU",
309+
help="Target device for compiling the model (e.g., CPU, GPU). Default is CPU.",
310+
)
311+
312+
args = parser.parse_args()
313+
314+
# Run the main function with parsed arguments
315+
# Disable nncf patching as export of the patched model is not supported.
316+
with nncf.torch.disable_patching():
317+
main(
318+
model_name=args.model_name,
319+
input_dims=args.input_dims,
320+
quantize=args.quantize,
321+
video_path=args.video_path,
322+
subset_size=args.subset_size,
323+
backend=args.backend,
324+
device=args.device,
325+
)

0 commit comments

Comments
 (0)