Skip to content

Commit 52fb40a

Browse files
committed
feat: support INVR dataset
1 parent bb85fdb commit 52fb40a

File tree

3 files changed

+185
-74
lines changed

3 files changed

+185
-74
lines changed

examples/datasets/INVR_N3D.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import imageio.v2 as imageio
1010
from PIL import Image
1111
import numpy as np
12+
1213
import torch
1314
from pycolmap import SceneManager
1415
try:
@@ -38,23 +39,27 @@ def __init__(
3839
multiview: bool = False,
3940
duration: int = 5, # only for testing
4041
resolution_scales: list = [1.0],
41-
resolution: int = 2,
42+
downscale_factor: int = 2,
4243
data_device: str = "cpu",
44+
test_view_id: List[int] = [0]
4345
):
4446
self.model_path = model_path
4547
self.source_path = source_path
4648
self.images_phrase = images_phrase
4749
self.eval = eval
4850
self.duration = duration
4951
self.resolution_scales = resolution_scales
52+
self.test_view_id = test_view_id
5053

5154
self.train_cameras = {}
5255
self.test_cameras = {}
5356
raydict = {}
5457

5558
# Get scene info
5659
if loader == "colmap": # colmapvalid only for testing
57-
scene_info = sceneLoadTypeCallbacks["Colmap"](self.source_path, self.images_phrase, self.eval, multiview, duration=self.duration) # SceneInfo() - NamedTuple
60+
scene_info = sceneLoadTypeCallbacks["Colmap"](self.source_path, self.images_phrase, self.eval, multiview, duration=self.duration, test_view_id=self.test_view_id, downscale_factor=downscale_factor) # SceneInfo() - NamedTuple
61+
elif loader == "invr":
62+
scene_info = sceneLoadTypeCallbacks["INVR"](self.source_path, self.images_phrase, self.eval, multiview, duration=self.duration) # SceneInfo() - NamedTuple
5863
else:
5964
assert False, "Could not recognize scene type!"
6065

@@ -65,7 +70,7 @@ def __init__(
6570
# need modification
6671
class ModelParams():
6772
def __init__(self):
68-
self.resolution = resolution
73+
self.downscale_factor = downscale_factor
6974
self.data_device = data_device
7075
args = ModelParams()
7176
self.args = args
@@ -121,6 +126,7 @@ def __init__(
121126
self.num_views = num_views
122127
self.parser = parser
123128
self.resolution_scale = self.parser.resolution_scales[0]
129+
self.split = split
124130

125131
if split == "train":
126132
self.scene_info = self.parser.scene_info[1]
@@ -153,7 +159,7 @@ def __init__(
153159

154160
self.start_frame = min(scene_by_t.keys())
155161

156-
def __len__(self):
162+
def __len__(self): # num of timestamp
157163
return self.fake_length if self.use_fake_length else len(self.scene_by_t)
158164
# return len(self.scene_info)
159165

@@ -164,16 +170,19 @@ def fetch_image(self, path):
164170
def __getitem__(self, index: int) -> Dict[str, Any]:
165171
tid = index % len(self.scene_by_t)
166172
t_infos = self.scene_by_t[tid + self.start_frame]
167-
try:
168-
frame_infos = random.sample(t_infos, k=self.num_views)
169-
except: #replace
170-
frame_infos = random.choices(t_infos, k=self.num_views)
171-
# frame_infos = np.random.choice(t_infos, self.num_views, replace=False)
173+
if self.split == "train":
174+
try:
175+
frame_infos = random.sample(t_infos, k=self.num_views)
176+
except: #replace
177+
frame_infos = random.choices(t_infos, k=self.num_views)
178+
else:
179+
frame_infos = t_infos[:self.num_views] # take out frames in single imgstamp by default order
180+
172181
K = self.parser.K
173-
scale = self.parser.args.resolution
182+
downscale_factor = self.parser.args.downscale_factor
174183
Ks, images, image_paths, rays, timesteps, camtoworlds = [], [], [], [], [], []
175184
for globalid, cami, finfo in frame_infos:
176-
resolution = (int(finfo.width / scale), int(finfo.height / scale))
185+
resolution = (int(finfo.width / downscale_factor), int(finfo.height / downscale_factor))
177186

178187
images.append(PILtoTorch_new(self.fetch_image(finfo.image_path), resolution).permute(1,2,0))
179188
image_paths.append(finfo.image_path)

examples/helper/STG/camera_utils.py

Lines changed: 25 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,10 @@
2525
def loadCam(args, id, cam_info, resolution_scale):
2626
orig_w, orig_h = cam_info.image.size
2727

28-
if args.resolution in [1, 2, 4, 8]:
29-
resolution = round(orig_w/(resolution_scale * args.resolution)), round(orig_h/(resolution_scale * args.resolution))
28+
if args.downscale_factor in [1, 2, 4, 8]:
29+
resolution = round(orig_w/(resolution_scale * args.downscale_factor)), round(orig_h/(resolution_scale * args.downscale_factor))
3030
else: # should be a type that converts to float
31-
if args.resolution == -1:
31+
if args.downscale_factor == -1:
3232
if orig_w > 1600:
3333
global WARNED
3434
if not WARNED:
@@ -39,7 +39,7 @@ def loadCam(args, id, cam_info, resolution_scale):
3939
else:
4040
global_down = 1
4141
else:
42-
global_down = orig_w / args.resolution
42+
global_down = orig_w / args.downscale_factor
4343

4444
scale = float(global_down) * float(resolution_scale)
4545
resolution = (int(orig_w / scale), int(orig_h / scale))
@@ -70,10 +70,10 @@ def loadCam(args, id, cam_info, resolution_scale):
7070
# @timer
7171
def loadCamv2(args, id, cam_info, resolution_scale):
7272
orig_w, orig_h = cam_info.width, cam_info.height
73-
if args.resolution in [1, 2, 4, 8]:
74-
resolution = round(orig_w/(resolution_scale * args.resolution)), round(orig_h/(resolution_scale * args.resolution))
73+
if args.downscale_factor in [1, 2, 4, 8]:
74+
resolution = round(orig_w/(resolution_scale * args.downscale_factor)), round(orig_h/(resolution_scale * args.downscale_factor))
7575
else: # should be a type that converts to float
76-
if args.resolution == -1:
76+
if args.downscale_factor == -1:
7777
if orig_w > 1600:
7878
global WARNED
7979
if not WARNED:
@@ -84,7 +84,7 @@ def loadCamv2(args, id, cam_info, resolution_scale):
8484
else:
8585
global_down = 1
8686
else:
87-
global_down = orig_w / args.resolution
87+
global_down = orig_w / args.downscale_factor
8888

8989
scale = float(global_down) * float(resolution_scale)
9090
resolution = (int(orig_w / scale), int(orig_h / scale))
@@ -107,6 +107,10 @@ def loadCamv2(args, id, cam_info, resolution_scale):
107107
else :
108108
rays_o = None
109109
rays_d = None
110+
111+
if gt_image is None:
112+
gt_image = (resolution[0], resolution[1])
113+
110114
return Camera(colmap_id=cam_info.uid, R=cam_info.R, T=cam_info.T,
111115
FoVx=cam_info.FovX, FoVy=cam_info.FovY,
112116
image=gt_image, gt_alpha_mask=loaded_mask,
@@ -122,10 +126,10 @@ def loadCamv2(args, id, cam_info, resolution_scale):
122126
def loadCamv2timing(args, id, cam_info, resolution_scale):
123127
orig_w, orig_h = cam_info.image.size
124128

125-
if args.resolution in [1, 2, 4, 8]:
126-
resolution = round(orig_w/(resolution_scale * args.resolution)), round(orig_h/(resolution_scale * args.resolution))
129+
if args.downscale_factor in [1, 2, 4, 8]:
130+
resolution = round(orig_w/(resolution_scale * args.downscale_factor)), round(orig_h/(resolution_scale * args.downscale_factor))
127131
else: # should be a type that converts to float
128-
if args.resolution == -1:
132+
if args.downscale_factor == -1:
129133
if orig_w > 1600:
130134
global WARNED
131135
if not WARNED:
@@ -136,7 +140,7 @@ def loadCamv2timing(args, id, cam_info, resolution_scale):
136140
else:
137141
global_down = 1
138142
else:
139-
global_down = orig_w / args.resolution
143+
global_down = orig_w / args.downscale_factor
140144

141145
scale = float(global_down) * float(resolution_scale)
142146
resolution = (int(orig_w / scale), int(orig_h / scale))
@@ -166,11 +170,11 @@ def loadCamv2timing(args, id, cam_info, resolution_scale):
166170

167171
def loadCamv2ss(args, id, cam_info, resolution_scale):
168172
orig_w, orig_h = cam_info.image.size
169-
assert args.resolution == 1
170-
if args.resolution in [1, 2, 4, 8]:
171-
resolution = round(orig_w/(resolution_scale * args.resolution)), round(orig_h/(resolution_scale * args.resolution))
173+
assert args.downscale_factor == 1
174+
if args.downscale_factor in [1, 2, 4, 8]:
175+
resolution = round(orig_w/(resolution_scale * args.downscale_factor)), round(orig_h/(resolution_scale * args.downscale_factor))
172176
else: # should be a type that converts to float
173-
if args.resolution == -1:
177+
if args.downscale_factor == -1:
174178
if orig_w > 1600:
175179
global WARNED
176180
if not WARNED:
@@ -181,7 +185,7 @@ def loadCamv2ss(args, id, cam_info, resolution_scale):
181185
else:
182186
global_down = 1
183187
else:
184-
global_down = orig_w / args.resolution
188+
global_down = orig_w / args.downscale_factor
185189

186190
scale = float(global_down) * float(resolution_scale)
187191
resolution = (int(orig_w / scale), int(orig_h / scale))
@@ -214,10 +218,10 @@ def loadCamv2ss(args, id, cam_info, resolution_scale):
214218
def loadCamnogt(args, id, cam_info, resolution_scale):
215219
orig_w, orig_h = cam_info.width, cam_info.height
216220

217-
if args.resolution in [1, 2, 4, 8]:
218-
resolution = round(orig_w/(resolution_scale * args.resolution)), round(orig_h/(resolution_scale * args.resolution))
221+
if args.downscale_factor in [1, 2, 4, 8]:
222+
resolution = round(orig_w/(resolution_scale * args.downscale_factor)), round(orig_h/(resolution_scale * args.downscale_factor))
219223
else: # should be a type that converts to float
220-
if args.resolution == -1:
224+
if args.downscale_factor == -1:
221225
if orig_w > 1600:
222226
global WARNED
223227
if not WARNED:
@@ -228,7 +232,7 @@ def loadCamnogt(args, id, cam_info, resolution_scale):
228232
else:
229233
global_down = 1
230234
else:
231-
global_down = orig_w / args.resolution
235+
global_down = orig_w / args.downscale_factor
232236

233237
scale = float(global_down) * float(resolution_scale)
234238
resolution = (int(orig_w / scale), int(orig_h / scale))

0 commit comments

Comments
 (0)