22from PIL import Image
33from torch .utils .data import ConcatDataset
44from romatch .utils import get_tuple_transform_ops , get_depth_tuple_transform_ops
5+ from romatch .utils .transforms import RandomColorAug
6+ import torchvision .transforms .functional as tvf
57import numpy as np
68import torch
79
@@ -21,7 +23,12 @@ def __init__(
2123 wt = 560 ,
2224 min_overlap = 0.0 ,
2325 max_overlap = 1.0 ,
26+ shake_t = 0 ,
2427 normalize = True ,
28+ use_horizontal_flip_aug = False ,
29+ use_single_horizontal_flip_aug = False ,
30+ random_eraser = None ,
31+ use_randaug = False ,
2532 max_num_pairs = 20000 , # * total 2499030
2633 ) -> None :
2734 self .data_root = data_root
@@ -49,18 +56,38 @@ def __init__(
4956 self .pairs = self .pairs [pairinds ]
5057 self .overlaps = self .overlaps [pairinds ]
5158
52- self .wt , self .ht = wt , ht
5359 self .im_transform_ops = get_tuple_transform_ops (
5460 resize = (ht , wt ),
5561 normalize = normalize ,
5662 )
5763 self .depth_transform_ops = get_depth_tuple_transform_ops (
5864 resize = (ht , wt )
5965 )
66+ self .wt , self .ht = wt , ht
67+ self .shake_t = shake_t
68+
69+ if use_horizontal_flip_aug and use_single_horizontal_flip_aug :
70+ raise ValueError ("Can't both flip both images and only flip one" )
71+ self .use_horizontal_flip_aug = use_horizontal_flip_aug
72+ self .use_single_horizontal_flip_aug = use_single_horizontal_flip_aug
73+
74+ self .use_randaug = use_randaug
75+ self .random_eraser = random_eraser
6076
6177 def load_im (self , path ):
6278 return Image .open (path )
6379
80+ def horizontal_flip (self , im_A , im_B , depth_A , depth_B , K_A , K_B ):
81+ im_A = im_A .flip (- 1 )
82+ im_B = im_B .flip (- 1 )
83+ depth_A , depth_B = depth_A .flip (- 1 ), depth_B .flip (- 1 )
84+ flip_mat = torch .tensor (
85+ [[- 1 , 0 , self .wt ], [0 , 1 , 0 ], [0 , 0 , 1. ]]).to (K_A .device )
86+ K_A = flip_mat @K_A
87+ K_B = flip_mat @K_B
88+
89+ return im_A , im_B , depth_A , depth_B , K_A , K_B
90+
6491 def load_depth (self , depth_ref ):
6592 depth = cv2 .imread (depth_ref , cv2 .IMREAD_UNCHANGED )
6693 return torch .tensor (depth [:, :, 0 ])
@@ -73,6 +100,24 @@ def scale_intrinsic(self, K, wi, hi):
73100 sK = torch .tensor ([[sx , 0 , 0 ], [0 , sy , 0 ], [0 , 0 , 1 ]])
74101 return sK @ K
75102
103+ def rand_shake (self , * things ):
104+ t = np .random .choice (range (- self .shake_t , self .shake_t + 1 ), size = 2 )
105+ return [
106+ tvf .affine (thing , angle = 0.0 , translate = list (
107+ t ), scale = 1.0 , shear = [0.0 , 0.0 ])
108+ for thing in things
109+ ], t
110+
111+ def rand_augment (self , im_A , im_B ):
112+ im_A = np .array (im_A )
113+ im_B = np .array (im_B )
114+ random_color_aug = RandomColorAug ()
115+ im_A = random_color_aug (im_A )
116+ im_B = random_color_aug (im_B )
117+ im_A = Image .fromarray (im_A )
118+ im_B = Image .fromarray (im_B )
119+ return im_A , im_B
120+
76121 def __getitem__ (self , pair_idx ):
77122 # read intrinsics of original size
78123 idx1 , idx2 = self .pairs [pair_idx ]
@@ -104,40 +149,46 @@ def __getitem__(self, pair_idx):
104149 K1 = self .scale_intrinsic (K1 , im_A .width , im_A .height )
105150 K2 = self .scale_intrinsic (K2 , im_B .width , im_B .height )
106151
107- # * im_A: (640, 512) ImageFile
108- # * depth_A: [512, 640]
109- # plt.figure()
110- # plt.subplot(2, 2, 1)
111- # plt.imshow(im_A)
112- # plt.subplot(2, 2, 2)
113- # plt.imshow(depth_A)
114-
115152 # Process images
116- im_A , im_B = self .im_transform_ops ((im_A , im_B ))
117- depth_A , depth_B = self .depth_transform_ops (
118- (depth_A [None , None ], depth_B [None , None ])
119- )
120-
121- # * im_A: [3, 560, 560]
122- # * depth_A: [1, 1, 560, 560]
123- # plt.subplot(2, 2, 3)
124- # plt.imshow(im_A.permute(1, 2, 0) * 0.5 + 0.5)
125- # plt.subplot(2, 2, 4)
126- # plt.imshow(depth_A[0, 0])
127- # plt.tight_layout()
128- # plt.show()
153+ try :
154+ if self .use_randaug :
155+ im_A , im_B = self .rand_augment (im_A , im_B )
129156
130- im_A , im_B = im_A [None ], im_B [None ]
157+ im_A , im_B = self .im_transform_ops ((im_A , im_B ))
158+ depth_A , depth_B = self .depth_transform_ops (
159+ (depth_A [None , None ], depth_B [None , None ])
160+ )
131161
132- # * im_A: [1, 3, 560, 560]
133- # * depth_A: [1, 1, 560, 560]
162+ [im_A , im_B , depth_A , depth_B ], t = self .rand_shake (
163+ im_A , im_B , depth_A , depth_B )
164+ K1 [:2 , 2 ] += t
165+ K2 [:2 , 2 ] += t
166+
167+ im_A , im_B = im_A [None ], im_B [None ]
168+ if self .random_eraser is not None :
169+ im_A , depth_A = self .random_eraser (im_A , depth_A )
170+ im_B , depth_B = self .random_eraser (im_B , depth_B )
171+
172+ if self .use_horizontal_flip_aug :
173+ if np .random .rand () > 0.5 :
174+ im_A , im_B , depth_A , depth_B , K1 , K2 = self .horizontal_flip (
175+ im_A , im_B , depth_A , depth_B , K1 , K2 )
176+
177+ if self .use_single_horizontal_flip_aug :
178+ if np .random .rand () > 0.5 :
179+ im_B , depth_B , K2 = self .single_horizontal_flip (
180+ im_B , depth_B , K2 )
181+ except Exception as e :
182+ print (
183+ f"Error in transform ({ self .image_paths [idx1 ]} , { self .image_paths [idx1 ]} ):" , e )
184+ return None
134185
135186 data_dict = {
136- "im_A" : im_A [0 ], # * [3, 560, 560]
187+ "im_A" : im_A [0 ],
137188 "im_A_identifier" : self .image_paths [idx1 ].split ("/" )[- 1 ].split (".jpg" )[0 ],
138189 "im_B" : im_B [0 ],
139190 "im_B_identifier" : self .image_paths [idx2 ].split ("/" )[- 1 ].split (".jpg" )[0 ],
140- "im_A_depth" : depth_A [0 , 0 ], # * [560, 560]
191+ "im_A_depth" : depth_A [0 , 0 ],
141192 "im_B_depth" : depth_B [0 , 0 ],
142193 "K1" : K1 ,
143194 "K2" : K2 ,
@@ -154,19 +205,19 @@ def __init__(self, data_root: str = "./data/extredata") -> None:
154205 self .data_root = data_root
155206 self .scene_info_root = os .path .join (data_root , "scene_info" )
156207 self .all_scenes = set (os .listdir (self .scene_info_root ))
157- self .test_scenes = {"Madrid4_117@-83@276@68@0@90.npy " ,
158- "Madrid4_90@-33@76@58@0@90.npy " ,
159- "Madrid1_93@467@65@51@0@90.npy " ,
160- "Berlin6_141@17@21@70@0@90.npy " ,
161- "Tokyo5_92@167@326@64@0@90.npy " ,
162- "Madrid1_93@-233@-385@59@0@90.npy " ,
163- "German5_61@-263@139@63@0@90.npy " ,
164- "Milano3_123@-434@318@58@0@90.npy " ,
165- "NewYork4_138@-133@169@66@0@90.npy " ,
166- "Bern0_143@216@-387@51@0@90.npy " ,
167- "Berlin0_111@-133@-280@52@0@90.npy " ,
168- "Madrid0_122@167@215@51@0@90.npy " ,
169- "Milano2_134@116@218@51@0@90.npy " }
208+ self .test_scenes = {"Madrid4_117@-83@276@68@0@90.npz " ,
209+ "Madrid4_90@-33@76@58@0@90.npz " ,
210+ "Madrid1_93@467@65@51@0@90.npz " ,
211+ "Berlin6_141@17@21@70@0@90.npz " ,
212+ "Tokyo5_92@167@326@64@0@90.npz " ,
213+ "Madrid1_93@-233@-385@59@0@90.npz " ,
214+ "German5_61@-263@139@63@0@90.npz " ,
215+ "Milano3_123@-434@318@58@0@90.npz " ,
216+ "NewYork4_138@-133@169@66@0@90.npz " ,
217+ "Bern0_143@216@-387@51@0@90.npz " ,
218+ "Berlin0_111@-133@-280@52@0@90.npz " ,
219+ "Madrid0_122@167@215@51@0@90.npz " ,
220+ "Milano2_134@116@218@51@0@90.npz " }
170221 self .ignore_scenes = set ()
171222
172223 def build_scenes (self , split : str = "train" , ** kwargs ):
@@ -179,10 +230,10 @@ def build_scenes(self, split: str = "train", **kwargs):
179230
180231 scenes = []
181232 for scene_name in scene_names :
182- if ".npy " not in scene_name :
233+ if ".npz " not in scene_name :
183234 continue
184235 scene_info_path = os .path .join (self .scene_info_root , scene_name )
185- scene_info = np .load (scene_info_path , allow_pickle = True ). item ()
236+ scene_info = np .load (scene_info_path , allow_pickle = True )
186237 scene = ExtredataScene (
187238 data_root = self .data_root ,
188239 scene_info = scene_info ,
@@ -199,10 +250,3 @@ def weight_scenes(self, concat_dataset, alpha: float = 0.5) -> torch.Tensor:
199250 ns .append (len (d ))
200251 ws = torch .cat ([torch .ones (n ) / n ** alpha for n in ns ])
201252 return ws
202-
203-
204- if __name__ == "__main__" :
205- dataset = ExtredataBuilder ()
206- train1 = dataset .build_scenes ()
207- train = ConcatDataset (train1 )
208- print (len (train )) # * 2499030
0 commit comments