Skip to content

Commit 7431a5b

Browse files
authored
Merge branch 'main' into main
2 parents 5dbb321 + 5003d0e commit 7431a5b

File tree

10 files changed

+58
-33
lines changed

10 files changed

+58
-33
lines changed

docs/developer_guides/pipelines/datamanagers.md

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,32 @@ To train splatfacto with a large dataset that's unable to fit in memory, please
115115
ns-train splatfacto --data {PROCESSED_DATA_DIR} --pipeline.datamanager.cache-images disk
116116
```
117117

118+
Checkout these flowcharts for more customization on large datasets!
119+
120+
```{image} imgs/DatamanagerGuide-LargeNeRF-light.png
121+
:align: center
122+
:class: only-light
123+
:width: 600
124+
```
125+
126+
```{image} imgs/DatamanagerGuide-LargeNeRF-dark.png
127+
:align: center
128+
:class: only-dark
129+
:width: 600
130+
```
131+
132+
```{image} imgs/DatamanagerGuide-Large3DGS-light.png
133+
:align: center
134+
:class: only-light
135+
:width: 600
136+
```
137+
138+
```{image} imgs/DatamanagerGuide-Large3DGS-dark.png
139+
:align: center
140+
:class: only-dark
141+
:width: 600
142+
```
143+
118144
## Migrating Your DataManager to the new DataManager
119145
Many methods subclass a DataManager and add extra data to it. If you would like your custom datamanager to also support new parallel features, you can migrate any custom dataloading logic to the new `custom_ray_processor()` API. This function takes in a full training batch (either image or ray bundle) and allows the user to modify or add to it. Let's take a look at an example for the LERF method, which was built on Nerfstudio's VanillaDataManager. This API provides an interface to attach new information to the RayBundle (for ray based methods), Cameras object (for splatting based methods), or ground truth dictionary. It runs in a background process if disk caching is enabled, otherwise it runs in the main process.
120146

526 KB
Loading
528 KB
Loading
532 KB
Loading
532 KB
Loading

nerfstudio/configs/method_configs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@
219219
max_num_iterations=30000,
220220
mixed_precision=True,
221221
pipeline=VanillaPipelineConfig(
222-
datamanager=VanillaDataManagerConfig(
222+
datamanager=ParallelDataManagerConfig(
223223
_target=ParallelDataManager[DepthDataset],
224224
dataparser=NerfstudioDataParserConfig(),
225225
train_num_rays_per_batch=4096,

nerfstudio/data/datamanagers/full_images_datamanager.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from copy import deepcopy
2727
from dataclasses import dataclass, field
2828
from functools import cached_property
29+
from itertools import islice
2930
from pathlib import Path
3031
from typing import Dict, ForwardRef, Generic, List, Literal, Optional, Tuple, Type, Union, cast, get_args, get_origin
3132

@@ -45,7 +46,7 @@
4546
from nerfstudio.data.datasets.base_dataset import InputDataset
4647
from nerfstudio.data.utils.data_utils import identity_collate
4748
from nerfstudio.data.utils.dataloaders import ImageBatchStream, _undistort_image
48-
from nerfstudio.utils.misc import get_orig_class
49+
from nerfstudio.utils.misc import get_dict_to_torch, get_orig_class
4950
from nerfstudio.utils.rich_utils import CONSOLE
5051

5152

@@ -84,7 +85,7 @@ class FullImageDatamanagerConfig(DataManagerConfig):
8485
dataloader_num_workers: int = 4
8586
"""The number of workers performing the dataloading from either disk/RAM, which
8687
includes collating, pixel sampling, unprojecting, ray generation etc."""
87-
prefetch_factor: int = 4
88+
prefetch_factor: Optional[int] = 4
8889
"""The limit number of batches a worker will start loading once an iterator is created.
8990
More details are described here: https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader"""
9091
cache_compressed_images: bool = False
@@ -356,9 +357,9 @@ def fixed_indices_eval_dataloader(self) -> List[Tuple[Cameras, Dict]]:
356357
self.eval_imagebatch_stream,
357358
batch_size=1,
358359
num_workers=0,
359-
collate_fn=identity_collate,
360+
collate_fn=lambda x: x[0],
360361
)
361-
return [batch[0] for batch in dataloader]
362+
return list(islice(dataloader, len(self.eval_dataset)))
362363

363364
image_indices = [i for i in range(len(self.eval_dataset))]
364365
data = [d.copy() for d in self.cached_eval]
@@ -388,6 +389,8 @@ def next_train(self, step: int) -> Tuple[Cameras, Dict]:
388389
self.train_count += 1
389390
if self.config.cache_images == "disk":
390391
camera, data = next(self.iter_train_image_dataloader)[0]
392+
camera = camera.to(self.device)
393+
data = get_dict_to_torch(data, self.device)
391394
return camera, data
392395

393396
image_idx = self.train_unseen_cameras.pop(0)
@@ -414,6 +417,8 @@ def next_eval(self, step: int) -> Tuple[Cameras, Dict]:
414417
self.eval_count += 1
415418
if self.config.cache_images == "disk":
416419
camera, data = next(self.iter_eval_image_dataloader)[0]
420+
camera = camera.to(self.device)
421+
data = get_dict_to_torch(data, self.device)
417422
return camera, data
418423

419424
return self.next_eval_image(step=step)

nerfstudio/data/datamanagers/parallel_datamanager.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
RayBatchStream,
4141
variable_res_collate,
4242
)
43-
from nerfstudio.utils.misc import get_orig_class
43+
from nerfstudio.utils.misc import get_dict_to_torch, get_orig_class
4444
from nerfstudio.utils.rich_utils import CONSOLE
4545

4646

@@ -56,7 +56,7 @@ class ParallelDataManagerConfig(VanillaDataManagerConfig):
5656
dataloader_num_workers: int = 4
5757
"""The number of workers performing the dataloading from either disk/RAM, which
5858
includes collating, pixel sampling, unprojecting, ray generation etc."""
59-
prefetch_factor: int = 10
59+
prefetch_factor: Optional[int] = 10
6060
"""The limit number of batches a worker will start loading once an iterator is created.
6161
More details are described here: https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader"""
6262
cache_compressed_images: bool = False
@@ -241,12 +241,16 @@ def next_train(self, step: int) -> Tuple[RayBundle, Dict]:
241241
"""Returns the next batch of data from the train dataloader."""
242242
self.train_count += 1
243243
ray_bundle, batch = next(self.iter_train_raybundles)[0]
244+
ray_bundle = ray_bundle.to(self.device)
245+
batch = get_dict_to_torch(batch, self.device)
244246
return ray_bundle, batch
245247

246248
def next_eval(self, step: int) -> Tuple[RayBundle, Dict]:
247249
"""Returns the next batch of data from the eval dataloader."""
248250
self.eval_count += 1
249251
ray_bundle, batch = next(self.iter_train_raybundles)[0]
252+
ray_bundle = ray_bundle.to(self.device)
253+
batch = get_dict_to_torch(batch, self.device)
250254
return ray_bundle, batch
251255

252256
def next_eval_image(self, step: int) -> Tuple[Cameras, Dict]:

nerfstudio/data/pixel_samplers.py

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import random
2020
import warnings
21+
from collections import defaultdict
2122
from dataclasses import dataclass, field
2223
from typing import Dict, Optional, Type, Union
2324

@@ -335,8 +336,7 @@ def collate_image_dataset_batch_list(self, batch: Dict, num_rays_per_batch: int,
335336

336337
# only sample within the mask, if the mask is in the batch
337338
all_indices = []
338-
all_images = []
339-
all_depth_images = []
339+
all_images = defaultdict(list)
340340

341341
assert num_rays_per_batch % 2 == 0, "num_rays_per_batch must be divisible by 2"
342342
num_rays_per_image = divide_rays_per_image(num_rays_per_batch, num_images)
@@ -350,10 +350,11 @@ def collate_image_dataset_batch_list(self, batch: Dict, num_rays_per_batch: int,
350350
)
351351
indices[:, 0] = i
352352
all_indices.append(indices)
353-
all_images.append(batch["image"][i][indices[:, 1], indices[:, 2]])
354-
if "depth_image" in batch:
355-
all_depth_images.append(batch["depth_image"][i][indices[:, 1], indices[:, 2]])
356353

354+
for key, value in batch.items():
355+
if key in ["image_idx", "mask"]:
356+
continue
357+
all_images[key].append(value[i][indices[:, 1], indices[:, 2]])
357358
else:
358359
for i, num_rays in enumerate(num_rays_per_image):
359360
image_height, image_width, _ = batch["image"][i].shape
@@ -363,26 +364,19 @@ def collate_image_dataset_batch_list(self, batch: Dict, num_rays_per_batch: int,
363364
indices = self.sample_method(num_rays, 1, image_height, image_width, device=device)
364365
indices[:, 0] = i
365366
all_indices.append(indices)
366-
all_images.append(batch["image"][i][indices[:, 1], indices[:, 2]])
367-
if "depth_image" in batch:
368-
all_depth_images.append(batch["depth_image"][i][indices[:, 1], indices[:, 2]])
367+
for key, value in batch.items():
368+
if key in ["image_idx", "mask"]:
369+
continue
370+
all_images[key].append(value[i][indices[:, 1], indices[:, 2]])
369371

370372
indices = torch.cat(all_indices, dim=0)
371373

372-
c, y, x = (i.flatten() for i in torch.split(indices, 1, dim=-1))
373-
collated_batch = {
374-
key: value[c, y, x]
375-
for key, value in batch.items()
376-
if key not in ("image_idx", "image", "mask", "depth_image") and value is not None
377-
}
378-
379-
collated_batch["image"] = torch.cat(all_images, dim=0)
380-
if "depth_image" in batch:
381-
collated_batch["depth_image"] = torch.cat(all_depth_images, dim=0)
374+
collated_batch = {key: torch.cat(all_images[key], dim=0) for key in all_images}
382375

383376
assert collated_batch["image"].shape[0] == num_rays_per_batch
384377

385378
# Needed to correct the random indices to their actual camera idx locations.
379+
c = indices[..., 0].flatten()
386380
indices[:, 0] = batch["image_idx"][c]
387381
collated_batch["indices"] = indices # with the abs camera indices
388382

nerfstudio/data/utils/dataloaders.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -574,18 +574,18 @@ def __iter__(self):
574574
"""
575575
Here, the variable 'batch' refers to the output of our pixel sampler.
576576
- batch is a dict_keys(['image', 'indices'])
577-
- batch['image'] returns a pytorch tensor with shape `torch.Size([4096, 3])` , where 4096 = num_rays_per_batch.
577+
- batch['image'] returns a `torch.Size([4096, 3])` tensor on CPU, where 4096 = num_rays_per_batch.
578578
- Note: each row in this tensor represents the RGB values as floats in [0, 1] of the pixel the ray goes through.
579579
- The info of what specific image index that pixel belongs to is stored within batch[’indices’]
580-
- batch['indices'] returns a pytorch tensor `torch.Size([4096, 3])` tensor where each row represents (image_idx, pixelRow, pixelCol)
580+
- batch['indices'] returns a `torch.Size([4096, 3])` tensor on CPU where each row represents (image_idx, pixelRow, pixelCol)
581581
pixel_sampler (for variable_res_collate) will loop though each image, samples pixel within the mask, and returns
582582
them as the variable `indices` which has shape torch.Size([4096, 3]), where each row represents a pixel (image_idx, pixelRow, pixelCol)
583583
"""
584584
batch = worker_pixel_sampler.sample(collated_batch) # type: ignore
585585
# Note: collated_batch["image"].get_device() will return CPU if self.exclude_batch_keys_from_device contains 'image'
586586
ray_indices = batch["indices"]
587-
# the ray_bundle is on the GPU; batch["image"] is on the CPU, here we move it to the GPU
588-
ray_bundle = self.ray_generator(ray_indices).to(self.device)
587+
# Both ray_bundle and batch["image"] are on the CPU and will be moved to the GPU in the main process (parallel_datamanager.py)
588+
ray_bundle = self.ray_generator(ray_indices)
589589
if self.custom_ray_processor:
590590
ray_bundle, batch = self.custom_ray_processor(ray_bundle, batch)
591591

@@ -645,10 +645,6 @@ def __iter__(self):
645645
camera, data = self.custom_image_processor(camera, data)
646646

647647
i += 1
648-
camera = camera.to(self.device)
649-
for k in data.keys():
650-
if isinstance(data[k], torch.Tensor):
651-
data[k] = data[k].to(self.device)
652648
yield camera, data
653649

654650

0 commit comments

Comments
 (0)