Skip to content

Commit db39fd4

Browse files
[Enhance] Use Inferencer to implement Demo (#2763)
1 parent f4c032e commit db39fd4

22 files changed

+927
-533
lines changed

demo/inference_demo.ipynb

Lines changed: 48 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -2,117 +2,83 @@
22
"cells": [
33
{
44
"cell_type": "code",
5-
"execution_count": null,
6-
"source": [
7-
"from mmdet3d.apis import inference_detector, init_model\n",
8-
"from mmdet3d.registry import VISUALIZERS\n",
9-
"from mmdet3d.utils import register_all_modules"
10-
],
11-
"outputs": [],
5+
"execution_count": 25,
126
"metadata": {
137
"pycharm": {
148
"is_executing": false
159
}
16-
}
17-
},
18-
{
19-
"cell_type": "code",
20-
"execution_count": null,
21-
"source": [
22-
"# register all modules in mmdet3d into the registries\n",
23-
"register_all_modules()"
24-
],
10+
},
2511
"outputs": [],
26-
"metadata": {}
27-
},
28-
{
29-
"cell_type": "code",
30-
"execution_count": 8,
3112
"source": [
32-
"config_file = '../configs/second/hv_second_secfpn_6x8_80e_kitti-3d-car.py'\n",
33-
"# download the checkpoint from model zoo and put it in `checkpoints/`\n",
34-
"checkpoint_file = '../work_dirs/second/epoch_40.pth'"
35-
],
36-
"outputs": [],
37-
"metadata": {
38-
"pycharm": {
39-
"is_executing": false
40-
}
41-
}
13+
"from mmdet3d.apis import LidarDet3DInferencer"
14+
]
4215
},
4316
{
4417
"cell_type": "code",
4518
"execution_count": null,
46-
"source": [
47-
"# build the model from a config file and a checkpoint file\n",
48-
"model = init_model(config_file, checkpoint_file, device='cuda:0')"
49-
],
19+
"metadata": {},
5020
"outputs": [],
51-
"metadata": {}
21+
"source": [
22+
"# initialize inferencer\n",
23+
"inferencer = LidarDet3DInferencer('pointpillars_kitti-3class')"
24+
]
5225
},
5326
{
5427
"cell_type": "code",
5528
"execution_count": null,
56-
"source": [
57-
"# init visualizer\n",
58-
"visualizer = VISUALIZERS.build(model.cfg.visualizer)\n",
59-
"visualizer.dataset_meta = {\n",
60-
" 'CLASSES': model.CLASSES,\n",
61-
" 'PALETTE': model.PALETTE\n",
62-
"}"
63-
],
64-
"outputs": [],
6529
"metadata": {
6630
"pycharm": {
6731
"is_executing": false
6832
}
69-
}
33+
},
34+
"outputs": [],
35+
"source": [
36+
"# inference\n",
37+
"inputs = dict(points='./data/kitti/000008.bin')\n",
38+
"inferencer(inputs)"
39+
]
7040
},
7141
{
7242
"cell_type": "code",
73-
"execution_count": 11,
74-
"source": [
75-
"# test a single sample\n",
76-
"pcd = './data/kitti/000008.bin'\n",
77-
"result, data = inference_detector(model, pcd)\n",
78-
"points = data['inputs']['points']\n",
79-
"data_input = dict(points=points)"
80-
],
43+
"execution_count": null,
44+
"metadata": {},
8145
"outputs": [],
82-
"metadata": {
83-
"pycharm": {
84-
"is_executing": false
85-
}
86-
}
46+
"source": [
47+
"# inference and visualize\n",
48+
"# NOTE: use the `Esc` key to exit Open3D window in Jupyter Notebook Environment\n",
49+
"inferencer(inputs, show=True)"
50+
]
8751
},
8852
{
8953
"cell_type": "code",
9054
"execution_count": null,
91-
"source": [
92-
"# show the results\n",
93-
"out_dir = './'\n",
94-
"visualizer.add_datasample(\n",
95-
" 'result',\n",
96-
" data_input,\n",
97-
" data_sample=result,\n",
98-
" draw_gt=False,\n",
99-
" show=True,\n",
100-
" wait_time=0,\n",
101-
" out_file=out_dir,\n",
102-
" vis_task='det')"
103-
],
55+
"metadata": {},
10456
"outputs": [],
105-
"metadata": {
106-
"pycharm": {
107-
"is_executing": false
108-
}
109-
}
57+
"source": [
58+
"# If your operating environment does not have a display device,\n",
59+
"# (e.g. a remote server), you can save the predictions and visualize\n",
60+
"# them in local devices.\n",
61+
"inferencer(inputs, show=False, out_dir='./remote_outputs')\n",
62+
"\n",
63+
"# Simulate the migration process\n",
64+
"%mv ./remote_outputs ./local_outputs\n",
65+
"\n",
66+
"# Visualize the predictions from the saved files\n",
67+
"# NOTE: use the `Esc` key to exit Open3D window in Jupyter Notebook Environment\n",
68+
"local_inferencer = LidarDet3DInferencer('pointpillars_kitti-3class')\n",
69+
"inputs = local_inferencer._inputs_to_list(inputs)\n",
70+
"local_inferencer.visualize_preds_fromfile(inputs, ['local_outputs/preds/000008.json'], show=True)"
71+
]
11072
}
11173
],
11274
"metadata": {
75+
"interpreter": {
76+
"hash": "a0c343fece975dd89087e8c2194dd4d3db28d7000f1b32ed9ed9d584dd54dbbe"
77+
},
11378
"kernelspec": {
114-
"name": "python3",
115-
"display_name": "Python 3.7.6 64-bit ('torch1.7-cu10.1': conda)"
79+
"display_name": "Python 3 (ipykernel)",
80+
"language": "python",
81+
"name": "python3"
11682
},
11783
"language_info": {
11884
"codemirror_mode": {
@@ -124,19 +90,16 @@
12490
"name": "python",
12591
"nbconvert_exporter": "python",
12692
"pygments_lexer": "ipython3",
127-
"version": "3.7.6"
93+
"version": "3.9.16"
12894
},
12995
"pycharm": {
13096
"stem_cell": {
13197
"cell_type": "raw",
132-
"source": [],
13398
"metadata": {
13499
"collapsed": false
135-
}
100+
},
101+
"source": []
136102
}
137-
},
138-
"interpreter": {
139-
"hash": "a0c343fece975dd89087e8c2194dd4d3db28d7000f1b32ed9ed9d584dd54dbbe"
140103
}
141104
},
142105
"nbformat": 4,

demo/mono_det_demo.py

Lines changed: 67 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,19 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
2+
import logging
3+
import os
24
from argparse import ArgumentParser
35

4-
import mmcv
6+
from mmengine.logging import print_log
57

6-
from mmdet3d.apis import inference_mono_3d_detector, init_model
7-
from mmdet3d.registry import VISUALIZERS
8+
from mmdet3d.apis import MonoDet3DInferencer
89

910

1011
def parse_args():
1112
parser = ArgumentParser()
12-
parser.add_argument('img', help='image file')
13-
parser.add_argument('ann', help='ann file')
14-
parser.add_argument('config', help='Config file')
15-
parser.add_argument('checkpoint', help='Checkpoint file')
13+
parser.add_argument('img', help='Image file')
14+
parser.add_argument('infos', help='Infos file with annotations')
15+
parser.add_argument('model', help='Config file')
16+
parser.add_argument('weights', help='Checkpoint file')
1617
parser.add_argument(
1718
'--device', default='cuda:0', help='Device used for inference')
1819
parser.add_argument(
@@ -21,50 +22,77 @@ def parse_args():
2122
default='CAM_BACK',
2223
help='choose camera type to inference')
2324
parser.add_argument(
24-
'--score-thr', type=float, default=0.30, help='bbox score threshold')
25+
'--pred-score-thr',
26+
type=float,
27+
default=0.3,
28+
help='bbox score threshold')
2529
parser.add_argument(
26-
'--out-dir', type=str, default='demo', help='dir to save results')
30+
'--out-dir',
31+
type=str,
32+
default='outputs',
33+
help='Output directory of prediction and visualization results.')
2734
parser.add_argument(
2835
'--show',
2936
action='store_true',
30-
help='show online visualization results')
37+
help='Show online visualization results')
38+
parser.add_argument(
39+
'--wait-time',
40+
type=float,
41+
default=-1,
42+
help='The interval of show (s). Demo will be blocked in showing'
43+
'results, if wait_time is -1. Defaults to -1.')
3144
parser.add_argument(
32-
'--snapshot',
45+
'--no-save-vis',
3346
action='store_true',
34-
help='whether to save online visualization results')
35-
args = parser.parse_args()
36-
return args
47+
help='Do not save detection visualization results')
48+
parser.add_argument(
49+
'--no-save-pred',
50+
action='store_true',
51+
help='Do not save detection prediction results')
52+
parser.add_argument(
53+
'--print-result',
54+
action='store_true',
55+
help='Whether to print the results.')
56+
call_args = vars(parser.parse_args())
57+
58+
call_args['inputs'] = dict(
59+
img=call_args.pop('img'), infos=call_args.pop('infos'))
60+
call_args.pop('cam_type')
61+
62+
if call_args['no_save_vis'] and call_args['no_save_pred']:
63+
call_args['out_dir'] = ''
64+
65+
init_kws = ['model', 'weights', 'device']
66+
init_args = {}
67+
for init_kw in init_kws:
68+
init_args[init_kw] = call_args.pop(init_kw)
3769

70+
# NOTE: If your operating environment does not have a display device,
71+
# (e.g. a remote server), you can save the predictions and visualize
72+
# them in local devices.
73+
if os.environ.get('DISPLAY') is None and call_args['show']:
74+
print_log(
75+
'Display device not found. `--show` is forced to False',
76+
logger='current',
77+
level=logging.WARNING)
78+
call_args['show'] = False
3879

39-
def main(args):
40-
# build the model from a config file and a checkpoint file
41-
model = init_model(args.config, args.checkpoint, device=args.device)
80+
return init_args, call_args
4281

43-
# init visualizer
44-
visualizer = VISUALIZERS.build(model.cfg.visualizer)
45-
visualizer.dataset_meta = model.dataset_meta
4682

47-
# test a single image
48-
result = inference_mono_3d_detector(model, args.img, args.ann,
49-
args.cam_type)
83+
def main():
84+
# TODO: Support inference of point cloud numpy file.
85+
init_args, call_args = parse_args()
5086

51-
img = mmcv.imread(args.img)
52-
img = mmcv.imconvert(img, 'bgr', 'rgb')
87+
inferencer = MonoDet3DInferencer(**init_args)
88+
inferencer(**call_args)
5389

54-
data_input = dict(img=img)
55-
# show the results
56-
visualizer.add_datasample(
57-
'result',
58-
data_input,
59-
data_sample=result,
60-
draw_gt=False,
61-
show=args.show,
62-
wait_time=-1,
63-
out_file=args.out_dir,
64-
pred_score_thr=args.score_thr,
65-
vis_task='mono_det')
90+
if call_args['out_dir'] != '' and not (call_args['no_save_vis']
91+
and call_args['no_save_pred']):
92+
print_log(
93+
f'results have been saved at {call_args["out_dir"]}',
94+
logger='current')
6695

6796

6897
if __name__ == '__main__':
69-
args = parse_args()
70-
main(args)
98+
main()

0 commit comments

Comments
 (0)