Skip to content

Commit 3148ab1

Browse files
committed
add lazy_load argument to control whether or not
2 parents 733041f + 996185e commit 3148ab1

File tree

3 files changed

+9
-5
lines changed

3 files changed

+9
-5
lines changed

arguments/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ def __init__(self, parser, sentinel=False):
5353
self._resolution = -1
5454
self._white_background = False
5555
self._device_number = 0
56+
self._lazy_load = False
5657
self.data_device = "cuda"
5758
self.eval = False
5859
super().__init__(parser, "Loading Parameters", sentinel)

scene/cameras.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
class Camera(nn.Module):
1818
def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask,
1919
image_name, uid,
20-
trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device = "cuda"
20+
trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device = "cuda", lazy_load=False
2121
):
2222
super(Camera, self).__init__()
2323

@@ -36,14 +36,17 @@ def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask,
3636
print(f"[Warning] Custom device {data_device} failed, fallback to default cuda device" )
3737
self.data_device = torch.device("cuda")
3838

39-
self.original_image = image.clamp(0.0, 1.0)
39+
if lazy_load:
40+
self.data_device = torch.device("cpu")
41+
42+
self.original_image = image.clamp(0.0, 1.0).to(self.data_device)
4043
self.image_width = self.original_image.shape[2]
4144
self.image_height = self.original_image.shape[1]
4245

4346
if gt_alpha_mask is not None:
44-
self.original_image *= gt_alpha_mask
47+
self.original_image *= gt_alpha_mask.to(self.data_device)
4548
else:
46-
self.original_image *= torch.ones((1, self.image_height, self.image_width))
49+
self.original_image *= torch.ones((1, self.image_height, self.image_width), device=self.data_device)
4750

4851
self.zfar = 100.0
4952
self.znear = 0.01

utils/camera_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def loadCam(args, id, cam_info, resolution_scale):
4949
return Camera(colmap_id=cam_info.uid, R=cam_info.R, T=cam_info.T,
5050
FoVx=cam_info.FovX, FoVy=cam_info.FovY,
5151
image=gt_image, gt_alpha_mask=loaded_mask,
52-
image_name=cam_info.image_name, uid=id, data_device=args.data_device)
52+
image_name=cam_info.image_name, uid=id, data_device=args.data_device, lazy_load=args.lazy_load)
5353

5454
def cameraList_from_camInfos(cam_infos, resolution_scale, args):
5555
camera_list = []

0 commit comments

Comments
 (0)