Skip to content

Commit fcee08d

Browse files
fixing prefetch_factor type issues if user decides to use 0 workers
1 parent 217aee7 commit fcee08d

File tree

3 files changed

+240
-2
lines changed

3 files changed

+240
-2
lines changed

nerfstudio/data/datamanagers/full_images_datamanager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ class FullImageDatamanagerConfig(DataManagerConfig):
8585
dataloader_num_workers: int = 4
8686
"""The number of workers performing the dataloading from either disk/RAM, which
8787
includes collating, pixel sampling, unprojecting, ray generation etc."""
88-
prefetch_factor: int = 4
88+
prefetch_factor: int | None = 4
8989
"""The limit number of batches a worker will start loading once an iterator is created.
9090
More details are described here: https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader"""
9191
cache_compressed_images: bool = False

nerfstudio/data/datamanagers/parallel_datamanager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ class ParallelDataManagerConfig(VanillaDataManagerConfig):
5656
dataloader_num_workers: int = 4
5757
"""The number of workers performing the dataloading from either disk/RAM, which
5858
includes collating, pixel sampling, unprojecting, ray generation etc."""
59-
prefetch_factor: int = 10
59+
prefetch_factor: int | None = 10
6060
"""The limit number of batches a worker will start loading once an iterator is created.
6161
More details are described here: https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader"""
6262
cache_compressed_images: bool = False
Lines changed: 238 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,238 @@
1+
# Copyright 2022 the Regents of the University of California, Nerfstudio Team and contributors. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""
16+
Parallel data manager that outputs cameras / images instead of raybundles.
17+
"""
18+
19+
from __future__ import annotations
20+
21+
import random
22+
from functools import cached_property
23+
from pathlib import Path
24+
from typing import Dict, ForwardRef, Generic, List, Literal, Tuple, Type, Union, cast, get_args, get_origin
25+
26+
import fpsample
27+
import numpy as np
28+
import torch
29+
from torch.nn import Parameter
30+
from torch.utils.data import DataLoader
31+
32+
from nerfstudio.cameras.cameras import Cameras
33+
from nerfstudio.data.datamanagers.base_datamanager import DataManager, TDataset
34+
from nerfstudio.data.datamanagers.full_images_datamanager import FullImageDatamanagerConfig
35+
from nerfstudio.data.dataparsers.base_dataparser import DataparserOutputs
36+
from nerfstudio.data.datasets.base_dataset import InputDataset
37+
from nerfstudio.data.utils.data_utils import identity_collate
38+
from nerfstudio.data.utils.dataloaders import ImageBatchStream, undistort_view
39+
from nerfstudio.utils.misc import get_orig_class
40+
from nerfstudio.utils.rich_utils import CONSOLE
41+
42+
43+
class ParallelFullImageDatamanager(DataManager, Generic[TDataset]):
44+
def __init__(
45+
self,
46+
config: FullImageDatamanagerConfig,
47+
device: Union[torch.device, str] = "cpu",
48+
test_mode: Literal["test", "val", "inference"] = "val",
49+
world_size: int = 1,
50+
local_rank: int = 0,
51+
**kwargs,
52+
):
53+
self.config = config
54+
self.device = device
55+
self.world_size = world_size
56+
self.local_rank = local_rank
57+
self.sampler = None
58+
self.test_mode = test_mode
59+
self.test_split = "test" if test_mode in ["test", "inference"] else "val"
60+
self.dataparser_config = self.config.dataparser
61+
if self.config.data is not None:
62+
self.config.dataparser.data = Path(self.config.data)
63+
else:
64+
self.config.data = self.config.dataparser.data
65+
self.dataparser = self.dataparser_config.setup()
66+
if test_mode == "inference":
67+
self.dataparser.downscale_factor = 1 # Avoid opening images
68+
self.includes_time = self.dataparser.includes_time
69+
70+
self.train_dataparser_outputs: DataparserOutputs = self.dataparser.get_dataparser_outputs(split="train")
71+
self.train_dataset = self.create_train_dataset()
72+
self.eval_dataset = self.create_eval_dataset()
73+
74+
if len(self.train_dataset) > 500 and self.config.cache_images == "gpu":
75+
CONSOLE.print(
76+
"Train dataset has over 500 images, overriding cache_images to cpu",
77+
style="bold yellow",
78+
)
79+
self.config.cache_images = "cpu"
80+
81+
# Some logic to make sure we sample every camera in equal amounts
82+
self.train_unseen_cameras = self.sample_train_cameras()
83+
self.eval_unseen_cameras = [i for i in range(len(self.eval_dataset))]
84+
assert len(self.train_unseen_cameras) > 0, "No data found in dataset"
85+
86+
super().__init__()
87+
88+
def sample_train_cameras(self):
89+
"""Return a list of camera indices sampled using the strategy specified by
90+
self.config.train_cameras_sampling_strategy"""
91+
num_train_cameras = len(self.train_dataset)
92+
if self.config.train_cameras_sampling_strategy == "random":
93+
if not hasattr(self, "random_generator"):
94+
self.random_generator = random.Random(self.config.train_cameras_sampling_seed)
95+
indices = list(range(num_train_cameras))
96+
self.random_generator.shuffle(indices)
97+
return indices
98+
elif self.config.train_cameras_sampling_strategy == "fps":
99+
if not hasattr(self, "train_unsampled_epoch_count"):
100+
np.random.seed(self.config.train_cameras_sampling_seed) # fix random seed of fpsample
101+
self.train_unsampled_epoch_count = np.zeros(num_train_cameras)
102+
camera_origins = self.train_dataset.cameras.camera_to_worlds[..., 3].numpy()
103+
# We concatenate camera origins with weighted train_unsampled_epoch_count because we want to
104+
# increase the chance to sample camera that hasn't been sampled in consecutive epochs previously.
105+
# We assume the camera origins are also rescaled, so the weight 0.1 is relative to the scale of scene
106+
data = np.concatenate(
107+
(camera_origins, 0.1 * np.expand_dims(self.train_unsampled_epoch_count, axis=-1)), axis=-1
108+
)
109+
n = self.config.fps_reset_every
110+
if num_train_cameras < n:
111+
CONSOLE.log(
112+
f"num_train_cameras={num_train_cameras} is smaller than fps_reset_ever={n}, the behavior of "
113+
"camera sampler will be very similar to sampling random without replacement (default setting)."
114+
)
115+
n = num_train_cameras
116+
kdline_fps_samples_idx = fpsample.bucket_fps_kdline_sampling(data, n, h=3)
117+
118+
self.train_unsampled_epoch_count += 1
119+
self.train_unsampled_epoch_count[kdline_fps_samples_idx] = 0
120+
return kdline_fps_samples_idx.tolist()
121+
else:
122+
raise ValueError(f"Unknown train camera sampling strategy: {self.config.train_cameras_sampling_strategy}")
123+
124+
def create_train_dataset(self) -> TDataset:
125+
"""Sets up the data loaders for training"""
126+
return self.dataset_type(
127+
dataparser_outputs=self.train_dataparser_outputs,
128+
scale_factor=self.config.camera_res_scale_factor,
129+
cache_compressed_images=self.config.cache_compressed_images,
130+
)
131+
132+
def create_eval_dataset(self) -> TDataset:
133+
"""Sets up the data loaders for evaluation"""
134+
return self.dataset_type(
135+
dataparser_outputs=self.dataparser.get_dataparser_outputs(split=self.test_split),
136+
scale_factor=self.config.camera_res_scale_factor,
137+
cache_compressed_images=self.config.cache_compressed_images,
138+
)
139+
140+
@cached_property
141+
def dataset_type(self) -> Type[TDataset]:
142+
"""Returns the dataset type passed as the generic argument"""
143+
default: Type[TDataset] = cast(TDataset, TDataset.__default__) # type: ignore
144+
orig_class: Type[ParallelFullImageDatamanager] = get_orig_class(self, default=None) # type: ignore
145+
if type(self) is ParallelFullImageDatamanager and orig_class is None:
146+
return default
147+
if orig_class is not None and get_origin(orig_class) is ParallelFullImageDatamanager:
148+
return get_args(orig_class)[0]
149+
150+
# For inherited classes, we need to find the correct type to instantiate
151+
for base in getattr(self, "__orig_bases__", []):
152+
if get_origin(base) is ParallelFullImageDatamanager:
153+
for value in get_args(base):
154+
if isinstance(value, ForwardRef):
155+
if value.__forward_evaluated__:
156+
value = value.__forward_value__
157+
elif value.__forward_module__ is None:
158+
value.__forward_module__ = type(self).__module__
159+
value = getattr(value, "_evaluate")(None, None, set())
160+
assert isinstance(value, type)
161+
if issubclass(value, InputDataset):
162+
return cast(Type[TDataset], value)
163+
return default
164+
165+
def get_datapath(self) -> Path:
166+
return self.config.dataparser.data
167+
168+
def setup_train(self):
169+
self.train_imagebatch_stream = ImageBatchStream(
170+
input_dataset=self.train_dataset,
171+
cache_images_type=self.config.cache_images_type,
172+
sampling_seed=self.config.train_cameras_sampling_seed,
173+
device=self.device,
174+
custom_view_processor=self.custom_view_processor,
175+
)
176+
self.train_image_dataloader = DataLoader(
177+
self.train_imagebatch_stream,
178+
batch_size=1,
179+
num_workers=self.config.dataloader_num_workers,
180+
collate_fn=identity_collate,
181+
# pin_memory_device=self.device, # for some reason if we pin memory, exporting to PLY file doesn't work
182+
)
183+
self.iter_train_image_dataloader = iter(self.train_image_dataloader)
184+
185+
def setup_eval(self):
186+
self.eval_imagebatch_stream = ImageBatchStream(
187+
input_dataset=self.eval_dataset,
188+
cache_images_type=self.config.cache_images_type,
189+
sampling_seed=self.config.train_cameras_sampling_seed,
190+
device=self.device,
191+
custom_view_processor=self.custom_view_processor,
192+
)
193+
self.eval_image_dataloader = DataLoader(
194+
self.eval_imagebatch_stream,
195+
batch_size=1,
196+
num_workers=0,
197+
collate_fn=identity_collate,
198+
)
199+
self.iter_eval_image_dataloader = iter(self.eval_image_dataloader)
200+
201+
@property
202+
def fixed_indices_eval_dataloader(self) -> List[Tuple[Cameras, Dict]]:
203+
return self.iter_eval_image_dataloader
204+
205+
def get_param_groups(self) -> Dict[str, List[Parameter]]:
206+
"""Get the param groups for the data manager.
207+
Returns:
208+
A list of dictionaries containing the data manager's param groups.
209+
"""
210+
return {}
211+
212+
def get_train_rays_per_batch(self):
213+
# TODO: fix this to be the resolution of the last image rendered
214+
return 800 * 800
215+
216+
def next_train(self, step: int) -> Tuple[Cameras, Dict]:
217+
self.train_count += 1
218+
camera, data = next(self.iter_train_image_dataloader)[0]
219+
return camera, data
220+
221+
def next_eval(self, step: int) -> Tuple[Cameras, Dict]:
222+
self.eval_count += 1
223+
camera, data = next(self.iter_train_image_dataloader)[0]
224+
return camera, data
225+
226+
def next_eval_image(self, step: int) -> Tuple[Cameras, Dict]:
227+
"""Returns the next evaluation batch
228+
229+
Returns a Camera instead of raybundle"""
230+
image_idx = self.eval_unseen_cameras.pop(random.randint(0, len(self.eval_unseen_cameras) - 1))
231+
# Make sure to re-populate the unseen cameras list if we have exhausted it
232+
if len(self.eval_unseen_cameras) == 0:
233+
self.eval_unseen_cameras = [i for i in range(len(self.eval_dataset))]
234+
return undistort_view(image_idx, self.eval_dataset, self.config.cache_images_type)
235+
236+
def custom_view_processor(self, camera: Cameras, data: Dict) -> Tuple[Cameras, Dict]:
237+
"""An API to add latents, metadata, or other further customization an camera-and-image view dataloading process that is parallelized"""
238+
return camera, data

0 commit comments

Comments
 (0)