Skip to content

Commit 4091694

Browse files
cleanup, passing final ns-dev-test
1 parent 3a82351 commit 4091694

File tree

2 files changed

+25
-50
lines changed

2 files changed

+25
-50
lines changed

nerfstudio/data/datamanagers/base_datamanager.py

Lines changed: 6 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
from __future__ import annotations
2020

2121
from abc import abstractmethod
22-
from collections import defaultdict
2322
from dataclasses import dataclass, field
2423
from functools import cached_property
2524
from pathlib import Path
@@ -55,44 +54,19 @@
5554
from nerfstudio.data.dataparsers.blender_dataparser import BlenderDataParserConfig
5655
from nerfstudio.data.datasets.base_dataset import InputDataset
5756
from nerfstudio.data.pixel_samplers import PatchPixelSamplerConfig, PixelSampler, PixelSamplerConfig
58-
from nerfstudio.data.utils.dataloaders import CacheDataloader, FixedIndicesEvalDataloader, RandIndicesEvalDataloader
57+
from nerfstudio.data.utils.dataloaders import (
58+
CacheDataloader,
59+
FixedIndicesEvalDataloader,
60+
RandIndicesEvalDataloader,
61+
variable_res_collate,
62+
)
5963
from nerfstudio.data.utils.nerfstudio_collate import nerfstudio_collate
6064
from nerfstudio.engine.callbacks import TrainingCallback, TrainingCallbackAttributes
6165
from nerfstudio.model_components.ray_generators import RayGenerator
6266
from nerfstudio.utils.misc import IterableWrapper, get_orig_class
6367
from nerfstudio.utils.rich_utils import CONSOLE
6468

6569

66-
def variable_res_collate(batch: List[Dict]) -> Dict:
67-
"""Default collate function for the cached dataloader.
68-
Args:
69-
batch: Batch of samples from the dataset.
70-
Returns:
71-
Collated batch.
72-
"""
73-
images = []
74-
imgdata_lists = defaultdict(list)
75-
for data in batch:
76-
image = data.pop("image")
77-
images.append(image)
78-
topop = []
79-
for key, val in data.items():
80-
if isinstance(val, torch.Tensor):
81-
# if the value has same height and width as the image, assume that it should be collated accordingly.
82-
if len(val.shape) >= 2 and val.shape[:2] == image.shape[:2]:
83-
imgdata_lists[key].append(val)
84-
topop.append(key)
85-
# now that iteration is complete, the image data items can be removed from the batch
86-
for key in topop:
87-
del data[key]
88-
89-
new_batch = nerfstudio_collate(batch)
90-
new_batch["image"] = images
91-
new_batch.update(imgdata_lists)
92-
93-
return new_batch
94-
95-
9670
@dataclass
9771
class DataManagerConfig(InstantiateConfig):
9872
"""Configuration for data manager instantiation; DataManager is in charge of keeping the train/eval dataparsers;
@@ -305,8 +279,6 @@ def get_param_groups(self) -> Dict[str, List[Parameter]]:
305279

306280
@dataclass
307281
class VanillaDataManagerConfig(DataManagerConfig):
308-
"""A basic data manager for a ray-based model"""
309-
310282
_target: Type = field(default_factory=lambda: VanillaDataManager)
311283
"""Target class to instantiate."""
312284
dataparser: AnnotatedDataParserUnion = field(default_factory=BlenderDataParserConfig)

nerfstudio/data/utils/dataloaders.py

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -524,14 +524,16 @@ def __iter__(self):
524524
"""This implementation allows every worker only cache the indices of the images they will use to generate rays to conserve RAM memory."""
525525
worker_info = get_worker_info()
526526
if worker_info is not None: # if we have multiple processes
527-
per_worker = int(math.ceil(len(self.input_dataset) / float(worker_info.num_workers)))
528-
slice_start = worker_info.id * per_worker
529-
else: # we only have a single process
530-
per_worker = len(self.input_dataset)
531-
slice_start = 0
532-
dataset_indices = list(range(len(self.input_dataset)))
533-
# the indices of the datapoints in the dataset this worker will load
534-
worker_indices = dataset_indices[slice_start : slice_start + per_worker]
527+
if len(self.input_dataset) < worker_info.num_workers:
528+
# if there's fewer datapoints than workers, each worker receives all datapoints
529+
worker_indices = list(range(len(self.input_dataset)))
530+
else:
531+
per_worker = int(math.ceil(len(self.input_dataset) / float(worker_info.num_workers)))
532+
slice_start = worker_info.id * per_worker
533+
dataset_indices = list(range(len(self.input_dataset)))
534+
worker_indices = dataset_indices[slice_start : slice_start + per_worker]
535+
else: # if we only have a single process
536+
worker_indices = list(range(len(self.input_dataset)))
535537
if not self.load_from_disk:
536538
self._cached_collated_batch = self._get_collated_batch(worker_indices)
537539
r = random.Random(3301)
@@ -549,7 +551,6 @@ def __iter__(self):
549551
collated_batch = self._cached_collated_batch
550552
elif i % self.num_times_to_repeat_images == 0:
551553
r.shuffle(worker_indices)
552-
553554
if self.num_images_to_sample_from == -1:
554555
# if -1, the worker gets all available indices in its partition
555556
image_indices = worker_indices
@@ -562,10 +563,12 @@ def __iter__(self):
562563
"""
563564
Here, the variable 'batch' refers to the output of our pixel sampler.
564565
- batch is a dict_keys(['image', 'indices'])
565-
- batch['image'] returns a pytorch tensor with shape `torch.Size([4096, 3])` , where 4096 = num_rays_per_batch. Note: each row in this tensor represents the RGB values as floats in [0, 1] of the pixel the ray goes through. The info of what specific image index that pixel belongs to is stored within batch[’indices’]
566-
- batch['indices'] returns a pytorch tensor `torch.Size([4096, 3])` tensor where each row represents (image_index=camera_index, pixelRow, pixelCol)
567-
What the pixel_sampler does (for variable_res_collate) is that it loops though each image, samples pixel within the mask,
568-
and returns them as the variable `indices` which has shape torch.Size([4096, 3]), where each row represents a pixel (image_idx, pixelRow, pixelCol)
566+
- batch['image'] returns a pytorch tensor with shape `torch.Size([4096, 3])` , where 4096 = num_rays_per_batch.
567+
- Note: each row in this tensor represents the RGB values as floats in [0, 1] of the pixel the ray goes through.
568+
- The info of what specific image index that pixel belongs to is stored within batch[’indices’]
569+
- batch['indices'] returns a pytorch tensor `torch.Size([4096, 3])` tensor where each row represents (image_idx, pixelRow, pixelCol)
570+
pixel_sampler (for variable_res_collate) will loop though each image, samples pixel within the mask, and returns
571+
them as the variable `indices` which has shape torch.Size([4096, 3]), where each row represents a pixel (image_idx, pixelRow, pixelCol)
569572
"""
570573
batch = worker_pixel_sampler.sample(collated_batch) # type: ignore
571574
# collated_batch["image"].get_device() will return CPU if self.exclude_batch_keys_from_device contains 'image'
@@ -632,9 +635,9 @@ def __iter__(self):
632635

633636
i += 1
634637
camera = camera.to(self.device)
635-
for k in data.keys():
636-
if isinstance(data[k], torch.Tensor):
637-
data[k] = data[k].to(self.device)
638+
# for k in data.keys():
639+
# if isinstance(data[k], torch.Tensor):
640+
# data[k] = data[k].to(self.device)
638641
yield camera, data
639642

640643

0 commit comments

Comments
 (0)