1
1
# Copyright (c) OpenMMLab. All rights reserved.
2
+ import logging
3
+ import os
2
4
from argparse import ArgumentParser
3
5
4
- import mmcv
6
+ from mmengine . logging import print_log
5
7
6
- from mmdet3d .apis import inference_mono_3d_detector , init_model
7
- from mmdet3d .registry import VISUALIZERS
8
+ from mmdet3d .apis import MonoDet3DInferencer
8
9
9
10
10
11
def parse_args ():
11
12
parser = ArgumentParser ()
12
- parser .add_argument ('img' , help = 'image file' )
13
- parser .add_argument ('ann ' , help = 'ann file' )
14
- parser .add_argument ('config ' , help = 'Config file' )
15
- parser .add_argument ('checkpoint ' , help = 'Checkpoint file' )
13
+ parser .add_argument ('img' , help = 'Image file' )
14
+ parser .add_argument ('infos ' , help = 'Infos file with annotations ' )
15
+ parser .add_argument ('model ' , help = 'Config file' )
16
+ parser .add_argument ('weights ' , help = 'Checkpoint file' )
16
17
parser .add_argument (
17
18
'--device' , default = 'cuda:0' , help = 'Device used for inference' )
18
19
parser .add_argument (
@@ -21,50 +22,77 @@ def parse_args():
21
22
default = 'CAM_BACK' ,
22
23
help = 'choose camera type to inference' )
23
24
parser .add_argument (
24
- '--score-thr' , type = float , default = 0.30 , help = 'bbox score threshold' )
25
+ '--pred-score-thr' ,
26
+ type = float ,
27
+ default = 0.3 ,
28
+ help = 'bbox score threshold' )
25
29
parser .add_argument (
26
- '--out-dir' , type = str , default = 'demo' , help = 'dir to save results' )
30
+ '--out-dir' ,
31
+ type = str ,
32
+ default = 'outputs' ,
33
+ help = 'Output directory of prediction and visualization results.' )
27
34
parser .add_argument (
28
35
'--show' ,
29
36
action = 'store_true' ,
30
- help = 'show online visualization results' )
37
+ help = 'Show online visualization results' )
38
+ parser .add_argument (
39
+ '--wait-time' ,
40
+ type = float ,
41
+ default = - 1 ,
42
+ help = 'The interval of show (s). Demo will be blocked in showing'
43
+ 'results, if wait_time is -1. Defaults to -1.' )
31
44
parser .add_argument (
32
- '--snapshot ' ,
45
+ '--no-save-vis ' ,
33
46
action = 'store_true' ,
34
- help = 'whether to save online visualization results' )
35
- args = parser .parse_args ()
36
- return args
47
+ help = 'Do not save detection visualization results' )
48
+ parser .add_argument (
49
+ '--no-save-pred' ,
50
+ action = 'store_true' ,
51
+ help = 'Do not save detection prediction results' )
52
+ parser .add_argument (
53
+ '--print-result' ,
54
+ action = 'store_true' ,
55
+ help = 'Whether to print the results.' )
56
+ call_args = vars (parser .parse_args ())
57
+
58
+ call_args ['inputs' ] = dict (
59
+ img = call_args .pop ('img' ), infos = call_args .pop ('infos' ))
60
+ call_args .pop ('cam_type' )
61
+
62
+ if call_args ['no_save_vis' ] and call_args ['no_save_pred' ]:
63
+ call_args ['out_dir' ] = ''
64
+
65
+ init_kws = ['model' , 'weights' , 'device' ]
66
+ init_args = {}
67
+ for init_kw in init_kws :
68
+ init_args [init_kw ] = call_args .pop (init_kw )
37
69
70
+ # NOTE: If your operating environment does not have a display device,
71
+ # (e.g. a remote server), you can save the predictions and visualize
72
+ # them in local devices.
73
+ if os .environ .get ('DISPLAY' ) is None and call_args ['show' ]:
74
+ print_log (
75
+ 'Display device not found. `--show` is forced to False' ,
76
+ logger = 'current' ,
77
+ level = logging .WARNING )
78
+ call_args ['show' ] = False
38
79
39
- def main (args ):
40
- # build the model from a config file and a checkpoint file
41
- model = init_model (args .config , args .checkpoint , device = args .device )
80
+ return init_args , call_args
42
81
43
- # init visualizer
44
- visualizer = VISUALIZERS .build (model .cfg .visualizer )
45
- visualizer .dataset_meta = model .dataset_meta
46
82
47
- # test a single image
48
- result = inference_mono_3d_detector ( model , args . img , args . ann ,
49
- args . cam_type )
83
+ def main ():
84
+ # TODO: Support inference of point cloud numpy file.
85
+ init_args , call_args = parse_args ( )
50
86
51
- img = mmcv . imread ( args . img )
52
- img = mmcv . imconvert ( img , 'bgr' , 'rgb' )
87
+ inferencer = MonoDet3DInferencer ( ** init_args )
88
+ inferencer ( ** call_args )
53
89
54
- data_input = dict (img = img )
55
- # show the results
56
- visualizer .add_datasample (
57
- 'result' ,
58
- data_input ,
59
- data_sample = result ,
60
- draw_gt = False ,
61
- show = args .show ,
62
- wait_time = - 1 ,
63
- out_file = args .out_dir ,
64
- pred_score_thr = args .score_thr ,
65
- vis_task = 'mono_det' )
90
+ if call_args ['out_dir' ] != '' and not (call_args ['no_save_vis' ]
91
+ and call_args ['no_save_pred' ]):
92
+ print_log (
93
+ f'results have been saved at { call_args ["out_dir" ]} ' ,
94
+ logger = 'current' )
66
95
67
96
68
97
if __name__ == '__main__' :
69
- args = parse_args ()
70
- main (args )
98
+ main ()
0 commit comments