diff --git a/docs/developer_guides/pipelines/datamanagers.md b/docs/developer_guides/pipelines/datamanagers.md index 316514c641..0d4a8b8ba1 100644 --- a/docs/developer_guides/pipelines/datamanagers.md +++ b/docs/developer_guides/pipelines/datamanagers.md @@ -115,6 +115,32 @@ To train splatfacto with a large dataset that's unable to fit in memory, please ns-train splatfacto --data {PROCESSED_DATA_DIR} --pipeline.datamanager.cache-images disk ``` +Checkout these flowcharts for more customization on large datasets! + +```{image} imgs/DatamanagerGuide-LargeNeRF-light.png +:align: center +:class: only-light +:width: 600 +``` + +```{image} imgs/DatamanagerGuide-LargeNeRF-dark.png +:align: center +:class: only-dark +:width: 600 +``` + +```{image} imgs/DatamanagerGuide-Large3DGS-light.png +:align: center +:class: only-light +:width: 600 +``` + +```{image} imgs/DatamanagerGuide-Large3DGS-dark.png +:align: center +:class: only-dark +:width: 600 +``` + ## Migrating Your DataManager to the new DataManager Many methods subclass a DataManager and add extra data to it. If you would like your custom datamanager to also support new parallel features, you can migrate any custom dataloading logic to the new `custom_ray_processor()` API. This function takes in a full training batch (either image or ray bundle) and allows the user to modify or add to it. Let's take a look at an example for the LERF method, which was built on Nerfstudio's VanillaDataManager. This API provides an interface to attach new information to the RayBundle (for ray based methods), Cameras object (for splatting based methods), or ground truth dictionary. It runs in a background process if disk caching is enabled, otherwise it runs in the main process. diff --git a/docs/developer_guides/pipelines/imgs/DatamanagerGuide-Large3DGS-dark.png b/docs/developer_guides/pipelines/imgs/DatamanagerGuide-Large3DGS-dark.png new file mode 100644 index 0000000000..cdbb2f8f7a Binary files /dev/null and b/docs/developer_guides/pipelines/imgs/DatamanagerGuide-Large3DGS-dark.png differ diff --git a/docs/developer_guides/pipelines/imgs/DatamanagerGuide-Large3DGS-light.png b/docs/developer_guides/pipelines/imgs/DatamanagerGuide-Large3DGS-light.png new file mode 100644 index 0000000000..972577bba1 Binary files /dev/null and b/docs/developer_guides/pipelines/imgs/DatamanagerGuide-Large3DGS-light.png differ diff --git a/docs/developer_guides/pipelines/imgs/DatamanagerGuide-LargeNeRF-dark.png b/docs/developer_guides/pipelines/imgs/DatamanagerGuide-LargeNeRF-dark.png new file mode 100644 index 0000000000..23c93aee7d Binary files /dev/null and b/docs/developer_guides/pipelines/imgs/DatamanagerGuide-LargeNeRF-dark.png differ diff --git a/docs/developer_guides/pipelines/imgs/DatamanagerGuide-LargeNeRF-light.png b/docs/developer_guides/pipelines/imgs/DatamanagerGuide-LargeNeRF-light.png new file mode 100644 index 0000000000..ebdec0c2ac Binary files /dev/null and b/docs/developer_guides/pipelines/imgs/DatamanagerGuide-LargeNeRF-light.png differ diff --git a/nerfstudio/configs/method_configs.py b/nerfstudio/configs/method_configs.py index bc1b4225aa..dbb6acf14c 100644 --- a/nerfstudio/configs/method_configs.py +++ b/nerfstudio/configs/method_configs.py @@ -219,7 +219,7 @@ max_num_iterations=30000, mixed_precision=True, pipeline=VanillaPipelineConfig( - datamanager=VanillaDataManagerConfig( + datamanager=ParallelDataManagerConfig( _target=ParallelDataManager[DepthDataset], dataparser=NerfstudioDataParserConfig(), train_num_rays_per_batch=4096, diff --git a/nerfstudio/data/datamanagers/full_images_datamanager.py b/nerfstudio/data/datamanagers/full_images_datamanager.py index 3ec06120cf..13fb6f68d4 100644 --- a/nerfstudio/data/datamanagers/full_images_datamanager.py +++ b/nerfstudio/data/datamanagers/full_images_datamanager.py @@ -26,6 +26,7 @@ from copy import deepcopy from dataclasses import dataclass, field from functools import cached_property +from itertools import islice from pathlib import Path from typing import Dict, ForwardRef, Generic, List, Literal, Optional, Tuple, Type, Union, cast, get_args, get_origin @@ -45,7 +46,7 @@ from nerfstudio.data.datasets.base_dataset import InputDataset from nerfstudio.data.utils.data_utils import identity_collate from nerfstudio.data.utils.dataloaders import ImageBatchStream, _undistort_image -from nerfstudio.utils.misc import get_orig_class +from nerfstudio.utils.misc import get_dict_to_torch, get_orig_class from nerfstudio.utils.rich_utils import CONSOLE @@ -84,7 +85,7 @@ class FullImageDatamanagerConfig(DataManagerConfig): dataloader_num_workers: int = 4 """The number of workers performing the dataloading from either disk/RAM, which includes collating, pixel sampling, unprojecting, ray generation etc.""" - prefetch_factor: int = 4 + prefetch_factor: Optional[int] = 4 """The limit number of batches a worker will start loading once an iterator is created. More details are described here: https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader""" cache_compressed_images: bool = False @@ -356,9 +357,9 @@ def fixed_indices_eval_dataloader(self) -> List[Tuple[Cameras, Dict]]: self.eval_imagebatch_stream, batch_size=1, num_workers=0, - collate_fn=identity_collate, + collate_fn=lambda x: x[0], ) - return [batch[0] for batch in dataloader] + return list(islice(dataloader, len(self.eval_dataset))) image_indices = [i for i in range(len(self.eval_dataset))] data = [d.copy() for d in self.cached_eval] @@ -388,6 +389,9 @@ def next_train(self, step: int) -> Tuple[Cameras, Dict]: self.train_count += 1 if self.config.cache_images == "disk": camera, data = next(self.iter_train_image_dataloader)[0] + camera = camera.to(self.device) + data = get_dict_to_torch(data, self.device) + print(camera.metadata) return camera, data image_idx = self.train_unseen_cameras.pop(0) @@ -414,6 +418,8 @@ def next_eval(self, step: int) -> Tuple[Cameras, Dict]: self.eval_count += 1 if self.config.cache_images == "disk": camera, data = next(self.iter_eval_image_dataloader)[0] + camera = camera.to(self.device) + data = get_dict_to_torch(data, self.device) return camera, data return self.next_eval_image(step=step) diff --git a/nerfstudio/data/datamanagers/parallel_datamanager.py b/nerfstudio/data/datamanagers/parallel_datamanager.py index fe3a62f3c4..e1e1f6ef52 100644 --- a/nerfstudio/data/datamanagers/parallel_datamanager.py +++ b/nerfstudio/data/datamanagers/parallel_datamanager.py @@ -40,7 +40,7 @@ RayBatchStream, variable_res_collate, ) -from nerfstudio.utils.misc import get_orig_class +from nerfstudio.utils.misc import get_dict_to_torch, get_orig_class from nerfstudio.utils.rich_utils import CONSOLE @@ -56,7 +56,7 @@ class ParallelDataManagerConfig(VanillaDataManagerConfig): dataloader_num_workers: int = 4 """The number of workers performing the dataloading from either disk/RAM, which includes collating, pixel sampling, unprojecting, ray generation etc.""" - prefetch_factor: int = 10 + prefetch_factor: Optional[int] = 10 """The limit number of batches a worker will start loading once an iterator is created. More details are described here: https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader""" cache_compressed_images: bool = False @@ -241,12 +241,16 @@ def next_train(self, step: int) -> Tuple[RayBundle, Dict]: """Returns the next batch of data from the train dataloader.""" self.train_count += 1 ray_bundle, batch = next(self.iter_train_raybundles)[0] + ray_bundle = ray_bundle.to(self.device) + batch = get_dict_to_torch(batch, self.device) return ray_bundle, batch def next_eval(self, step: int) -> Tuple[RayBundle, Dict]: """Returns the next batch of data from the eval dataloader.""" self.eval_count += 1 ray_bundle, batch = next(self.iter_train_raybundles)[0] + ray_bundle = ray_bundle.to(self.device) + batch = get_dict_to_torch(batch, self.device) return ray_bundle, batch def next_eval_image(self, step: int) -> Tuple[Cameras, Dict]: diff --git a/nerfstudio/data/utils/dataloaders.py b/nerfstudio/data/utils/dataloaders.py index 9fa2faaf7d..f9a6ccc5f1 100644 --- a/nerfstudio/data/utils/dataloaders.py +++ b/nerfstudio/data/utils/dataloaders.py @@ -574,18 +574,18 @@ def __iter__(self): """ Here, the variable 'batch' refers to the output of our pixel sampler. - batch is a dict_keys(['image', 'indices']) - - batch['image'] returns a pytorch tensor with shape `torch.Size([4096, 3])` , where 4096 = num_rays_per_batch. + - batch['image'] returns a `torch.Size([4096, 3])` tensor on CPU, where 4096 = num_rays_per_batch. - Note: each row in this tensor represents the RGB values as floats in [0, 1] of the pixel the ray goes through. - The info of what specific image index that pixel belongs to is stored within batch[’indices’] - - batch['indices'] returns a pytorch tensor `torch.Size([4096, 3])` tensor where each row represents (image_idx, pixelRow, pixelCol) + - batch['indices'] returns a `torch.Size([4096, 3])` tensor on CPU where each row represents (image_idx, pixelRow, pixelCol) pixel_sampler (for variable_res_collate) will loop though each image, samples pixel within the mask, and returns them as the variable `indices` which has shape torch.Size([4096, 3]), where each row represents a pixel (image_idx, pixelRow, pixelCol) """ batch = worker_pixel_sampler.sample(collated_batch) # type: ignore # Note: collated_batch["image"].get_device() will return CPU if self.exclude_batch_keys_from_device contains 'image' ray_indices = batch["indices"] - # the ray_bundle is on the GPU; batch["image"] is on the CPU, here we move it to the GPU - ray_bundle = self.ray_generator(ray_indices).to(self.device) + # Both ray_bundle and batch["image"] are on the CPU and will be moved to the GPU in the main process (parallel_datamanager.py) + ray_bundle = self.ray_generator(ray_indices) if self.custom_ray_processor: ray_bundle, batch = self.custom_ray_processor(ray_bundle, batch) @@ -645,10 +645,6 @@ def __iter__(self): camera, data = self.custom_image_processor(camera, data) i += 1 - camera = camera.to(self.device) - for k in data.keys(): - if isinstance(data[k], torch.Tensor): - data[k] = data[k].to(self.device) yield camera, data