Skip to content

Commit dfd866c

Browse files
committed
add GPU number and lazy load img to GPU
1 parent 2eee0e2 commit dfd866c

File tree

6 files changed

+11
-6
lines changed

6 files changed

+11
-6
lines changed

arguments/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@ def __init__(self, parser, sentinel=False):
5252
self._images = "images"
5353
self._resolution = -1
5454
self._white_background = False
55+
self._device_number = 0
56+
self._lazy_load = False
5557
self.data_device = "cuda"
5658
self.eval = False
5759
super().__init__(parser, "Loading Parameters", sentinel)

render.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,6 @@ def render_sets(dataset : ModelParams, iteration : int, pipeline : PipelineParam
6161
print("Rendering " + args.model_path)
6262

6363
# Initialize system state (RNG)
64-
safe_state(args.quiet)
64+
safe_state(args.quiet, args.device_number)
6565

6666
render_sets(model.extract(args), args.iteration, pipeline.extract(args), args.skip_train, args.skip_test)

scene/cameras.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
class Camera(nn.Module):
1818
def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask,
1919
image_name, uid,
20-
trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device = "cuda"
20+
trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device = "cuda", lazy_load=False
2121
):
2222
super(Camera, self).__init__()
2323

@@ -36,6 +36,9 @@ def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask,
3636
print(f"[Warning] Custom device {data_device} failed, fallback to default cuda device" )
3737
self.data_device = torch.device("cuda")
3838

39+
if lazy_load:
40+
self.data_device = torch.device("cpu")
41+
3942
self.original_image = image.clamp(0.0, 1.0).to(self.data_device)
4043
self.image_width = self.original_image.shape[2]
4144
self.image_height = self.original_image.shape[1]

train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ def training_report(tb_writer, iteration, Ll1, loss, l1_loss, elapsed, testing_i
211211
print("Optimizing " + args.model_path)
212212

213213
# Initialize system state (RNG)
214-
safe_state(args.quiet)
214+
safe_state(args.quiet, args.device_number)
215215

216216
# Start GUI server, configure and run training
217217
network_gui.init(args.ip, args.port)

utils/camera_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def loadCam(args, id, cam_info, resolution_scale):
4949
return Camera(colmap_id=cam_info.uid, R=cam_info.R, T=cam_info.T,
5050
FoVx=cam_info.FovX, FoVy=cam_info.FovY,
5151
image=gt_image, gt_alpha_mask=loaded_mask,
52-
image_name=cam_info.image_name, uid=id, data_device=args.data_device)
52+
image_name=cam_info.image_name, uid=id, data_device=args.data_device, lazy_load=args.lazy_load)
5353

5454
def cameraList_from_camInfos(cam_infos, resolution_scale, args):
5555
camera_list = []

utils/general_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def build_scaling_rotation(s, r):
109109
L = R @ L
110110
return L
111111

112-
def safe_state(silent):
112+
def safe_state(silent, device_number):
113113
old_f = sys.stdout
114114
class F:
115115
def __init__(self, silent):
@@ -130,4 +130,4 @@ def flush(self):
130130
random.seed(0)
131131
np.random.seed(0)
132132
torch.manual_seed(0)
133-
torch.cuda.set_device(torch.device("cuda:0"))
133+
torch.cuda.set_device(torch.device(f"cuda:{device_number}"))

0 commit comments

Comments
 (0)