Skip to content

Commit e1938e4

Browse files
simplying code and adding GPU serialization fix for both datamanagers
1 parent f87446f commit e1938e4

File tree

3 files changed

+12
-12
lines changed

3 files changed

+12
-12
lines changed

nerfstudio/data/datamanagers/full_images_datamanager.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
from nerfstudio.data.datasets.base_dataset import InputDataset
4747
from nerfstudio.data.utils.data_utils import identity_collate
4848
from nerfstudio.data.utils.dataloaders import ImageBatchStream, _undistort_image
49-
from nerfstudio.utils.misc import get_orig_class
49+
from nerfstudio.utils.misc import get_dict_to_torch, get_orig_class
5050
from nerfstudio.utils.rich_utils import CONSOLE
5151

5252

@@ -390,9 +390,7 @@ def next_train(self, step: int) -> Tuple[Cameras, Dict]:
390390
if self.config.cache_images == "disk":
391391
camera, data = next(self.iter_train_image_dataloader)[0]
392392
camera = camera.to(self.device)
393-
for k in data.keys():
394-
if isinstance(data[k], torch.Tensor):
395-
data[k] = data[k].to(self.device)
393+
data = get_dict_to_torch(data, self.device)
396394
return camera, data
397395

398396
image_idx = self.train_unseen_cameras.pop(0)
@@ -420,9 +418,7 @@ def next_eval(self, step: int) -> Tuple[Cameras, Dict]:
420418
if self.config.cache_images == "disk":
421419
camera, data = next(self.iter_eval_image_dataloader)[0]
422420
camera = camera.to(self.device)
423-
for k in data.keys():
424-
if isinstance(data[k], torch.Tensor):
425-
data[k] = data[k].to(self.device)
421+
data = get_dict_to_torch(data, self.device)
426422
return camera, data
427423

428424
return self.next_eval_image(step=step)

nerfstudio/data/datamanagers/parallel_datamanager.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
RayBatchStream,
4141
variable_res_collate,
4242
)
43-
from nerfstudio.utils.misc import get_orig_class
43+
from nerfstudio.utils.misc import get_dict_to_torch, get_orig_class
4444
from nerfstudio.utils.rich_utils import CONSOLE
4545

4646

@@ -241,12 +241,16 @@ def next_train(self, step: int) -> Tuple[RayBundle, Dict]:
241241
"""Returns the next batch of data from the train dataloader."""
242242
self.train_count += 1
243243
ray_bundle, batch = next(self.iter_train_raybundles)[0]
244+
ray_bundle = ray_bundle.to(self.device)
245+
batch = get_dict_to_torch(batch, self.device)
244246
return ray_bundle, batch
245247

246248
def next_eval(self, step: int) -> Tuple[RayBundle, Dict]:
247249
"""Returns the next batch of data from the eval dataloader."""
248250
self.eval_count += 1
249251
ray_bundle, batch = next(self.iter_train_raybundles)[0]
252+
ray_bundle = ray_bundle.to(self.device)
253+
batch = get_dict_to_torch(batch, self.device)
250254
return ray_bundle, batch
251255

252256
def next_eval_image(self, step: int) -> Tuple[Cameras, Dict]:

nerfstudio/data/utils/dataloaders.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -574,18 +574,18 @@ def __iter__(self):
574574
"""
575575
Here, the variable 'batch' refers to the output of our pixel sampler.
576576
- batch is a dict_keys(['image', 'indices'])
577-
- batch['image'] returns a pytorch tensor with shape `torch.Size([4096, 3])` , where 4096 = num_rays_per_batch.
577+
- batch['image'] returns a `torch.Size([4096, 3])` tensor on CPU, where 4096 = num_rays_per_batch.
578578
- Note: each row in this tensor represents the RGB values as floats in [0, 1] of the pixel the ray goes through.
579579
- The info of what specific image index that pixel belongs to is stored within batch[’indices’]
580-
- batch['indices'] returns a pytorch tensor `torch.Size([4096, 3])` tensor where each row represents (image_idx, pixelRow, pixelCol)
580+
- batch['indices'] returns a `torch.Size([4096, 3])` tensor on CPU where each row represents (image_idx, pixelRow, pixelCol)
581581
pixel_sampler (for variable_res_collate) will loop though each image, samples pixel within the mask, and returns
582582
them as the variable `indices` which has shape torch.Size([4096, 3]), where each row represents a pixel (image_idx, pixelRow, pixelCol)
583583
"""
584584
batch = worker_pixel_sampler.sample(collated_batch) # type: ignore
585585
# Note: collated_batch["image"].get_device() will return CPU if self.exclude_batch_keys_from_device contains 'image'
586586
ray_indices = batch["indices"]
587-
# the ray_bundle is on the GPU; batch["image"] is on the CPU, here we move it to the GPU
588-
ray_bundle = self.ray_generator(ray_indices).to(self.device)
587+
# Both ray_bundle and batch["image"] are on the CPU and will be moved to the GPU in the main process (parallel_datamanager.py)
588+
ray_bundle = self.ray_generator(ray_indices)
589589
if self.custom_ray_processor:
590590
ray_bundle, batch = self.custom_ray_processor(ray_bundle, batch)
591591

0 commit comments

Comments
 (0)