Skip to content

Commit e4a7661

Browse files
adding new developer documentation if users would like to migrate their custom datamanagers to support new features
1 parent 3fafbc2 commit e4a7661

File tree

1 file changed

+103
-0
lines changed

1 file changed

+103
-0
lines changed

docs/developer_guides/pipelines/datamanagers.md

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,3 +94,106 @@ See the code!
9494
## Creating Your Own
9595

9696
We currently don't have other implementations because most papers follow the VanillaDataManager implementation. However, it should be straightforward to add a VanillaDataManager with logic that progressively adds cameras, for instance, by relying on the step and modifying RayBundle and RayGT generation logic.
97+
98+
## Migrating Your Datamanager to the New Datamanager
99+
100+
As of January 2025, the FullImageDatamanager and ParallelImageDatamanager implementation now supports parallelized dataloading and dataloading from disk to preserve CPU RAM. If you would like your custom datamanager to also support these new features, you can migrate any custom dataloading logic to the `custom_view_processor` API. Let's take a look at an example for the LERF method, which was built on Nerfstudio's VanillaDataManager.
101+
102+
```python
103+
class LERFDataManager(VanillaDataManager): # pylint: disable=abstract-method
104+
"""Basic stored data manager implementation.
105+
106+
This is pretty much a port over from our old dataloading utilities, and is a little jank
107+
under the hood. We may clean this up a little bit under the hood with more standard dataloading
108+
components that can be strung together, but it can be just used as a black box for now since
109+
only the constructor is likely to change in the future, or maybe passing in step number to the
110+
next_train and next_eval functions.
111+
112+
Args:
113+
config: the DataManagerConfig used to instantiate class
114+
"""
115+
116+
config: LERFDataManagerConfig
117+
118+
def __init__(
119+
self,
120+
config: LERFDataManagerConfig,
121+
device: Union[torch.device, str] = "cpu",
122+
test_mode: Literal["test", "val", "inference"] = "val",
123+
world_size: int = 1,
124+
local_rank: int = 0,
125+
**kwargs, # pylint: disable=unused-argument
126+
):
127+
super().__init__(
128+
config=config, device=device, test_mode=test_mode, world_size=world_size, local_rank=local_rank, **kwargs
129+
)
130+
self.image_encoder: BaseImageEncoder = kwargs["image_encoder"]
131+
images = [self.train_dataset[i]["image"].permute(2, 0, 1)[None, ...] for i in range(len(self.train_dataset))]
132+
images = torch.cat(images)
133+
134+
cache_dir = f"outputs/{self.config.dataparser.data.name}"
135+
clip_cache_path = Path(osp.join(cache_dir, f"clip_{self.image_encoder.name}"))
136+
dino_cache_path = Path(osp.join(cache_dir, "dino.npy"))
137+
# NOTE: cache config is sensitive to list vs. tuple, because it checks for dict equality
138+
self.dino_dataloader = DinoDataloader(
139+
image_list=images,
140+
device=self.device,
141+
cfg={"image_shape": list(images.shape[2:4])},
142+
cache_path=dino_cache_path,
143+
)
144+
torch.cuda.empty_cache()
145+
self.clip_interpolator = PyramidEmbeddingDataloader(
146+
image_list=images,
147+
device=self.device,
148+
cfg={
149+
"tile_size_range": [0.05, 0.5],
150+
"tile_size_res": 7,
151+
"stride_scaler": 0.5,
152+
"image_shape": list(images.shape[2:4]),
153+
"model_name": self.image_encoder.name,
154+
},
155+
cache_path=clip_cache_path,
156+
model=self.image_encoder,
157+
)
158+
159+
def next_train(self, step: int) -> Tuple[RayBundle, Dict]:
160+
"""Returns the next batch of data from the train dataloader."""
161+
self.train_count += 1
162+
image_batch = next(self.iter_train_image_dataloader)
163+
assert self.train_pixel_sampler is not None
164+
batch = self.train_pixel_sampler.sample(image_batch)
165+
ray_indices = batch["indices"]
166+
ray_bundle = self.train_ray_generator(ray_indices)
167+
batch["clip"], clip_scale = self.clip_interpolator(ray_indices)
168+
batch["dino"] = self.dino_dataloader(ray_indices)
169+
ray_bundle.metadata["clip_scales"] = clip_scale
170+
# assume all cameras have the same focal length and image width
171+
ray_bundle.metadata["fx"] = self.train_dataset.cameras[0].fx.item()
172+
ray_bundle.metadata["width"] = self.train_dataset.cameras[0].width.item()
173+
ray_bundle.metadata["fy"] = self.train_dataset.cameras[0].fy.item()
174+
ray_bundle.metadata["height"] = self.train_dataset.cameras[0].height.item()
175+
return ray_bundle, batch
176+
```
177+
178+
To migrate this custom datamanager to the new datamanager, we can shift the data customization process in `next_train()` to `custom_view_processor()`.
179+
180+
```python
181+
class LERFDataManager(ParallelDataManager, Generic[TDataset]):
182+
183+
...
184+
185+
def custom_ray_processor(
186+
self, ray_bundle: RayBundle, batch: Dict
187+
) -> Tuple[RayBundle, Dict]:
188+
"""An API to add latents, metadata, or other further customization to the RayBundle dataloading process that is parallelized"""
189+
ray_indices = batch["indices"]
190+
batch["clip"], clip_scale = self.clip_interpolator(ray_indices)
191+
batch["dino"] = self.dino_dataloader(ray_indices)
192+
ray_bundle.metadata["clip_scales"] = clip_scale
193+
# assume all cameras have the same focal length and image width
194+
ray_bundle.metadata["fx"] = self.train_dataset.cameras[0].fx.item()
195+
ray_bundle.metadata["width"] = self.train_dataset.cameras[0].width.item()
196+
ray_bundle.metadata["fy"] = self.train_dataset.cameras[0].fy.item()
197+
ray_bundle.metadata["height"] = self.train_dataset.cameras[0].height.item()
198+
return ray_bundle, batch
199+
```

0 commit comments

Comments
 (0)