|
| 1 | +# Copyright 2022 the Regents of the University of California, Nerfstudio Team and contributors. All rights reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +""" |
| 16 | +Parallel data manager that outputs cameras / images instead of raybundles. |
| 17 | +""" |
| 18 | + |
| 19 | +from __future__ import annotations |
| 20 | + |
| 21 | +import random |
| 22 | +from functools import cached_property |
| 23 | +from pathlib import Path |
| 24 | +from typing import Dict, ForwardRef, Generic, List, Literal, Tuple, Type, Union, cast, get_args, get_origin |
| 25 | + |
| 26 | +import fpsample |
| 27 | +import numpy as np |
| 28 | +import torch |
| 29 | +from torch.nn import Parameter |
| 30 | +from torch.utils.data import DataLoader |
| 31 | + |
| 32 | +from nerfstudio.cameras.cameras import Cameras |
| 33 | +from nerfstudio.data.datamanagers.base_datamanager import DataManager, TDataset |
| 34 | +from nerfstudio.data.datamanagers.full_images_datamanager import FullImageDatamanagerConfig |
| 35 | +from nerfstudio.data.dataparsers.base_dataparser import DataparserOutputs |
| 36 | +from nerfstudio.data.datasets.base_dataset import InputDataset |
| 37 | +from nerfstudio.data.utils.data_utils import identity_collate |
| 38 | +from nerfstudio.data.utils.dataloaders import ImageBatchStream, undistort_view |
| 39 | +from nerfstudio.utils.misc import get_orig_class |
| 40 | +from nerfstudio.utils.rich_utils import CONSOLE |
| 41 | + |
| 42 | + |
| 43 | +class ParallelFullImageDatamanager(DataManager, Generic[TDataset]): |
| 44 | + def __init__( |
| 45 | + self, |
| 46 | + config: FullImageDatamanagerConfig, |
| 47 | + device: Union[torch.device, str] = "cpu", |
| 48 | + test_mode: Literal["test", "val", "inference"] = "val", |
| 49 | + world_size: int = 1, |
| 50 | + local_rank: int = 0, |
| 51 | + **kwargs, |
| 52 | + ): |
| 53 | + self.config = config |
| 54 | + self.device = device |
| 55 | + self.world_size = world_size |
| 56 | + self.local_rank = local_rank |
| 57 | + self.sampler = None |
| 58 | + self.test_mode = test_mode |
| 59 | + self.test_split = "test" if test_mode in ["test", "inference"] else "val" |
| 60 | + self.dataparser_config = self.config.dataparser |
| 61 | + if self.config.data is not None: |
| 62 | + self.config.dataparser.data = Path(self.config.data) |
| 63 | + else: |
| 64 | + self.config.data = self.config.dataparser.data |
| 65 | + self.dataparser = self.dataparser_config.setup() |
| 66 | + if test_mode == "inference": |
| 67 | + self.dataparser.downscale_factor = 1 # Avoid opening images |
| 68 | + self.includes_time = self.dataparser.includes_time |
| 69 | + |
| 70 | + self.train_dataparser_outputs: DataparserOutputs = self.dataparser.get_dataparser_outputs(split="train") |
| 71 | + self.train_dataset = self.create_train_dataset() |
| 72 | + self.eval_dataset = self.create_eval_dataset() |
| 73 | + |
| 74 | + if len(self.train_dataset) > 500 and self.config.cache_images == "gpu": |
| 75 | + CONSOLE.print( |
| 76 | + "Train dataset has over 500 images, overriding cache_images to cpu", |
| 77 | + style="bold yellow", |
| 78 | + ) |
| 79 | + self.config.cache_images = "cpu" |
| 80 | + |
| 81 | + # Some logic to make sure we sample every camera in equal amounts |
| 82 | + self.train_unseen_cameras = self.sample_train_cameras() |
| 83 | + self.eval_unseen_cameras = [i for i in range(len(self.eval_dataset))] |
| 84 | + assert len(self.train_unseen_cameras) > 0, "No data found in dataset" |
| 85 | + |
| 86 | + super().__init__() |
| 87 | + |
| 88 | + def sample_train_cameras(self): |
| 89 | + """Return a list of camera indices sampled using the strategy specified by |
| 90 | + self.config.train_cameras_sampling_strategy""" |
| 91 | + num_train_cameras = len(self.train_dataset) |
| 92 | + if self.config.train_cameras_sampling_strategy == "random": |
| 93 | + if not hasattr(self, "random_generator"): |
| 94 | + self.random_generator = random.Random(self.config.train_cameras_sampling_seed) |
| 95 | + indices = list(range(num_train_cameras)) |
| 96 | + self.random_generator.shuffle(indices) |
| 97 | + return indices |
| 98 | + elif self.config.train_cameras_sampling_strategy == "fps": |
| 99 | + if not hasattr(self, "train_unsampled_epoch_count"): |
| 100 | + np.random.seed(self.config.train_cameras_sampling_seed) # fix random seed of fpsample |
| 101 | + self.train_unsampled_epoch_count = np.zeros(num_train_cameras) |
| 102 | + camera_origins = self.train_dataset.cameras.camera_to_worlds[..., 3].numpy() |
| 103 | + # We concatenate camera origins with weighted train_unsampled_epoch_count because we want to |
| 104 | + # increase the chance to sample camera that hasn't been sampled in consecutive epochs previously. |
| 105 | + # We assume the camera origins are also rescaled, so the weight 0.1 is relative to the scale of scene |
| 106 | + data = np.concatenate( |
| 107 | + (camera_origins, 0.1 * np.expand_dims(self.train_unsampled_epoch_count, axis=-1)), axis=-1 |
| 108 | + ) |
| 109 | + n = self.config.fps_reset_every |
| 110 | + if num_train_cameras < n: |
| 111 | + CONSOLE.log( |
| 112 | + f"num_train_cameras={num_train_cameras} is smaller than fps_reset_ever={n}, the behavior of " |
| 113 | + "camera sampler will be very similar to sampling random without replacement (default setting)." |
| 114 | + ) |
| 115 | + n = num_train_cameras |
| 116 | + kdline_fps_samples_idx = fpsample.bucket_fps_kdline_sampling(data, n, h=3) |
| 117 | + |
| 118 | + self.train_unsampled_epoch_count += 1 |
| 119 | + self.train_unsampled_epoch_count[kdline_fps_samples_idx] = 0 |
| 120 | + return kdline_fps_samples_idx.tolist() |
| 121 | + else: |
| 122 | + raise ValueError(f"Unknown train camera sampling strategy: {self.config.train_cameras_sampling_strategy}") |
| 123 | + |
| 124 | + def create_train_dataset(self) -> TDataset: |
| 125 | + """Sets up the data loaders for training""" |
| 126 | + return self.dataset_type( |
| 127 | + dataparser_outputs=self.train_dataparser_outputs, |
| 128 | + scale_factor=self.config.camera_res_scale_factor, |
| 129 | + cache_compressed_images=self.config.cache_compressed_images, |
| 130 | + ) |
| 131 | + |
| 132 | + def create_eval_dataset(self) -> TDataset: |
| 133 | + """Sets up the data loaders for evaluation""" |
| 134 | + return self.dataset_type( |
| 135 | + dataparser_outputs=self.dataparser.get_dataparser_outputs(split=self.test_split), |
| 136 | + scale_factor=self.config.camera_res_scale_factor, |
| 137 | + cache_compressed_images=self.config.cache_compressed_images, |
| 138 | + ) |
| 139 | + |
| 140 | + @cached_property |
| 141 | + def dataset_type(self) -> Type[TDataset]: |
| 142 | + """Returns the dataset type passed as the generic argument""" |
| 143 | + default: Type[TDataset] = cast(TDataset, TDataset.__default__) # type: ignore |
| 144 | + orig_class: Type[ParallelFullImageDatamanager] = get_orig_class(self, default=None) # type: ignore |
| 145 | + if type(self) is ParallelFullImageDatamanager and orig_class is None: |
| 146 | + return default |
| 147 | + if orig_class is not None and get_origin(orig_class) is ParallelFullImageDatamanager: |
| 148 | + return get_args(orig_class)[0] |
| 149 | + |
| 150 | + # For inherited classes, we need to find the correct type to instantiate |
| 151 | + for base in getattr(self, "__orig_bases__", []): |
| 152 | + if get_origin(base) is ParallelFullImageDatamanager: |
| 153 | + for value in get_args(base): |
| 154 | + if isinstance(value, ForwardRef): |
| 155 | + if value.__forward_evaluated__: |
| 156 | + value = value.__forward_value__ |
| 157 | + elif value.__forward_module__ is None: |
| 158 | + value.__forward_module__ = type(self).__module__ |
| 159 | + value = getattr(value, "_evaluate")(None, None, set()) |
| 160 | + assert isinstance(value, type) |
| 161 | + if issubclass(value, InputDataset): |
| 162 | + return cast(Type[TDataset], value) |
| 163 | + return default |
| 164 | + |
| 165 | + def get_datapath(self) -> Path: |
| 166 | + return self.config.dataparser.data |
| 167 | + |
| 168 | + def setup_train(self): |
| 169 | + self.train_imagebatch_stream = ImageBatchStream( |
| 170 | + input_dataset=self.train_dataset, |
| 171 | + cache_images_type=self.config.cache_images_type, |
| 172 | + sampling_seed=self.config.train_cameras_sampling_seed, |
| 173 | + device=self.device, |
| 174 | + custom_view_processor=self.custom_view_processor, |
| 175 | + ) |
| 176 | + self.train_image_dataloader = DataLoader( |
| 177 | + self.train_imagebatch_stream, |
| 178 | + batch_size=1, |
| 179 | + num_workers=self.config.dataloader_num_workers, |
| 180 | + collate_fn=identity_collate, |
| 181 | + # pin_memory_device=self.device, # for some reason if we pin memory, exporting to PLY file doesn't work |
| 182 | + ) |
| 183 | + self.iter_train_image_dataloader = iter(self.train_image_dataloader) |
| 184 | + |
| 185 | + def setup_eval(self): |
| 186 | + self.eval_imagebatch_stream = ImageBatchStream( |
| 187 | + input_dataset=self.eval_dataset, |
| 188 | + cache_images_type=self.config.cache_images_type, |
| 189 | + sampling_seed=self.config.train_cameras_sampling_seed, |
| 190 | + device=self.device, |
| 191 | + custom_view_processor=self.custom_view_processor, |
| 192 | + ) |
| 193 | + self.eval_image_dataloader = DataLoader( |
| 194 | + self.eval_imagebatch_stream, |
| 195 | + batch_size=1, |
| 196 | + num_workers=0, |
| 197 | + collate_fn=identity_collate, |
| 198 | + ) |
| 199 | + self.iter_eval_image_dataloader = iter(self.eval_image_dataloader) |
| 200 | + |
| 201 | + @property |
| 202 | + def fixed_indices_eval_dataloader(self) -> List[Tuple[Cameras, Dict]]: |
| 203 | + return self.iter_eval_image_dataloader |
| 204 | + |
| 205 | + def get_param_groups(self) -> Dict[str, List[Parameter]]: |
| 206 | + """Get the param groups for the data manager. |
| 207 | + Returns: |
| 208 | + A list of dictionaries containing the data manager's param groups. |
| 209 | + """ |
| 210 | + return {} |
| 211 | + |
| 212 | + def get_train_rays_per_batch(self): |
| 213 | + # TODO: fix this to be the resolution of the last image rendered |
| 214 | + return 800 * 800 |
| 215 | + |
| 216 | + def next_train(self, step: int) -> Tuple[Cameras, Dict]: |
| 217 | + self.train_count += 1 |
| 218 | + camera, data = next(self.iter_train_image_dataloader)[0] |
| 219 | + return camera, data |
| 220 | + |
| 221 | + def next_eval(self, step: int) -> Tuple[Cameras, Dict]: |
| 222 | + self.eval_count += 1 |
| 223 | + camera, data = next(self.iter_train_image_dataloader)[0] |
| 224 | + return camera, data |
| 225 | + |
| 226 | + def next_eval_image(self, step: int) -> Tuple[Cameras, Dict]: |
| 227 | + """Returns the next evaluation batch |
| 228 | +
|
| 229 | + Returns a Camera instead of raybundle""" |
| 230 | + image_idx = self.eval_unseen_cameras.pop(random.randint(0, len(self.eval_unseen_cameras) - 1)) |
| 231 | + # Make sure to re-populate the unseen cameras list if we have exhausted it |
| 232 | + if len(self.eval_unseen_cameras) == 0: |
| 233 | + self.eval_unseen_cameras = [i for i in range(len(self.eval_dataset))] |
| 234 | + return undistort_view(image_idx, self.eval_dataset, self.config.cache_images_type) |
| 235 | + |
| 236 | + def custom_view_processor(self, camera: Cameras, data: Dict) -> Tuple[Cameras, Dict]: |
| 237 | + """An API to add latents, metadata, or other further customization an camera-and-image view dataloading process that is parallelized""" |
| 238 | + return camera, data |
0 commit comments