Skip to content

Commit 1c7749d

Browse files
authored
[Enhancement]: Support opset_version 13 (#2071)
* upgrade to opset 13 * fix unsqueeze * fix mmseg yml * fix mmseg reg test * forcely change opset13 * fix mmdet3d * optimize squeeze * update base dockerfile * support squeeze/unsqueeze with axes as input in onnx2ncnn * update optimizer for squeeze/unsqueeze * revert * Revert "support squeeze/unsqueeze with axes as input in onnx2ncnn" This reverts commit 5ca9f1a. * fix docs * fix opset
1 parent 389a146 commit 1c7749d

File tree

7 files changed

+69
-31
lines changed

7 files changed

+69
-31
lines changed

docker/Base/Dockerfile

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,3 +108,13 @@ RUN wget -c $TENSORRT_URL && \
108108
ENV TENSORRT_DIR=/root/workspace/TensorRT
109109
ENV LD_LIBRARY_PATH=$TENSORRT_DIR/lib:$LD_LIBRARY_PATH
110110
ENV PATH=$TENSORRT_DIR/bin:$PATH
111+
112+
# openvino
113+
RUN wget https://storage.openvinotoolkit.org/repositories/openvino/packages/2022.3/linux/l_openvino_toolkit_ubuntu20_2022.3.0.9052.9752fafe8eb_x86_64.tgz &&\
114+
tar -zxvf ./l_openvino_toolkit*.tgz &&\
115+
rm ./l_openvino_toolkit*.tgz &&\
116+
mv ./l_openvino* ./openvino_toolkit &&\
117+
bash ./openvino_toolkit/install_dependencies/install_openvino_dependencies.sh
118+
119+
ENV OPENVINO_DIR=/root/workspace/openvino_toolkit
120+
ENV InferenceEngine_DIR=$OPENVINO_DIR/runtime/cmake

mmdeploy/apis/onnx/optimizer.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,38 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
22
from typing import Callable
33

4+
import torch
5+
46
from mmdeploy.core import FUNCTION_REWRITER
57

68

9+
def update_squeeze_unsqueeze_opset13_pass(graph, params_dict, torch_out):
10+
"""Update Squeeze/Unsqueeze axes for opset13."""
11+
for node in graph.nodes():
12+
if node.kind() in ['onnx::Squeeze', 'onnx::Unsqueeze'] and \
13+
node.hasAttribute('axes'):
14+
axes = node['axes']
15+
axes_node = graph.create('onnx::Constant')
16+
axes_node.t_('value', torch.LongTensor(axes))
17+
node.removeAttribute('axes')
18+
node.addInput(axes_node.output())
19+
axes_node.insertBefore(node)
20+
return graph, params_dict, torch_out
21+
22+
723
@FUNCTION_REWRITER.register_rewriter('torch.onnx.utils._model_to_graph')
824
def model_to_graph__custom_optimizer(*args, **kwargs):
925
"""Rewriter of _model_to_graph, add custom passes."""
1026
ctx = FUNCTION_REWRITER.get_context()
1127
graph, params_dict, torch_out = ctx.origin_func(*args, **kwargs)
12-
28+
if hasattr(ctx, 'opset'):
29+
opset_version = ctx.opset
30+
else:
31+
from mmdeploy.utils import get_ir_config
32+
opset_version = get_ir_config(ctx.cfg).get('opset_version', 11)
33+
if opset_version >= 13:
34+
graph, params_dict, torch_out = update_squeeze_unsqueeze_opset13_pass(
35+
graph, params_dict, torch_out)
1336
custom_passes = getattr(ctx, 'onnx_custom_passes', None)
1437

1538
if custom_passes is not None:

mmdeploy/apis/onnx/passes/optimize_onnx.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,6 @@ def optimize_onnx(ctx, graph, params_dict, torch_out):
1818
logger.warning(
1919
'Can not optimize model, please build torchscipt extension.\n'
2020
'More details: '
21-
'https://github.com/open-mmlab/mmdeploy/tree/1.x/docs/en/experimental/onnx_optimizer.md' # noqa
21+
'https://github.com/open-mmlab/mmdeploy/tree/main/docs/en/experimental/onnx_optimizer.md' # noqa
2222
)
2323
return graph, params_dict, torch_out

mmdeploy/codebase/mmdet/models/task_modules/prior_generators/anchor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,8 @@ def symbolic(g, base_anchors, feat_h, feat_w, stride_h: int,
4848
stride_w: int):
4949
"""Map ops to onnx symbolics."""
5050
# zero_h and zero_w is used to provide shape to GridPriorsTRT
51-
feat_h = g.op('Unsqueeze', feat_h, axes_i=[0])
52-
feat_w = g.op('Unsqueeze', feat_w, axes_i=[0])
51+
feat_h = symbolic_helper._unsqueeze_helper(g, feat_h, [0])
52+
feat_w = symbolic_helper._unsqueeze_helper(g, feat_w, [0])
5353
zero_h = g.op(
5454
'ConstantOfShape',
5555
feat_h,

mmdeploy/codebase/mmdet3d/deploy/voxel_detection.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,14 +92,14 @@ def build_backend_model(self,
9292

9393
def create_input(
9494
self,
95-
pcd: str,
95+
pcd: Union[str, Sequence[str]],
9696
input_shape: Sequence[int] = None,
9797
data_preprocessor: Optional[BaseDataPreprocessor] = None
9898
) -> Tuple[Dict, torch.Tensor]:
9999
"""Create input for detector.
100100
101101
Args:
102-
pcd (str): Input pcd file path.
102+
pcd (str, Sequence[str]): Input pcd file path.
103103
input_shape (Sequence[int], optional): model input shape.
104104
Defaults to None.
105105
data_preprocessor (Optional[BaseDataPreprocessor], optional):
@@ -115,7 +115,9 @@ def create_input(
115115
test_pipeline = Compose(test_pipeline)
116116
box_type_3d, box_mode_3d = \
117117
get_box_type(cfg.test_dataloader.dataset.box_type_3d)
118-
118+
# do not support batch inference
119+
if isinstance(pcd, (list, tuple)):
120+
pcd = pcd[0]
119121
data = []
120122
data_ = dict(
121123
lidar_points=dict(lidar_path=pcd),

tests/regression/mmseg.yml

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ torchscript:
122122

123123
models:
124124
- name: FCN
125-
metafile: configs/fcn/fcn.yml
125+
metafile: configs/fcn/metafile.yaml
126126
model_configs:
127127
- configs/fcn/fcn_r50-d8_4xb2-40k_cityscapes-512x1024.py
128128
pipelines:
@@ -134,7 +134,7 @@ models:
134134
- *pipeline_openvino_dynamic_fp32
135135

136136
- name: PSPNet
137-
metafile: configs/pspnet/pspnet.yml
137+
metafile: configs/pspnet/metafile.yaml
138138
model_configs:
139139
- configs/pspnet/pspnet_r50-d8_4xb2-40k_cityscapes-512x1024.py
140140
pipelines:
@@ -146,7 +146,7 @@ models:
146146
- *pipeline_openvino_static_fp32
147147

148148
- name: deeplabv3
149-
metafile: configs/deeplabv3/deeplabv3.yml
149+
metafile: configs/deeplabv3/metafile.yaml
150150
model_configs:
151151
- configs/deeplabv3/deeplabv3_r50-d8_4xb2-40k_cityscapes-512x1024.py
152152
pipelines:
@@ -158,7 +158,7 @@ models:
158158
- *pipeline_openvino_dynamic_fp32
159159

160160
- name: deeplabv3+
161-
metafile: configs/deeplabv3plus/deeplabv3plus.yml
161+
metafile: configs/deeplabv3plus/metafile.yaml
162162
model_configs:
163163
- configs/deeplabv3plus/deeplabv3plus_r50-d8_4xb2-40k_cityscapes-512x1024.py
164164
pipelines:
@@ -170,7 +170,7 @@ models:
170170
- *pipeline_openvino_dynamic_fp32
171171

172172
- name: Fast-SCNN
173-
metafile: configs/fastscnn/fastscnn.yml
173+
metafile: configs/fastscnn/metafile.yaml
174174
model_configs:
175175
- configs/fastscnn/fast_scnn_8xb4-160k_cityscapes-512x1024.py
176176
pipelines:
@@ -181,7 +181,7 @@ models:
181181
- *pipeline_openvino_static_fp32
182182

183183
- name: UNet
184-
metafile: configs/unet/unet.yml
184+
metafile: configs/unet/metafile.yaml
185185
model_configs:
186186
- configs/unet/unet-s5-d16_fcn_4xb4-160k_cityscapes-512x1024.py
187187
pipelines:
@@ -192,7 +192,7 @@ models:
192192
- *pipeline_pplnn_dynamic_fp32
193193

194194
- name: ANN
195-
metafile: configs/ann/ann.yml
195+
metafile: configs/ann/metafile.yaml
196196
model_configs:
197197
- configs/ann/ann_r50-d8_4xb2-40k_cityscapes-512x1024.py
198198
pipelines:
@@ -201,7 +201,7 @@ models:
201201
- *pipeline_ts_fp32
202202

203203
- name: APCNet
204-
metafile: configs/apcnet/apcnet.yml
204+
metafile: configs/apcnet/metafile.yaml
205205
model_configs:
206206
- configs/apcnet/apcnet_r50-d8_4xb2-40k_cityscapes-512x1024.py
207207
pipelines:
@@ -211,7 +211,7 @@ models:
211211
- *pipeline_ts_fp32
212212

213213
- name: BiSeNetV1
214-
metafile: configs/bisenetv1/bisenetv1.yml
214+
metafile: configs/bisenetv1/metafile.yaml
215215
model_configs:
216216
- configs/bisenetv1/bisenetv1_r18-d32_4xb4-160k_cityscapes-1024x1024.py
217217
pipelines:
@@ -222,7 +222,7 @@ models:
222222
- *pipeline_ts_fp32
223223

224224
- name: BiSeNetV2
225-
metafile: configs/bisenetv2/bisenetv2.yml
225+
metafile: configs/bisenetv2/metafile.yaml
226226
model_configs:
227227
- configs/bisenetv2/bisenetv2_fcn_4xb4-160k_cityscapes-1024x1024.py
228228
pipelines:
@@ -233,7 +233,7 @@ models:
233233
- *pipeline_ts_fp32
234234

235235
- name: CGNet
236-
metafile: configs/cgnet/cgnet.yml
236+
metafile: configs/cgnet/metafile.yaml
237237
model_configs:
238238
- configs/cgnet/cgnet_fcn_4xb8-60k_cityscapes-512x1024.py
239239
pipelines:
@@ -244,7 +244,7 @@ models:
244244
- *pipeline_ts_fp32
245245

246246
- name: EMANet
247-
metafile: configs/emanet/emanet.yml
247+
metafile: configs/emanet/metafile.yaml
248248
model_configs:
249249
- configs/emanet/emanet_r50-d8_4xb2-80k_cityscapes-512x1024.py
250250
pipelines:
@@ -254,7 +254,7 @@ models:
254254
- *pipeline_ts_fp32
255255

256256
- name: EncNet
257-
metafile: configs/encnet/encnet.yml
257+
metafile: configs/encnet/metafile.yaml
258258
model_configs:
259259
- configs/encnet/encnet_r50-d8_4xb2-40k_cityscapes-512x1024.py
260260
pipelines:
@@ -264,7 +264,7 @@ models:
264264
- *pipeline_ts_fp32
265265

266266
- name: ERFNet
267-
metafile: configs/erfnet/erfnet.yml
267+
metafile: configs/erfnet/metafile.yaml
268268
model_configs:
269269
- configs/erfnet/erfnet_fcn_4xb4-160k_cityscapes-512x1024.py
270270
pipelines:
@@ -275,7 +275,7 @@ models:
275275
- *pipeline_ts_fp32
276276

277277
- name: FastFCN
278-
metafile: configs/fastfcn/fastfcn.yml
278+
metafile: configs/fastfcn/metafile.yaml
279279
model_configs:
280280
- configs/fastfcn/fastfcn_r50-d32_jpu_aspp_4xb2-80k_cityscapes-512x1024.py
281281
pipelines:
@@ -286,7 +286,7 @@ models:
286286
- *pipeline_ts_fp32
287287

288288
- name: GCNet
289-
metafile: configs/gcnet/gcnet.yml
289+
metafile: configs/gcnet/metafile.yaml
290290
model_configs:
291291
- configs/gcnet/gcnet_r50-d8_4xb2-40k_cityscapes-512x1024.py
292292
pipelines:
@@ -295,7 +295,7 @@ models:
295295
- *pipeline_ts_fp32
296296

297297
- name: ICNet
298-
metafile: configs/icnet/icnet.yml
298+
metafile: configs/icnet/metafile.yaml
299299
model_configs:
300300
- configs/icnet/icnet_r18-d8_4xb2-80k_cityscapes-832x832.py
301301
pipelines:
@@ -305,7 +305,7 @@ models:
305305
- *pipeline_ts_fp32
306306

307307
- name: ISANet
308-
metafile: configs/isanet/isanet.yml
308+
metafile: configs/isanet/metafile.yaml
309309
model_configs:
310310
- configs/isanet/isanet_r50-d8_4xb2-40k_cityscapes-512x1024.py
311311
pipelines:
@@ -314,7 +314,7 @@ models:
314314
- *pipeline_openvino_static_fp32_512x512
315315

316316
- name: OCRNet
317-
metafile: configs/ocrnet/ocrnet.yml
317+
metafile: configs/ocrnet/metafile.yaml
318318
model_configs:
319319
- configs/ocrnet/ocrnet_hr18s_4xb2-40k_cityscapes-512x1024.py
320320
pipelines:
@@ -325,7 +325,7 @@ models:
325325
- *pipeline_ts_fp32
326326

327327
- name: PointRend
328-
metafile: configs/point_rend/point_rend.yml
328+
metafile: configs/point_rend/metafile.yaml
329329
model_configs:
330330
- configs/point_rend/pointrend_r50_4xb2-80k_cityscapes-512x1024.py
331331
pipelines:
@@ -334,7 +334,7 @@ models:
334334
- *pipeline_ts_fp32
335335

336336
- name: Semantic FPN
337-
metafile: configs/sem_fpn/sem_fpn.yml
337+
metafile: configs/sem_fpn/metafile.yaml
338338
model_configs:
339339
- configs/sem_fpn/fpn_r50_4xb2-80k_cityscapes-512x1024.py
340340
pipelines:
@@ -345,7 +345,7 @@ models:
345345
- *pipeline_ts_fp32
346346

347347
- name: STDC
348-
metafile: configs/stdc/stdc.yml
348+
metafile: configs/stdc/metafile.yaml
349349
model_configs:
350350
- configs/stdc/stdc1_in1k-pre_4xb12-80k_cityscapes-512x1024.py
351351
- configs/stdc/stdc2_in1k-pre_4xb12-80k_cityscapes-512x1024.py
@@ -357,14 +357,14 @@ models:
357357
- *pipeline_ts_fp32
358358

359359
- name: UPerNet
360-
metafile: configs/upernet/upernet.yml
360+
metafile: configs/upernet/metafile.yaml
361361
model_configs:
362362
- configs/upernet/upernet_r50_4xb2-40k_cityscapes-512x1024.py
363363
pipelines:
364364
- *pipeline_ort_static_fp32
365365
- *pipeline_trt_static_fp16
366366
- name: Segmenter
367-
metafile: configs/segmenter/segmenter.yml
367+
metafile: configs/segmenter/metafile.yaml
368368
model_configs:
369369
- configs/segmenter/segmenter_vit-s_fcn_8xb1-160k_ade20k-512x512.py
370370
pipelines:

tools/regression_test.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,9 @@ def get_pytorch_result(model_name: str, meta_info: dict, checkpoint_path: Path,
302302
# get metric
303303
model_info = meta_info[model_config_name]
304304
metafile_metric_info = model_info['Results']
305+
# deal with mmseg case
306+
if not isinstance(metafile_metric_info, (list, tuple)):
307+
metafile_metric_info = [metafile_metric_info]
305308
pytorch_metric = dict()
306309
using_dataset = set()
307310
using_task = set()

0 commit comments

Comments
 (0)