Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit 27b3f47

Browse files
authored
[cherry-pick] post cut commits for release (#113)
1 parent e969e92 commit 27b3f47

File tree

12 files changed

+479
-71
lines changed

12 files changed

+479
-71
lines changed

integrations/ultralytics/README.md

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -119,9 +119,12 @@ cp sparseml/integrations/ultralytics/deepsparse/*.py yolov5
119119
cd yolov5
120120

121121
# install deepsparse and server dependencies
122-
pip install deepsparse flask flask-cors
122+
pip install deepsparse sparseml flask flask-cors
123123
```
124124

125+
Note: on new Ubuntu systems, to install `cv2` running `sudo apt-get update && apt-get install -y python3-opencv`
126+
may be necessary.
127+
125128

126129
### Benchmarking
127130
`benchmarking.py` is a script for benchmarking sparsified and quantized YOLOv3
@@ -130,8 +133,8 @@ performance with DeepSparse. For a full list of options run `python benchmarkin
130133
To run a benchmark run:
131134
```bash
132135
python benchmark.py
133-
zoo:cv/detection/yolo_v3-spp/pytorch/ultralytics/coco/pruned_quant-aggressive_90 \
134-
--batch-size 32 \
136+
zoo:cv/detection/yolo_v3-spp/pytorch/ultralytics/coco/pruned_quant-aggressive_94 \
137+
--batch-size 1 \
135138
--quantized-inputs
136139
```
137140

integrations/ultralytics/deepsparse/SERVER.md

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,11 +49,14 @@ pip install deepsparse sparseml flask flask-cors
4949

5050
### Server
5151

52-
First, start up the host `server.py` with your model of choice.
52+
First, start up the host `server.py` with your model of choice, SparseZoo stubs are
53+
also supported.
5354

5455
Example command:
5556
```bash
56-
python server.py ~/models/yolov3-pruned_quant.onnx
57+
python server.py \
58+
zoo:cv/detection/yolo_v3-spp/pytorch/ultralytics/coco/pruned_quant-aggressive_94 \
59+
--quantized-inputs
5760
```
5861

5962
You can leave that running as a detached process or in a spare terminal.

integrations/ultralytics/deepsparse/benchmark.py

Lines changed: 60 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,9 @@
2121
usage: benchmark.py [-h] [-e {deepsparse,onnxruntime,torch}]
2222
[--data-path DATA_PATH]
2323
[--image-shape IMAGE_SHAPE [IMAGE_SHAPE ...]]
24-
[-b BATCH_SIZE] [-c NUM_CORES] [-i NUM_ITERATIONS]
25-
[-w NUM_WARMUP_ITERATIONS] [-q] [--fp16] [--device DEVICE]
24+
[-b BATCH_SIZE] [-c NUM_CORES] [-s NUM_SOCKETS]
25+
[-i NUM_ITERATIONS] [-w NUM_WARMUP_ITERATIONS] [-q]
26+
[--fp16] [--device DEVICE]
2627
model_filepath
2728
2829
Benchmark sparsified YOLOv3 models
@@ -53,7 +54,13 @@
5354
The batch size to run the benchmark for
5455
-c NUM_CORES, --num-cores NUM_CORES
5556
The number of physical cores to run the benchmark on,
56-
defaults to all physical cores available on the system
57+
defaults to None where it uses all physical cores
58+
available on the system. For DeepSparse benchmarks,
59+
this value is the number of cores per socket
60+
-s NUM_SOCKETS, --num-sockets NUM_SOCKETS
61+
For DeepSparse benchmarks only. The number of physical
62+
cores to run the benchmark on. Defaults to None where
63+
is uses all sockets available on the system
5764
-i NUM_ITERATIONS, --num-iterations NUM_ITERATIONS
5865
The number of iterations the benchmark will be run for
5966
-w NUM_WARMUP_ITERATIONS, --num-warmup-iterations NUM_WARMUP_ITERATIONS
@@ -110,14 +117,17 @@
110117

111118
from deepsparse import compile_model, cpu
112119
from deepsparse.benchmark import BenchmarkResults
113-
from deepsparse_utils import load_image, postprocess_nms, pre_nms_postprocess
120+
from deepsparse_utils import (
121+
YoloPostprocessor,
122+
load_image,
123+
modify_yolo_onnx_input_shape,
124+
postprocess_nms,
125+
)
114126
from sparseml.onnx.utils import override_model_batch_size
115127
from sparsezoo.models.detection import yolo_v3 as zoo_yolo_v3
116128
from sparsezoo.utils import load_numpy_list
117129

118130

119-
CORES_PER_SOCKET, AVX_TYPE, _ = cpu.cpu_details()
120-
121131
DEEPSPARSE_ENGINE = "deepsparse"
122132
ORT_ENGINE = "onnxruntime"
123133
TORCH_ENGINE = "torch"
@@ -180,10 +190,22 @@ def parse_args():
180190
"-c",
181191
"--num-cores",
182192
type=int,
183-
default=CORES_PER_SOCKET,
193+
default=None,
184194
help=(
185195
"The number of physical cores to run the benchmark on, "
186-
"defaults to all physical cores available on the system"
196+
"defaults to None where it uses all physical cores available on the system. "
197+
"For DeepSparse benchmarks, this value is the number of cores per socket"
198+
),
199+
)
200+
parser.add_argument(
201+
"-s",
202+
"--num-sockets",
203+
type=int,
204+
default=None,
205+
help=(
206+
"For DeepSparse benchmarks only. The number of physical cores to run the "
207+
"benchmark on. Defaults to None where is uses all sockets available on the "
208+
"system"
187209
),
188210
)
189211
parser.add_argument(
@@ -227,7 +249,6 @@ def parse_args():
227249
)
228250

229251
args = parser.parse_args()
230-
231252
if args.engine == TORCH_ENGINE and args.device is None:
232253
args.device = "cuda:0" if torch.cuda.is_available() else "cpu"
233254

@@ -273,30 +294,46 @@ def _load_model(args) -> Any:
273294
raise ValueError(f"half precision is not supported for {args.engine}")
274295
if args.quantized_inputs and args.engine == TORCH_ENGINE:
275296
raise ValueError(f"quantized inputs not supported for {args.engine}")
276-
if args.num_cores != CORES_PER_SOCKET and args.engine == TORCH_ENGINE:
297+
if args.num_cores is not None and args.engine == TORCH_ENGINE:
277298
raise ValueError(
278299
f"overriding default num_cores not supported for {args.engine}"
279300
)
280301
if (
281-
args.num_cores != CORES_PER_SOCKET
302+
args.num_cores is not None
282303
and args.engine == ORT_ENGINE
283304
and onnxruntime.__version__ < "1.7"
284305
):
285-
print(
306+
raise ValueError(
286307
"overriding default num_cores not supported for onnxruntime < 1.7.0. "
287308
"If using an older build with OpenMP, try setting the OMP_NUM_THREADS "
288309
"environment variable"
289310
)
311+
if args.num_sockets is not None and args.engine != DEEPSPARSE_ENGINE:
312+
raise ValueError(f"Overriding num_sockets is not supported for {args.engine}")
313+
314+
# scale static ONNX graph to desired image shape
315+
if args.engine in [DEEPSPARSE_ENGINE, ORT_ENGINE]:
316+
args.model_filepath, _ = modify_yolo_onnx_input_shape(
317+
args.model_filepath, args.image_shape
318+
)
290319

291320
# load model
292321
if args.engine == DEEPSPARSE_ENGINE:
293322
print(f"Compiling deepsparse model for {args.model_filepath}")
294-
model = compile_model(args.model_filepath, args.batch_size, args.num_cores)
323+
model = compile_model(
324+
args.model_filepath, args.batch_size, args.num_cores, args.num_sockets
325+
)
326+
if args.quantized_inputs and not model.cpu_vnni:
327+
print(
328+
"WARNING: VNNI instructions not detected, "
329+
"quantization speedup not well supported"
330+
)
295331
elif args.engine == ORT_ENGINE:
296332
print(f"loading onnxruntime model for {args.model_filepath}")
297333

298334
sess_options = onnxruntime.SessionOptions()
299-
sess_options.intra_op_num_threads = args.num_cores
335+
if args.num_cores is not None:
336+
sess_options.intra_op_num_threads = args.num_cores
300337
sess_options.log_severity_level = 3
301338
sess_options.graph_optimization_level = (
302339
onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
@@ -376,7 +413,7 @@ def _run_model(
376413
def benchmark_yolo(args):
377414
model = _load_model(args)
378415
print("Loading dataset")
379-
dataset = _load_images(args.data_path, args.image_shape)
416+
dataset = _load_images(args.data_path, tuple(args.image_shape))
380417
total_iterations = args.num_iterations + args.num_warmup_iterations
381418
data_loader = _iter_batches(dataset, args.batch_size, total_iterations)
382419

@@ -388,6 +425,12 @@ def benchmark_yolo(args):
388425
flush=True,
389426
)
390427

428+
postprocessor = (
429+
YoloPostprocessor(args.image_shape)
430+
if args.engine in [DEEPSPARSE_ENGINE, ORT_ENGINE]
431+
else None
432+
)
433+
391434
results = BenchmarkResults()
392435
progress_bar = tqdm(total=args.num_iterations)
393436

@@ -403,8 +446,8 @@ def benchmark_yolo(args):
403446
outputs = _run_model(args, model, batch)
404447

405448
# post-processing
406-
if args.engine != TORCH_ENGINE:
407-
outputs = pre_nms_postprocess(outputs)
449+
if postprocessor:
450+
outputs = postprocessor.pre_nms_postprocess(outputs)
408451

409452
# NMS
410453
outputs = postprocess_nms(outputs)

integrations/ultralytics/deepsparse/deepsparse_utils.py

Lines changed: 107 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,17 @@
2020
"""
2121

2222

23-
from typing import List, Tuple, Union
23+
from tempfile import NamedTemporaryFile
24+
from typing import List, Optional, Tuple, Union
2425

2526
import cv2
2627
import numpy
28+
import onnx
2729
import torch
2830

31+
from sparseml.onnx.utils import get_tensor_dim_shape, set_tensor_dim_shape
32+
from sparsezoo import Zoo
33+
2934
# ultralytics/yolov5 imports
3035
from utils.general import non_max_suppression
3136

@@ -37,22 +42,13 @@
3742
]
3843

3944

40-
def _get_grid(size: int) -> torch.Tensor:
41-
# adapted from yolov5.yolo.Detect._make_grid
42-
coords_y, coords_x = torch.meshgrid([torch.arange(size), torch.arange(size)])
43-
grid = torch.stack((coords_x, coords_y), 2)
44-
return grid.view(1, 1, size, size, 2).float()
45-
46-
4745
# Yolo V3 specific variables
4846
_YOLO_V3_ANCHORS = [
4947
torch.Tensor([[10, 13], [16, 30], [33, 23]]),
5048
torch.Tensor([[30, 61], [62, 45], [59, 119]]),
5149
torch.Tensor([[116, 90], [156, 198], [373, 326]]),
5250
]
5351
_YOLO_V3_ANCHOR_GRIDS = [t.clone().view(1, -1, 1, 1, 2) for t in _YOLO_V3_ANCHORS]
54-
_YOLO_V3_OUTPUT_SHAPES = [80, 40, 20]
55-
_YOLO_V3_GRIDS = [_get_grid(grid_size) for grid_size in _YOLO_V3_OUTPUT_SHAPES]
5652

5753

5854
def load_image(
@@ -70,28 +66,53 @@ def load_image(
7066
return img
7167

7268

73-
def pre_nms_postprocess(outputs: List[numpy.ndarray]) -> torch.Tensor:
69+
class YoloPostprocessor:
7470
"""
75-
:param outputs: raw outputs of a YOLOv3 model before anchor grid processing
76-
:return: post-processed model outputs without NMS.
77-
"""
78-
# postprocess and transform raw outputs into single torch tensor
79-
processed_outputs = []
80-
for idx, pred in enumerate(outputs):
81-
pred = torch.from_numpy(pred)
82-
pred = pred.sigmoid()
71+
Class for performing postprocessing of YOLOv3 model predictions
8372
84-
# get grid and stride
85-
grid = _YOLO_V3_GRIDS[idx]
86-
anchor_grid = _YOLO_V3_ANCHOR_GRIDS[idx]
87-
stride = 640 / _YOLO_V3_OUTPUT_SHAPES[idx]
73+
:param image_size: size of input image to model. used to calculate stride based on
74+
output shapes
75+
"""
8876

89-
# decode xywh box values
90-
pred[..., 0:2] = (pred[..., 0:2] * 2.0 - 0.5 + grid) * stride
91-
pred[..., 2:4] = (pred[..., 2:4] * 2) ** 2 * anchor_grid
92-
# flatten anchor and grid dimensions -> (bs, num_predictions, num_classes + 5)
93-
processed_outputs.append(pred.view(pred.size(0), -1, pred.size(-1)))
94-
return torch.cat(processed_outputs, 1)
77+
def __init__(self, image_size: Tuple[int]):
78+
self._image_size = image_size
79+
self._grids = {} # Dict[Tuple[int], torch.Tensor]
80+
81+
def pre_nms_postprocess(self, outputs: List[numpy.ndarray]) -> torch.Tensor:
82+
"""
83+
:param outputs: raw outputs of a YOLOv3 model before anchor grid processing
84+
:return: post-processed model outputs without NMS.
85+
"""
86+
# postprocess and transform raw outputs into single torch tensor
87+
processed_outputs = []
88+
for idx, pred in enumerate(outputs):
89+
pred = torch.from_numpy(pred)
90+
pred = pred.sigmoid()
91+
92+
# get grid and stride
93+
grid_shape = pred.shape[2:4]
94+
grid = self._get_grid(grid_shape)
95+
anchor_grid = _YOLO_V3_ANCHOR_GRIDS[idx]
96+
stride = self._image_size[0] / grid_shape[0]
97+
98+
# decode xywh box values
99+
pred[..., 0:2] = (pred[..., 0:2] * 2.0 - 0.5 + grid) * stride
100+
pred[..., 2:4] = (pred[..., 2:4] * 2) ** 2 * anchor_grid
101+
# flatten anchor and grid dimensions -> (bs, num_predictions, num_classes + 5)
102+
processed_outputs.append(pred.view(pred.size(0), -1, pred.size(-1)))
103+
return torch.cat(processed_outputs, 1)
104+
105+
def _get_grid(self, grid_shape: Tuple[int]) -> torch.Tensor:
106+
if grid_shape not in self._grids:
107+
# adapted from yolov5.yolo.Detect._make_grid
108+
coords_y, coords_x = torch.meshgrid(
109+
[torch.arange(grid_shape[0]), torch.arange(grid_shape[1])]
110+
)
111+
grid = torch.stack((coords_x, coords_y), 2)
112+
self._grids[grid_shape] = grid.view(
113+
1, 1, grid_shape[0], grid_shape[1], 2
114+
).float()
115+
return self._grids[grid_shape]
95116

96117

97118
def postprocess_nms(outputs: torch.Tensor) -> List[numpy.ndarray]:
@@ -102,3 +123,60 @@ def postprocess_nms(outputs: torch.Tensor) -> List[numpy.ndarray]:
102123
# run nms in PyTorch, only post-process first output
103124
nms_outputs = non_max_suppression(outputs)
104125
return [output.cpu().numpy() for output in nms_outputs]
126+
127+
128+
def modify_yolo_onnx_input_shape(
129+
model_path: str, image_shape: Tuple[int]
130+
) -> Tuple[str, Optional[NamedTemporaryFile]]:
131+
"""
132+
Creates a new YOLOv3 ONNX model from the given path that accepts the given input
133+
shape. If the given model already has the given input shape no modifications are
134+
made. Uses a tempfile to store the modified model file.
135+
136+
:param model_path: file path to YOLOv3 ONNX model or SparseZoo stub of the model
137+
to be loaded
138+
:param image_shape: 2-tuple of the image shape to resize this yolo model to
139+
:return: filepath to an onnx model reshaped to the given input shape will be the
140+
original path if the shape is the same. Additionally returns the
141+
NamedTemporaryFile for managing the scope of the object for file deletion
142+
"""
143+
original_model_path = model_path
144+
if model_path.startswith("zoo:"):
145+
# load SparseZoo Model from stub
146+
model = Zoo.load_model_from_stub(model_path)
147+
model_path = model.onnx_file.downloaded_path()
148+
print(f"Downloaded {original_model_path} to {model_path}")
149+
150+
model = onnx.load(model_path)
151+
model_input = model.graph.input[0]
152+
153+
initial_x = get_tensor_dim_shape(model_input, 2)
154+
initial_y = get_tensor_dim_shape(model_input, 3)
155+
156+
if not (isinstance(initial_x, int) and isinstance(initial_y, int)):
157+
return model_path, None # model graph does not have static integer input shape
158+
159+
if (initial_x, initial_y) == tuple(image_shape):
160+
return model_path, None # no shape modification needed
161+
162+
scale_x = initial_x / image_shape[0]
163+
scale_y = initial_y / image_shape[1]
164+
set_tensor_dim_shape(model_input, 2, image_shape[0])
165+
set_tensor_dim_shape(model_input, 3, image_shape[1])
166+
167+
for model_output in model.graph.output:
168+
output_x = get_tensor_dim_shape(model_output, 2)
169+
output_y = get_tensor_dim_shape(model_output, 3)
170+
set_tensor_dim_shape(model_output, 2, int(output_x / scale_x))
171+
set_tensor_dim_shape(model_output, 3, int(output_y / scale_y))
172+
173+
tmp_file = NamedTemporaryFile() # file will be deleted after program exit
174+
onnx.save(model, tmp_file.name)
175+
176+
print(
177+
f"Overwriting original model shape {(initial_x, initial_y)} to {image_shape}\n"
178+
f"Original model path: {original_model_path}, new temporary model saved to "
179+
f"{tmp_file.name}"
180+
)
181+
182+
return tmp_file.name, tmp_file

0 commit comments

Comments
 (0)