44import os
55import os .path as osp
66import time
7- from functools import partial
87
9- from tqdm import tqdm
10- import numpy as np
118import mmcv
129import mmengine
10+ import numpy as np
11+ from tqdm import tqdm
1312
1413from mmdeploy .apis import visualize_model
15- from mmdeploy .utils import (Backend , get_backend , get_root_logger ,
16- load_config )
14+ from mmdeploy .utils import Backend , get_backend , get_root_logger , load_config
1715
1816
1917def parse_args ():
20- parser = argparse .ArgumentParser (description = 'Model inference visualization.' )
18+ parser = argparse .ArgumentParser (
19+ description = 'Model inference visualization.' )
2120 parser .add_argument ('--deploy-cfg' , help = 'deploy config path' )
2221 parser .add_argument ('--model-cfg' , help = 'model config path' )
23- parser .add_argument ('--deploy-path' ,
24- type = str ,
25- nargs = '+' ,
26- help = 'deploy model path' )
2722 parser .add_argument (
28- '--checkpoint ' ,
29- default = None ,
30- help = 'model checkpoint path' )
23+ '--deploy-path ' , type = str , nargs = '+' , help = 'deploy model path' )
24+ parser . add_argument (
25+ '--checkpoint' , default = None , help = 'model checkpoint path' )
3126 parser .add_argument (
3227 '--test-img' ,
3328 default = None ,
3429 type = str ,
3530 nargs = '+' ,
3631 help = 'image used to test model' )
3732 parser .add_argument (
38- '--save-dir' ,
39- default = None ,
40- help = 'the dir to save inference results' )
41- parser .add_argument (
42- '--device' , help = 'device to run model' , default = 'cpu' )
33+ '--save-dir' , default = None , help = 'the dir to save inference results' )
34+ parser .add_argument ('--device' , help = 'device to run model' , default = 'cpu' )
4335 parser .add_argument (
4436 '--log-level' ,
4537 help = 'set log level' ,
@@ -65,61 +57,48 @@ def main():
6557 deploy_model_path = args .deploy_path
6658 if not isinstance (deploy_model_path , list ):
6759 deploy_model_path = [deploy_model_path ]
68-
60+
6961 # load deploy_cfg
7062 deploy_cfg = load_config (deploy_cfg_path )[0 ]
71-
63+
7264 # create save_dir or generate default save_dir
7365 save_dir = args .save_dir
7466 if save_dir :
7567 # generate default dir
7668 current_time = time .localtime ()
77- save_dir = osp .join (
78- os .getcwd (), time .strftime ("%Y_%m_%d_%H_%M_%S" , current_time )
79- )
69+ save_dir = osp .join (os .getcwd (),
70+ time .strftime ('%Y_%m_%d_%H_%M_%S' , current_time ))
8071 mmengine .mkdir_or_exist (save_dir )
8172
8273 # get backend info
8374 backend = get_backend (deploy_cfg )
8475 extra = dict ()
8576 if backend == Backend .SNPE :
8677 extra ['uri' ] = args .uri
87-
78+
8879 # iterate single_img
8980 for single_img in tqdm (args .test_img ):
9081 filename = osp .basename (single_img )
91- output_file = osp .join (save_dir , filename )
92- visualize_model (
93- model_cfg_path ,
94- deploy_cfg_path ,
95- deploy_model_path ,
96- single_img ,
97- args .device ,
98- backend ,
99- output_file ,
100- False ,
101- ** extra )
102-
82+ output_file = osp .join (save_dir , filename )
83+ visualize_model (model_cfg_path , deploy_cfg_path , deploy_model_path ,
84+ single_img , args .device , backend , output_file , False ,
85+ ** extra )
86+
10387 if checkpoint_path :
10488 pytorch_output_file = osp .join (save_dir , 'pytorch_out.jpg' )
105- visualize_model (
106- model_cfg_path ,
107- deploy_cfg_path ,
108- [checkpoint_path ],
109- single_img ,
110- args .device ,
111- Backend .PYTORCH ,
112- pytorch_output_file ,
113- False )
114-
89+ visualize_model (model_cfg_path , deploy_cfg_path , [checkpoint_path ],
90+ single_img , args .device , Backend .PYTORCH ,
91+ pytorch_output_file , False )
92+
11593 # concat pytorch result and backend result
11694 backend_result = mmcv .imread (output_file )
11795 pytorch_result = mmcv .imread (pytorch_output_file )
11896 result = np .concatenate ((backend_result , pytorch_result ), axis = 1 )
11997 mmcv .imwrite (result , output_file )
120-
98+
12199 # remove temp pytorch result
122100 os .remove (osp .join (save_dir , pytorch_output_file ))
123101
102+
124103if __name__ == '__main__' :
125- main ()
104+ main ()
0 commit comments