17
17
class Camera (nn .Module ):
18
18
def __init__ (self , colmap_id , R , T , FoVx , FoVy , image , gt_alpha_mask ,
19
19
image_name , uid ,
20
- trans = np .array ([0.0 , 0.0 , 0.0 ]), scale = 1.0 , data_device = "cuda"
20
+ trans = np .array ([0.0 , 0.0 , 0.0 ]), scale = 1.0 , data_device = "cuda" , lazy_load = False
21
21
):
22
22
super (Camera , self ).__init__ ()
23
23
@@ -36,14 +36,17 @@ def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask,
36
36
print (f"[Warning] Custom device { data_device } failed, fallback to default cuda device" )
37
37
self .data_device = torch .device ("cuda" )
38
38
39
- self .original_image = image .clamp (0.0 , 1.0 )
39
+ if lazy_load :
40
+ self .data_device = torch .device ("cpu" )
41
+
42
+ self .original_image = image .clamp (0.0 , 1.0 ).to (self .data_device )
40
43
self .image_width = self .original_image .shape [2 ]
41
44
self .image_height = self .original_image .shape [1 ]
42
45
43
46
if gt_alpha_mask is not None :
44
- self .original_image *= gt_alpha_mask
47
+ self .original_image *= gt_alpha_mask . to ( self . data_device )
45
48
else :
46
- self .original_image *= torch .ones ((1 , self .image_height , self .image_width ))
49
+ self .original_image *= torch .ones ((1 , self .image_height , self .image_width ), device = self . data_device )
47
50
48
51
self .zfar = 100.0
49
52
self .znear = 0.01
0 commit comments