Skip to content

Commit 0a8e585

Browse files
committed
Fix visualize script.
1 parent fbae550 commit 0a8e585

File tree

1 file changed

+27
-48
lines changed

1 file changed

+27
-48
lines changed

tools/visualize.py

Lines changed: 27 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -4,42 +4,34 @@
44
import os
55
import os.path as osp
66
import time
7-
from functools import partial
87

9-
from tqdm import tqdm
10-
import numpy as np
118
import mmcv
129
import mmengine
10+
import numpy as np
11+
from tqdm import tqdm
1312

1413
from mmdeploy.apis import visualize_model
15-
from mmdeploy.utils import (Backend, get_backend, get_root_logger,
16-
load_config)
14+
from mmdeploy.utils import Backend, get_backend, get_root_logger, load_config
1715

1816

1917
def parse_args():
20-
parser = argparse.ArgumentParser(description='Model inference visualization.')
18+
parser = argparse.ArgumentParser(
19+
description='Model inference visualization.')
2120
parser.add_argument('--deploy-cfg', help='deploy config path')
2221
parser.add_argument('--model-cfg', help='model config path')
23-
parser.add_argument('--deploy-path',
24-
type=str,
25-
nargs='+',
26-
help='deploy model path')
2722
parser.add_argument(
28-
'--checkpoint',
29-
default=None,
30-
help='model checkpoint path')
23+
'--deploy-path', type=str, nargs='+', help='deploy model path')
24+
parser.add_argument(
25+
'--checkpoint', default=None, help='model checkpoint path')
3126
parser.add_argument(
3227
'--test-img',
3328
default=None,
3429
type=str,
3530
nargs='+',
3631
help='image used to test model')
3732
parser.add_argument(
38-
'--save-dir',
39-
default=None,
40-
help='the dir to save inference results')
41-
parser.add_argument(
42-
'--device', help='device to run model', default='cpu')
33+
'--save-dir', default=None, help='the dir to save inference results')
34+
parser.add_argument('--device', help='device to run model', default='cpu')
4335
parser.add_argument(
4436
'--log-level',
4537
help='set log level',
@@ -65,61 +57,48 @@ def main():
6557
deploy_model_path = args.deploy_path
6658
if not isinstance(deploy_model_path, list):
6759
deploy_model_path = [deploy_model_path]
68-
60+
6961
# load deploy_cfg
7062
deploy_cfg = load_config(deploy_cfg_path)[0]
71-
63+
7264
# create save_dir or generate default save_dir
7365
save_dir = args.save_dir
7466
if save_dir:
7567
# generate default dir
7668
current_time = time.localtime()
77-
save_dir = osp.join(
78-
os.getcwd(), time.strftime("%Y_%m_%d_%H_%M_%S", current_time)
79-
)
69+
save_dir = osp.join(os.getcwd(),
70+
time.strftime('%Y_%m_%d_%H_%M_%S', current_time))
8071
mmengine.mkdir_or_exist(save_dir)
8172

8273
# get backend info
8374
backend = get_backend(deploy_cfg)
8475
extra = dict()
8576
if backend == Backend.SNPE:
8677
extra['uri'] = args.uri
87-
78+
8879
# iterate single_img
8980
for single_img in tqdm(args.test_img):
9081
filename = osp.basename(single_img)
91-
output_file=osp.join(save_dir, filename)
92-
visualize_model(
93-
model_cfg_path,
94-
deploy_cfg_path,
95-
deploy_model_path,
96-
single_img,
97-
args.device,
98-
backend,
99-
output_file,
100-
False,
101-
**extra)
102-
82+
output_file = osp.join(save_dir, filename)
83+
visualize_model(model_cfg_path, deploy_cfg_path, deploy_model_path,
84+
single_img, args.device, backend, output_file, False,
85+
**extra)
86+
10387
if checkpoint_path:
10488
pytorch_output_file = osp.join(save_dir, 'pytorch_out.jpg')
105-
visualize_model(
106-
model_cfg_path,
107-
deploy_cfg_path,
108-
[checkpoint_path],
109-
single_img,
110-
args.device,
111-
Backend.PYTORCH,
112-
pytorch_output_file,
113-
False)
114-
89+
visualize_model(model_cfg_path, deploy_cfg_path, [checkpoint_path],
90+
single_img, args.device, Backend.PYTORCH,
91+
pytorch_output_file, False)
92+
11593
# concat pytorch result and backend result
11694
backend_result = mmcv.imread(output_file)
11795
pytorch_result = mmcv.imread(pytorch_output_file)
11896
result = np.concatenate((backend_result, pytorch_result), axis=1)
11997
mmcv.imwrite(result, output_file)
120-
98+
12199
# remove temp pytorch result
122100
os.remove(osp.join(save_dir, pytorch_output_file))
123101

102+
124103
if __name__ == '__main__':
125-
main()
104+
main()

0 commit comments

Comments
 (0)