-
Notifications
You must be signed in to change notification settings - Fork 69
Expand file tree
/
Copy pathvismatch_match.py
More file actions
123 lines (97 loc) · 4.29 KB
/
vismatch_match.py
File metadata and controls
123 lines (97 loc) · 4.29 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
"""
This script performs image matching using a specified matcher model. It processes pairs of input images,
detects keypoints, matches them, and performs RANSAC to find inliers. The results, including visualizations
and metadata, are saved to the specified output directory.
"""
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=UserWarning)
import torch
import argparse
import time
from pathlib import Path
from vismatch.utils import get_image_pairs_paths, get_default_device
from vismatch import get_matcher, available_models
from vismatch.viz import plot_matches
COL_WIDTH = 22
def parse_args():
# Format available matchers in columns, shown at the end of the help message (python vismatch_match.py -h)
matchers, cols, width = sorted(available_models), 4, 35
matcher_lines = [
" " + "".join(m.ljust(width) for m in matchers[i : i + cols]) for i in range(0, len(matchers), cols)
]
parser = argparse.ArgumentParser(
prog="vismatch-match",
description="Match keypoints between image pairs. Outputs match visualizations and result dicts.",
epilog=f"Available matchers ({len(matchers)}):\n" + "\n".join(matcher_lines),
formatter_class=argparse.RawDescriptionHelpFormatter,
)
parser.add_argument(
"--matcher",
type=str,
default="superpoint-lightglue",
choices=available_models,
metavar="MODEL",
help="matcher to use (default: %(default)s). See list below",
)
# Hyperparameters shared by all methods:
parser.add_argument(
"--img-size", type=int, default=512, help="resize img to img-size x img-size (default: %(default)s)"
)
parser.add_argument("--n-kpts", type=int, default=2048, help="max num keypoints (default: %(default)s)")
parser.add_argument(
"--device",
type=str,
default=get_default_device(),
choices=["cpu", "cuda"],
help="device to run on (default: %(default)s)",
)
parser.add_argument("--no-viz", action="store_true", help="avoid saving visualizations")
parser.add_argument(
"--input",
type=Path,
nargs="+", # Accept one or more arguments
required=True,
help="path to either (1) two image paths or (2) dir with two images or (3) dir with dirs with image pairs or "
"(4) txt file with two image paths per line",
)
parser.add_argument(
"--out-dir", type=Path, default=None, help="path where outputs are saved (default: outputs/{matcher})"
)
args = parser.parse_args()
if args.out_dir is None:
args.out_dir = Path("outputs") / args.matcher
return args
def main():
args = parse_args()
image_size = [args.img_size, args.img_size]
args.out_dir.mkdir(exist_ok=True, parents=True)
# Choose a matcher
matcher = get_matcher(args.matcher, device=args.device, max_num_keypoints=args.n_kpts)
print(f"Using matcher: {args.matcher} on device: {args.device}")
print("=" * 80)
pairs_of_paths = get_image_pairs_paths(args.input)
for i, (img0_path, img1_path) in enumerate(pairs_of_paths):
start = time.time()
image0 = matcher.load_image(img0_path, resize=image_size)
image1 = matcher.load_image(img1_path, resize=image_size)
result = matcher(image0, image1)
out_str = f"{'Input':<{COL_WIDTH}}: {img0_path}, {img1_path}\n"
out_str += f"{'Inliers (post-RANSAC)':<{COL_WIDTH}}: {result['num_inliers']}\n"
if not args.no_viz:
viz_path = args.out_dir / f"output_{i}_matches.jpg"
plot_matches(image0, image1, result, save_path=viz_path, show_all_kpts=True)
out_str += f"{'Viz saved in':<{COL_WIDTH}}: {viz_path}\n"
result["img0_path"] = img0_path
result["img1_path"] = img1_path
result["matcher"] = args.matcher
result["n_kpts"] = args.n_kpts
result["im_size"] = args.img_size
dict_path = args.out_dir / f"output_{i}_result.torch"
torch.save(result, dict_path)
out_str += f"{'Output saved in':<{COL_WIDTH}}: {dict_path}\n"
out_str += f"{'Time taken (s)':<{COL_WIDTH}}: {time.time() - start:.3f}\n"
print(out_str)
if __name__ == "__main__":
main()