Skip to content

Commit 33a7b19

Browse files
committed
final code
1 parent 7ab0610 commit 33a7b19

File tree

7 files changed

+171
-100
lines changed

7 files changed

+171
-100
lines changed

experiments/train_roma.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
import romatch.utils.writer as writ
1111
from romatch.benchmarks import MixedDenseBenchmark, MixedVisualizeBenchmark
12-
from romatch.datasets import get_mixed_dataset, get_extredata_dataset
12+
from romatch.datasets import get_mixed_dataset, get_extredata_dataset, get_megadepth_dataset
1313
from romatch.losses.robust_loss import RobustLosses
1414
from romatch.utils.collate import collate_fn_with
1515

@@ -200,7 +200,7 @@ def train(args):
200200
experiment_name += "_pretrained_weights"
201201

202202
writ.init_writer(experiment_name, rank)
203-
pl.seed_everything(args.seed) # for reproducibility
203+
pl.seed_everything(args.seed)
204204

205205
checkpoint_dir = "workspace/checkpoints/"
206206
h, w = resolutions[resolution]
@@ -234,12 +234,11 @@ def train(args):
234234
if not romatch.TEST_MODE:
235235
# Data
236236
if args.use_pretained_roma:
237-
# When finetuning, use extredata dataset only
238-
dataset, dataset_ws = get_extredata_dataset(
239-
h, w, train=True)
237+
dataset, dataset_ws = get_mixed_dataset(
238+
h, w, train=True, mega_percent=0.8)
240239
else:
241240
dataset, dataset_ws = get_mixed_dataset(
242-
h, w, train=True, mega_percent=0.1)
241+
h, w, train=True, mega_percent=0.9)
243242

244243
# Loss and optimizer
245244
depth_loss = RobustLosses(
@@ -254,8 +253,8 @@ def train(args):
254253
if args.use_pretained_roma:
255254
# Use smaller learning rate for pretrained weights
256255
parameters = [
257-
{"params": model.encoder.parameters(), "lr": romatch.STEP_SIZE * 5e-6 / 80},
258-
{"params": model.decoder.parameters(), "lr": romatch.STEP_SIZE * 1e-4 / 80},
256+
{"params": model.encoder.parameters(), "lr": romatch.STEP_SIZE * 5e-6 / 800},
257+
{"params": model.decoder.parameters(), "lr": romatch.STEP_SIZE * 1e-4 / 800},
259258
]
260259
else:
261260
parameters = [
@@ -271,11 +270,8 @@ def train(args):
271270
h=h, w=w, num_samples=1000, dataset="extredata")
272271
megadepth_benchmark = MixedDenseBenchmark(
273272
h=h, w=w, num_samples=1000, dataset="megadepth")
274-
275-
# When finetuning, use extredata dataset only
276-
vis_dataset = "extredata" if args.use_pretained_roma else "mixed"
277273
mixed_visualize_benchmark = MixedVisualizeBenchmark(
278-
h=h, w=w, count=8, dataset=vis_dataset)
274+
h=h, w=w, count=8, dataset="mixed")
279275

280276
checkpointer = CheckPoint(checkpoint_dir, experiment_name)
281277
model, optimizer, lr_scheduler, global_step = checkpointer.load(
@@ -305,7 +301,7 @@ def train(args):
305301
dataset,
306302
batch_size=batch_size,
307303
sampler=sampler,
308-
num_workers=8,
304+
num_workers=32,
309305
collate_fn=collate_fn_with(dataset),
310306
)
311307
)

romatch/benchmarks/mixed_dense_benchmark.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def __init__(self, h=384, w=512, num_samples=2000, dataset="megadepth") -> None:
1919
self.dataset, self.ws = get_megadepth_dataset(h, w, train=False)
2020
elif dataset == "mixed":
2121
self.dataset, self.ws = get_mixed_dataset(
22-
h, w, train=False, mega_percent=0.1)
22+
h, w, train=False, mega_percent=0.6)
2323

2424
def geometric_dist(self, depth1, depth2, T_1to2, K1, K2, dense_matches):
2525
b, h1, w1, d = dense_matches.shape

romatch/benchmarks/mixed_visualize_benchmark.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def __init__(self, h=384, w=512, count=8, dataset="mixed") -> None:
1919
self.dataset, self.ws = get_megadepth_dataset(h, w, train=False)
2020
elif dataset == "mixed":
2121
self.dataset, self.ws = get_mixed_dataset(
22-
h, w, train=False, mega_percent=0.1)
22+
h, w, train=False, mega_percent=0.6)
2323

2424
self.sampler = torch.utils.data.WeightedRandomSampler(
2525
self.ws, replacement=False, num_samples=100

romatch/datasets/extredata.py

Lines changed: 93 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
from PIL import Image
33
from torch.utils.data import ConcatDataset
44
from romatch.utils import get_tuple_transform_ops, get_depth_tuple_transform_ops
5+
from romatch.utils.transforms import RandomColorAug
6+
import torchvision.transforms.functional as tvf
57
import numpy as np
68
import torch
79

@@ -21,7 +23,12 @@ def __init__(
2123
wt=560,
2224
min_overlap=0.0,
2325
max_overlap=1.0,
26+
shake_t=0,
2427
normalize=True,
28+
use_horizontal_flip_aug=False,
29+
use_single_horizontal_flip_aug=False,
30+
random_eraser=None,
31+
use_randaug=False,
2532
max_num_pairs=20000, # * total 2499030
2633
) -> None:
2734
self.data_root = data_root
@@ -49,18 +56,38 @@ def __init__(
4956
self.pairs = self.pairs[pairinds]
5057
self.overlaps = self.overlaps[pairinds]
5158

52-
self.wt, self.ht = wt, ht
5359
self.im_transform_ops = get_tuple_transform_ops(
5460
resize=(ht, wt),
5561
normalize=normalize,
5662
)
5763
self.depth_transform_ops = get_depth_tuple_transform_ops(
5864
resize=(ht, wt)
5965
)
66+
self.wt, self.ht = wt, ht
67+
self.shake_t = shake_t
68+
69+
if use_horizontal_flip_aug and use_single_horizontal_flip_aug:
70+
raise ValueError("Can't both flip both images and only flip one")
71+
self.use_horizontal_flip_aug = use_horizontal_flip_aug
72+
self.use_single_horizontal_flip_aug = use_single_horizontal_flip_aug
73+
74+
self.use_randaug = use_randaug
75+
self.random_eraser = random_eraser
6076

6177
def load_im(self, path):
6278
return Image.open(path)
6379

80+
def horizontal_flip(self, im_A, im_B, depth_A, depth_B, K_A, K_B):
81+
im_A = im_A.flip(-1)
82+
im_B = im_B.flip(-1)
83+
depth_A, depth_B = depth_A.flip(-1), depth_B.flip(-1)
84+
flip_mat = torch.tensor(
85+
[[-1, 0, self.wt], [0, 1, 0], [0, 0, 1.]]).to(K_A.device)
86+
K_A = flip_mat@K_A
87+
K_B = flip_mat@K_B
88+
89+
return im_A, im_B, depth_A, depth_B, K_A, K_B
90+
6491
def load_depth(self, depth_ref):
6592
depth = cv2.imread(depth_ref, cv2.IMREAD_UNCHANGED)
6693
return torch.tensor(depth[:, :, 0])
@@ -73,6 +100,24 @@ def scale_intrinsic(self, K, wi, hi):
73100
sK = torch.tensor([[sx, 0, 0], [0, sy, 0], [0, 0, 1]])
74101
return sK @ K
75102

103+
def rand_shake(self, *things):
104+
t = np.random.choice(range(-self.shake_t, self.shake_t + 1), size=2)
105+
return [
106+
tvf.affine(thing, angle=0.0, translate=list(
107+
t), scale=1.0, shear=[0.0, 0.0])
108+
for thing in things
109+
], t
110+
111+
def rand_augment(self, im_A, im_B):
112+
im_A = np.array(im_A)
113+
im_B = np.array(im_B)
114+
random_color_aug = RandomColorAug()
115+
im_A = random_color_aug(im_A)
116+
im_B = random_color_aug(im_B)
117+
im_A = Image.fromarray(im_A)
118+
im_B = Image.fromarray(im_B)
119+
return im_A, im_B
120+
76121
def __getitem__(self, pair_idx):
77122
# read intrinsics of original size
78123
idx1, idx2 = self.pairs[pair_idx]
@@ -104,40 +149,46 @@ def __getitem__(self, pair_idx):
104149
K1 = self.scale_intrinsic(K1, im_A.width, im_A.height)
105150
K2 = self.scale_intrinsic(K2, im_B.width, im_B.height)
106151

107-
# * im_A: (640, 512) ImageFile
108-
# * depth_A: [512, 640]
109-
# plt.figure()
110-
# plt.subplot(2, 2, 1)
111-
# plt.imshow(im_A)
112-
# plt.subplot(2, 2, 2)
113-
# plt.imshow(depth_A)
114-
115152
# Process images
116-
im_A, im_B = self.im_transform_ops((im_A, im_B))
117-
depth_A, depth_B = self.depth_transform_ops(
118-
(depth_A[None, None], depth_B[None, None])
119-
)
120-
121-
# * im_A: [3, 560, 560]
122-
# * depth_A: [1, 1, 560, 560]
123-
# plt.subplot(2, 2, 3)
124-
# plt.imshow(im_A.permute(1, 2, 0) * 0.5 + 0.5)
125-
# plt.subplot(2, 2, 4)
126-
# plt.imshow(depth_A[0, 0])
127-
# plt.tight_layout()
128-
# plt.show()
153+
try:
154+
if self.use_randaug:
155+
im_A, im_B = self.rand_augment(im_A, im_B)
129156

130-
im_A, im_B = im_A[None], im_B[None]
157+
im_A, im_B = self.im_transform_ops((im_A, im_B))
158+
depth_A, depth_B = self.depth_transform_ops(
159+
(depth_A[None, None], depth_B[None, None])
160+
)
131161

132-
# * im_A: [1, 3, 560, 560]
133-
# * depth_A: [1, 1, 560, 560]
162+
[im_A, im_B, depth_A, depth_B], t = self.rand_shake(
163+
im_A, im_B, depth_A, depth_B)
164+
K1[:2, 2] += t
165+
K2[:2, 2] += t
166+
167+
im_A, im_B = im_A[None], im_B[None]
168+
if self.random_eraser is not None:
169+
im_A, depth_A = self.random_eraser(im_A, depth_A)
170+
im_B, depth_B = self.random_eraser(im_B, depth_B)
171+
172+
if self.use_horizontal_flip_aug:
173+
if np.random.rand() > 0.5:
174+
im_A, im_B, depth_A, depth_B, K1, K2 = self.horizontal_flip(
175+
im_A, im_B, depth_A, depth_B, K1, K2)
176+
177+
if self.use_single_horizontal_flip_aug:
178+
if np.random.rand() > 0.5:
179+
im_B, depth_B, K2 = self.single_horizontal_flip(
180+
im_B, depth_B, K2)
181+
except Exception as e:
182+
print(
183+
f"Error in transform ({self.image_paths[idx1]}, {self.image_paths[idx1]}):", e)
184+
return None
134185

135186
data_dict = {
136-
"im_A": im_A[0], # * [3, 560, 560]
187+
"im_A": im_A[0],
137188
"im_A_identifier": self.image_paths[idx1].split("/")[-1].split(".jpg")[0],
138189
"im_B": im_B[0],
139190
"im_B_identifier": self.image_paths[idx2].split("/")[-1].split(".jpg")[0],
140-
"im_A_depth": depth_A[0, 0], # * [560, 560]
191+
"im_A_depth": depth_A[0, 0],
141192
"im_B_depth": depth_B[0, 0],
142193
"K1": K1,
143194
"K2": K2,
@@ -154,19 +205,19 @@ def __init__(self, data_root: str = "./data/extredata") -> None:
154205
self.data_root = data_root
155206
self.scene_info_root = os.path.join(data_root, "scene_info")
156207
self.all_scenes = set(os.listdir(self.scene_info_root))
157-
self.test_scenes = {"Madrid4_117@-83@276@68@0@90.npy",
158-
"Madrid4_90@-33@76@58@0@90.npy",
159-
"Madrid1_93@467@65@51@0@90.npy",
160-
"Berlin6_141@17@21@70@0@90.npy",
161-
"Tokyo5_92@167@326@64@0@90.npy",
162-
"Madrid1_93@-233@-385@59@0@90.npy",
163-
"German5_61@-263@139@63@0@90.npy",
164-
"Milano3_123@-434@318@58@0@90.npy",
165-
"NewYork4_138@-133@169@66@0@90.npy",
166-
"Bern0_143@216@-387@51@0@90.npy",
167-
"Berlin0_111@-133@-280@52@0@90.npy",
168-
"Madrid0_122@167@215@51@0@90.npy",
169-
"Milano2_134@116@218@51@0@90.npy"}
208+
self.test_scenes = {"Madrid4_117@-83@276@68@0@90.npz",
209+
"Madrid4_90@-33@76@58@0@90.npz",
210+
"Madrid1_93@467@65@51@0@90.npz",
211+
"Berlin6_141@17@21@70@0@90.npz",
212+
"Tokyo5_92@167@326@64@0@90.npz",
213+
"Madrid1_93@-233@-385@59@0@90.npz",
214+
"German5_61@-263@139@63@0@90.npz",
215+
"Milano3_123@-434@318@58@0@90.npz",
216+
"NewYork4_138@-133@169@66@0@90.npz",
217+
"Bern0_143@216@-387@51@0@90.npz",
218+
"Berlin0_111@-133@-280@52@0@90.npz",
219+
"Madrid0_122@167@215@51@0@90.npz",
220+
"Milano2_134@116@218@51@0@90.npz"}
170221
self.ignore_scenes = set()
171222

172223
def build_scenes(self, split: str = "train", **kwargs):
@@ -179,10 +230,10 @@ def build_scenes(self, split: str = "train", **kwargs):
179230

180231
scenes = []
181232
for scene_name in scene_names:
182-
if ".npy" not in scene_name:
233+
if ".npz" not in scene_name:
183234
continue
184235
scene_info_path = os.path.join(self.scene_info_root, scene_name)
185-
scene_info = np.load(scene_info_path, allow_pickle=True).item()
236+
scene_info = np.load(scene_info_path, allow_pickle=True)
186237
scene = ExtredataScene(
187238
data_root=self.data_root,
188239
scene_info=scene_info,
@@ -199,10 +250,3 @@ def weight_scenes(self, concat_dataset, alpha: float = 0.5) -> torch.Tensor:
199250
ns.append(len(d))
200251
ws = torch.cat([torch.ones(n) / n**alpha for n in ns])
201252
return ws
202-
203-
204-
if __name__ == "__main__":
205-
dataset = ExtredataBuilder()
206-
train1 = dataset.build_scenes()
207-
train = ConcatDataset(train1)
208-
print(len(train)) # * 2499030

romatch/datasets/megadepth.py

Lines changed: 32 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44
import numpy as np
55
import torch
66
import torchvision.transforms.functional as tvf
7-
from romatch.utils import get_depth_tuple_transform_ops, get_tuple_transform_ops
8-
import romatch
97
from romatch.utils import *
8+
from romatch.utils.transforms import RandomColorAug
9+
import romatch
1010
import math
1111

1212

@@ -149,34 +149,38 @@ def __getitem__(self, pair_idx):
149149
K1 = self.scale_intrinsic(K1, im_A.width, im_A.height)
150150
K2 = self.scale_intrinsic(K2, im_B.width, im_B.height)
151151

152-
if self.use_randaug:
153-
im_A, im_B = self.rand_augment(im_A, im_B)
154-
155152
# Process images
156-
im_A, im_B = self.im_transform_ops((im_A, im_B))
157-
depth_A, depth_B = self.depth_transform_ops(
158-
(depth_A[None, None], depth_B[None, None])
159-
)
153+
try:
154+
if self.use_randaug:
155+
im_A, im_B = self.rand_augment(im_A, im_B)
160156

161-
[im_A, im_B, depth_A, depth_B], t = self.rand_shake(
162-
im_A, im_B, depth_A, depth_B)
163-
K1[:2, 2] += t
164-
K2[:2, 2] += t
165-
166-
im_A, im_B = im_A[None], im_B[None]
167-
if self.random_eraser is not None:
168-
im_A, depth_A = self.random_eraser(im_A, depth_A)
169-
im_B, depth_B = self.random_eraser(im_B, depth_B)
170-
171-
if self.use_horizontal_flip_aug:
172-
if np.random.rand() > 0.5:
173-
im_A, im_B, depth_A, depth_B, K1, K2 = self.horizontal_flip(
174-
im_A, im_B, depth_A, depth_B, K1, K2)
175-
176-
if self.use_single_horizontal_flip_aug:
177-
if np.random.rand() > 0.5:
178-
im_B, depth_B, K2 = self.single_horizontal_flip(
179-
im_B, depth_B, K2)
157+
im_A, im_B = self.im_transform_ops((im_A, im_B))
158+
depth_A, depth_B = self.depth_transform_ops(
159+
(depth_A[None, None], depth_B[None, None])
160+
)
161+
162+
[im_A, im_B, depth_A, depth_B], t = self.rand_shake(
163+
im_A, im_B, depth_A, depth_B)
164+
K1[:2, 2] += t
165+
K2[:2, 2] += t
166+
167+
im_A, im_B = im_A[None], im_B[None]
168+
if self.random_eraser is not None:
169+
im_A, depth_A = self.random_eraser(im_A, depth_A)
170+
im_B, depth_B = self.random_eraser(im_B, depth_B)
171+
172+
if self.use_horizontal_flip_aug:
173+
if np.random.rand() > 0.5:
174+
im_A, im_B, depth_A, depth_B, K1, K2 = self.horizontal_flip(
175+
im_A, im_B, depth_A, depth_B, K1, K2)
176+
177+
if self.use_single_horizontal_flip_aug:
178+
if np.random.rand() > 0.5:
179+
im_B, depth_B, K2 = self.single_horizontal_flip(
180+
im_B, depth_B, K2)
181+
except Exception as e:
182+
print(f"Error in transform ({self.image_paths[idx1]}, {self.image_paths[idx1]}):", e)
183+
return None
180184

181185
if romatch.DEBUG_MODE:
182186
tensor_to_pil(im_A[0], unnormalize=True)\

0 commit comments

Comments
 (0)