Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions docs/developer_guides/pipelines/datamanagers.md
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,32 @@ To train splatfacto with a large dataset that's unable to fit in memory, please
ns-train splatfacto --data {PROCESSED_DATA_DIR} --pipeline.datamanager.cache-images disk
```

Checkout these flowcharts for more customization on large datasets!

```{image} imgs/DatamanagerGuide-LargeNeRF-light.png
:align: center
:class: only-light
:width: 600
```

```{image} imgs/DatamanagerGuide-LargeNeRF-dark.png
:align: center
:class: only-dark
:width: 600
```

```{image} imgs/DatamanagerGuide-Large3DGS-light.png
:align: center
:class: only-light
:width: 600
```

```{image} imgs/DatamanagerGuide-Large3DGS-dark.png
:align: center
:class: only-dark
:width: 600
```

## Migrating Your DataManager to the new DataManager
Many methods subclass a DataManager and add extra data to it. If you would like your custom datamanager to also support new parallel features, you can migrate any custom dataloading logic to the new `custom_ray_processor()` API. This function takes in a full training batch (either image or ray bundle) and allows the user to modify or add to it. Let's take a look at an example for the LERF method, which was built on Nerfstudio's VanillaDataManager. This API provides an interface to attach new information to the RayBundle (for ray based methods), Cameras object (for splatting based methods), or ground truth dictionary. It runs in a background process if disk caching is enabled, otherwise it runs in the main process.

Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion nerfstudio/configs/method_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@
max_num_iterations=30000,
mixed_precision=True,
pipeline=VanillaPipelineConfig(
datamanager=VanillaDataManagerConfig(
datamanager=ParallelDataManagerConfig(
_target=ParallelDataManager[DepthDataset],
dataparser=NerfstudioDataParserConfig(),
train_num_rays_per_batch=4096,
Expand Down
14 changes: 10 additions & 4 deletions nerfstudio/data/datamanagers/full_images_datamanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from copy import deepcopy
from dataclasses import dataclass, field
from functools import cached_property
from itertools import islice
from pathlib import Path
from typing import Dict, ForwardRef, Generic, List, Literal, Optional, Tuple, Type, Union, cast, get_args, get_origin

Expand All @@ -45,7 +46,7 @@
from nerfstudio.data.datasets.base_dataset import InputDataset
from nerfstudio.data.utils.data_utils import identity_collate
from nerfstudio.data.utils.dataloaders import ImageBatchStream, _undistort_image
from nerfstudio.utils.misc import get_orig_class
from nerfstudio.utils.misc import get_dict_to_torch, get_orig_class
from nerfstudio.utils.rich_utils import CONSOLE


Expand Down Expand Up @@ -84,7 +85,7 @@ class FullImageDatamanagerConfig(DataManagerConfig):
dataloader_num_workers: int = 4
"""The number of workers performing the dataloading from either disk/RAM, which
includes collating, pixel sampling, unprojecting, ray generation etc."""
prefetch_factor: int = 4
prefetch_factor: Optional[int] = 4
"""The limit number of batches a worker will start loading once an iterator is created.
More details are described here: https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader"""
cache_compressed_images: bool = False
Expand Down Expand Up @@ -356,9 +357,9 @@ def fixed_indices_eval_dataloader(self) -> List[Tuple[Cameras, Dict]]:
self.eval_imagebatch_stream,
batch_size=1,
num_workers=0,
collate_fn=identity_collate,
collate_fn=lambda x: x[0],
)
return [batch[0] for batch in dataloader]
return list(islice(dataloader, len(self.eval_dataset)))

image_indices = [i for i in range(len(self.eval_dataset))]
data = [d.copy() for d in self.cached_eval]
Expand Down Expand Up @@ -388,6 +389,9 @@ def next_train(self, step: int) -> Tuple[Cameras, Dict]:
self.train_count += 1
if self.config.cache_images == "disk":
camera, data = next(self.iter_train_image_dataloader)[0]
camera = camera.to(self.device)
data = get_dict_to_torch(data, self.device)
print(camera.metadata)
return camera, data

image_idx = self.train_unseen_cameras.pop(0)
Expand All @@ -414,6 +418,8 @@ def next_eval(self, step: int) -> Tuple[Cameras, Dict]:
self.eval_count += 1
if self.config.cache_images == "disk":
camera, data = next(self.iter_eval_image_dataloader)[0]
camera = camera.to(self.device)
data = get_dict_to_torch(data, self.device)
return camera, data

return self.next_eval_image(step=step)
Expand Down
8 changes: 6 additions & 2 deletions nerfstudio/data/datamanagers/parallel_datamanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
RayBatchStream,
variable_res_collate,
)
from nerfstudio.utils.misc import get_orig_class
from nerfstudio.utils.misc import get_dict_to_torch, get_orig_class
from nerfstudio.utils.rich_utils import CONSOLE


Expand All @@ -56,7 +56,7 @@ class ParallelDataManagerConfig(VanillaDataManagerConfig):
dataloader_num_workers: int = 4
"""The number of workers performing the dataloading from either disk/RAM, which
includes collating, pixel sampling, unprojecting, ray generation etc."""
prefetch_factor: int = 10
prefetch_factor: Optional[int] = 10
"""The limit number of batches a worker will start loading once an iterator is created.
More details are described here: https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader"""
cache_compressed_images: bool = False
Expand Down Expand Up @@ -241,12 +241,16 @@ def next_train(self, step: int) -> Tuple[RayBundle, Dict]:
"""Returns the next batch of data from the train dataloader."""
self.train_count += 1
ray_bundle, batch = next(self.iter_train_raybundles)[0]
ray_bundle = ray_bundle.to(self.device)
batch = get_dict_to_torch(batch, self.device)
return ray_bundle, batch

def next_eval(self, step: int) -> Tuple[RayBundle, Dict]:
"""Returns the next batch of data from the eval dataloader."""
self.eval_count += 1
ray_bundle, batch = next(self.iter_train_raybundles)[0]
ray_bundle = ray_bundle.to(self.device)
batch = get_dict_to_torch(batch, self.device)
return ray_bundle, batch

def next_eval_image(self, step: int) -> Tuple[Cameras, Dict]:
Expand Down
12 changes: 4 additions & 8 deletions nerfstudio/data/utils/dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,18 +574,18 @@ def __iter__(self):
"""
Here, the variable 'batch' refers to the output of our pixel sampler.
- batch is a dict_keys(['image', 'indices'])
- batch['image'] returns a pytorch tensor with shape `torch.Size([4096, 3])` , where 4096 = num_rays_per_batch.
- batch['image'] returns a `torch.Size([4096, 3])` tensor on CPU, where 4096 = num_rays_per_batch.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we support rgba supervision here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes this does! When an RGBA image is present in the dataset, it gets converted into RGB format in the InputDataset.

Specifically this is what happens:

  1. dataloaders.py's RayBatchStream will call self.input_dataset.__getitem__
  2. InputDataset's __getitem__() method calls self.get_data
  3. get_data will call get_image_float32
  4. get_image_float32 has the code for RGBA support

- Note: each row in this tensor represents the RGB values as floats in [0, 1] of the pixel the ray goes through.
- The info of what specific image index that pixel belongs to is stored within batch[’indices’]
- batch['indices'] returns a pytorch tensor `torch.Size([4096, 3])` tensor where each row represents (image_idx, pixelRow, pixelCol)
- batch['indices'] returns a `torch.Size([4096, 3])` tensor on CPU where each row represents (image_idx, pixelRow, pixelCol)
pixel_sampler (for variable_res_collate) will loop though each image, samples pixel within the mask, and returns
them as the variable `indices` which has shape torch.Size([4096, 3]), where each row represents a pixel (image_idx, pixelRow, pixelCol)
"""
batch = worker_pixel_sampler.sample(collated_batch) # type: ignore
# Note: collated_batch["image"].get_device() will return CPU if self.exclude_batch_keys_from_device contains 'image'
ray_indices = batch["indices"]
# the ray_bundle is on the GPU; batch["image"] is on the CPU, here we move it to the GPU
ray_bundle = self.ray_generator(ray_indices).to(self.device)
# Both ray_bundle and batch["image"] are on the CPU and will be moved to the GPU in the main process (parallel_datamanager.py)
ray_bundle = self.ray_generator(ray_indices)
if self.custom_ray_processor:
ray_bundle, batch = self.custom_ray_processor(ray_bundle, batch)

Expand Down Expand Up @@ -645,10 +645,6 @@ def __iter__(self):
camera, data = self.custom_image_processor(camera, data)

i += 1
camera = camera.to(self.device)
for k in data.keys():
if isinstance(data[k], torch.Tensor):
data[k] = data[k].to(self.device)
yield camera, data


Expand Down