Skip to content

Commit 94357f8

Browse files
Dataloading Followup (#3604)
* fixing ns-eval when datamanager is set to disk for splats * fixing prefetch_factor type issues if user decides to use 0 workers * removing accidental file * fixing depth nerfacto config for issue 3592 * ruff linting the imports * fixing workers creating GPU tensors serialization error * finished next_eval * adding flowcharts! * fixing ruff * added flowcharts into datamanagers.md * simplying code and adding GPU serialization fix for both datamanagers * fixing nits with consistent styling * last optional nit * fixing num workers
1 parent 54b127f commit 94357f8

File tree

9 files changed

+47
-15
lines changed

9 files changed

+47
-15
lines changed

docs/developer_guides/pipelines/datamanagers.md

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,32 @@ To train splatfacto with a large dataset that's unable to fit in memory, please
115115
ns-train splatfacto --data {PROCESSED_DATA_DIR} --pipeline.datamanager.cache-images disk
116116
```
117117

118+
Checkout these flowcharts for more customization on large datasets!
119+
120+
```{image} imgs/DatamanagerGuide-LargeNeRF-light.png
121+
:align: center
122+
:class: only-light
123+
:width: 600
124+
```
125+
126+
```{image} imgs/DatamanagerGuide-LargeNeRF-dark.png
127+
:align: center
128+
:class: only-dark
129+
:width: 600
130+
```
131+
132+
```{image} imgs/DatamanagerGuide-Large3DGS-light.png
133+
:align: center
134+
:class: only-light
135+
:width: 600
136+
```
137+
138+
```{image} imgs/DatamanagerGuide-Large3DGS-dark.png
139+
:align: center
140+
:class: only-dark
141+
:width: 600
142+
```
143+
118144
## Migrating Your DataManager to the new DataManager
119145
Many methods subclass a DataManager and add extra data to it. If you would like your custom datamanager to also support new parallel features, you can migrate any custom dataloading logic to the new `custom_ray_processor()` API. This function takes in a full training batch (either image or ray bundle) and allows the user to modify or add to it. Let's take a look at an example for the LERF method, which was built on Nerfstudio's VanillaDataManager. This API provides an interface to attach new information to the RayBundle (for ray based methods), Cameras object (for splatting based methods), or ground truth dictionary. It runs in a background process if disk caching is enabled, otherwise it runs in the main process.
120146

526 KB
Loading
528 KB
Loading
532 KB
Loading
532 KB
Loading

nerfstudio/configs/method_configs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@
219219
max_num_iterations=30000,
220220
mixed_precision=True,
221221
pipeline=VanillaPipelineConfig(
222-
datamanager=VanillaDataManagerConfig(
222+
datamanager=ParallelDataManagerConfig(
223223
_target=ParallelDataManager[DepthDataset],
224224
dataparser=NerfstudioDataParserConfig(),
225225
train_num_rays_per_batch=4096,

nerfstudio/data/datamanagers/full_images_datamanager.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from copy import deepcopy
2727
from dataclasses import dataclass, field
2828
from functools import cached_property
29+
from itertools import islice
2930
from pathlib import Path
3031
from typing import Dict, ForwardRef, Generic, List, Literal, Optional, Tuple, Type, Union, cast, get_args, get_origin
3132

@@ -45,7 +46,7 @@
4546
from nerfstudio.data.datasets.base_dataset import InputDataset
4647
from nerfstudio.data.utils.data_utils import identity_collate
4748
from nerfstudio.data.utils.dataloaders import ImageBatchStream, _undistort_image
48-
from nerfstudio.utils.misc import get_orig_class
49+
from nerfstudio.utils.misc import get_dict_to_torch, get_orig_class
4950
from nerfstudio.utils.rich_utils import CONSOLE
5051

5152

@@ -84,7 +85,7 @@ class FullImageDatamanagerConfig(DataManagerConfig):
8485
dataloader_num_workers: int = 4
8586
"""The number of workers performing the dataloading from either disk/RAM, which
8687
includes collating, pixel sampling, unprojecting, ray generation etc."""
87-
prefetch_factor: int = 4
88+
prefetch_factor: Optional[int] = 4
8889
"""The limit number of batches a worker will start loading once an iterator is created.
8990
More details are described here: https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader"""
9091
cache_compressed_images: bool = False
@@ -356,9 +357,9 @@ def fixed_indices_eval_dataloader(self) -> List[Tuple[Cameras, Dict]]:
356357
self.eval_imagebatch_stream,
357358
batch_size=1,
358359
num_workers=0,
359-
collate_fn=identity_collate,
360+
collate_fn=lambda x: x[0],
360361
)
361-
return [batch[0] for batch in dataloader]
362+
return list(islice(dataloader, len(self.eval_dataset)))
362363

363364
image_indices = [i for i in range(len(self.eval_dataset))]
364365
data = [d.copy() for d in self.cached_eval]
@@ -388,6 +389,9 @@ def next_train(self, step: int) -> Tuple[Cameras, Dict]:
388389
self.train_count += 1
389390
if self.config.cache_images == "disk":
390391
camera, data = next(self.iter_train_image_dataloader)[0]
392+
camera = camera.to(self.device)
393+
data = get_dict_to_torch(data, self.device)
394+
print(camera.metadata)
391395
return camera, data
392396

393397
image_idx = self.train_unseen_cameras.pop(0)
@@ -414,6 +418,8 @@ def next_eval(self, step: int) -> Tuple[Cameras, Dict]:
414418
self.eval_count += 1
415419
if self.config.cache_images == "disk":
416420
camera, data = next(self.iter_eval_image_dataloader)[0]
421+
camera = camera.to(self.device)
422+
data = get_dict_to_torch(data, self.device)
417423
return camera, data
418424

419425
return self.next_eval_image(step=step)

nerfstudio/data/datamanagers/parallel_datamanager.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
RayBatchStream,
4141
variable_res_collate,
4242
)
43-
from nerfstudio.utils.misc import get_orig_class
43+
from nerfstudio.utils.misc import get_dict_to_torch, get_orig_class
4444
from nerfstudio.utils.rich_utils import CONSOLE
4545

4646

@@ -56,7 +56,7 @@ class ParallelDataManagerConfig(VanillaDataManagerConfig):
5656
dataloader_num_workers: int = 4
5757
"""The number of workers performing the dataloading from either disk/RAM, which
5858
includes collating, pixel sampling, unprojecting, ray generation etc."""
59-
prefetch_factor: int = 10
59+
prefetch_factor: Optional[int] = 10
6060
"""The limit number of batches a worker will start loading once an iterator is created.
6161
More details are described here: https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader"""
6262
cache_compressed_images: bool = False
@@ -241,12 +241,16 @@ def next_train(self, step: int) -> Tuple[RayBundle, Dict]:
241241
"""Returns the next batch of data from the train dataloader."""
242242
self.train_count += 1
243243
ray_bundle, batch = next(self.iter_train_raybundles)[0]
244+
ray_bundle = ray_bundle.to(self.device)
245+
batch = get_dict_to_torch(batch, self.device)
244246
return ray_bundle, batch
245247

246248
def next_eval(self, step: int) -> Tuple[RayBundle, Dict]:
247249
"""Returns the next batch of data from the eval dataloader."""
248250
self.eval_count += 1
249251
ray_bundle, batch = next(self.iter_train_raybundles)[0]
252+
ray_bundle = ray_bundle.to(self.device)
253+
batch = get_dict_to_torch(batch, self.device)
250254
return ray_bundle, batch
251255

252256
def next_eval_image(self, step: int) -> Tuple[Cameras, Dict]:

nerfstudio/data/utils/dataloaders.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -574,18 +574,18 @@ def __iter__(self):
574574
"""
575575
Here, the variable 'batch' refers to the output of our pixel sampler.
576576
- batch is a dict_keys(['image', 'indices'])
577-
- batch['image'] returns a pytorch tensor with shape `torch.Size([4096, 3])` , where 4096 = num_rays_per_batch.
577+
- batch['image'] returns a `torch.Size([4096, 3])` tensor on CPU, where 4096 = num_rays_per_batch.
578578
- Note: each row in this tensor represents the RGB values as floats in [0, 1] of the pixel the ray goes through.
579579
- The info of what specific image index that pixel belongs to is stored within batch[’indices’]
580-
- batch['indices'] returns a pytorch tensor `torch.Size([4096, 3])` tensor where each row represents (image_idx, pixelRow, pixelCol)
580+
- batch['indices'] returns a `torch.Size([4096, 3])` tensor on CPU where each row represents (image_idx, pixelRow, pixelCol)
581581
pixel_sampler (for variable_res_collate) will loop though each image, samples pixel within the mask, and returns
582582
them as the variable `indices` which has shape torch.Size([4096, 3]), where each row represents a pixel (image_idx, pixelRow, pixelCol)
583583
"""
584584
batch = worker_pixel_sampler.sample(collated_batch) # type: ignore
585585
# Note: collated_batch["image"].get_device() will return CPU if self.exclude_batch_keys_from_device contains 'image'
586586
ray_indices = batch["indices"]
587-
# the ray_bundle is on the GPU; batch["image"] is on the CPU, here we move it to the GPU
588-
ray_bundle = self.ray_generator(ray_indices).to(self.device)
587+
# Both ray_bundle and batch["image"] are on the CPU and will be moved to the GPU in the main process (parallel_datamanager.py)
588+
ray_bundle = self.ray_generator(ray_indices)
589589
if self.custom_ray_processor:
590590
ray_bundle, batch = self.custom_ray_processor(ray_bundle, batch)
591591

@@ -645,10 +645,6 @@ def __iter__(self):
645645
camera, data = self.custom_image_processor(camera, data)
646646

647647
i += 1
648-
camera = camera.to(self.device)
649-
for k in data.keys():
650-
if isinstance(data[k], torch.Tensor):
651-
data[k] = data[k].to(self.device)
652648
yield camera, data
653649

654650

0 commit comments

Comments
 (0)