Skip to content

Commit 733041f

Browse files
committed
add GPU number and lazy load img to GPU
1 parent 2eee0e2 commit 733041f

File tree

5 files changed

+8
-7
lines changed

5 files changed

+8
-7
lines changed

arguments/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ def __init__(self, parser, sentinel=False):
5252
self._images = "images"
5353
self._resolution = -1
5454
self._white_background = False
55+
self._device_number = 0
5556
self.data_device = "cuda"
5657
self.eval = False
5758
super().__init__(parser, "Loading Parameters", sentinel)

render.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,6 @@ def render_sets(dataset : ModelParams, iteration : int, pipeline : PipelineParam
6161
print("Rendering " + args.model_path)
6262

6363
# Initialize system state (RNG)
64-
safe_state(args.quiet)
64+
safe_state(args.quiet, args.device_number)
6565

6666
render_sets(model.extract(args), args.iteration, pipeline.extract(args), args.skip_train, args.skip_test)

scene/cameras.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,14 +36,14 @@ def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask,
3636
print(f"[Warning] Custom device {data_device} failed, fallback to default cuda device" )
3737
self.data_device = torch.device("cuda")
3838

39-
self.original_image = image.clamp(0.0, 1.0).to(self.data_device)
39+
self.original_image = image.clamp(0.0, 1.0)
4040
self.image_width = self.original_image.shape[2]
4141
self.image_height = self.original_image.shape[1]
4242

4343
if gt_alpha_mask is not None:
44-
self.original_image *= gt_alpha_mask.to(self.data_device)
44+
self.original_image *= gt_alpha_mask
4545
else:
46-
self.original_image *= torch.ones((1, self.image_height, self.image_width), device=self.data_device)
46+
self.original_image *= torch.ones((1, self.image_height, self.image_width))
4747

4848
self.zfar = 100.0
4949
self.znear = 0.01

train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ def training_report(tb_writer, iteration, Ll1, loss, l1_loss, elapsed, testing_i
211211
print("Optimizing " + args.model_path)
212212

213213
# Initialize system state (RNG)
214-
safe_state(args.quiet)
214+
safe_state(args.quiet, args.device_number)
215215

216216
# Start GUI server, configure and run training
217217
network_gui.init(args.ip, args.port)

utils/general_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def build_scaling_rotation(s, r):
109109
L = R @ L
110110
return L
111111

112-
def safe_state(silent):
112+
def safe_state(silent, device_number):
113113
old_f = sys.stdout
114114
class F:
115115
def __init__(self, silent):
@@ -130,4 +130,4 @@ def flush(self):
130130
random.seed(0)
131131
np.random.seed(0)
132132
torch.manual_seed(0)
133-
torch.cuda.set_device(torch.device("cuda:0"))
133+
torch.cuda.set_device(torch.device(f"cuda:{device_number}"))

0 commit comments

Comments
 (0)