Skip to content

Commit 37f4ca4

Browse files
committed
add docs, fix depth dataset with parallel datamanager, fix mask sampling bug
1 parent b0fc764 commit 37f4ca4

File tree

4 files changed

+8
-5
lines changed

4 files changed

+8
-5
lines changed

docs/developer_guides/pipelines/datamanagers.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,9 @@ ns-train splatfacto --data {PROCESSED_DATA_DIR} --pipeline.datamanager.cache_ima
109109
## Migrating Your DataManager to the new DataManager
110110
Many methods subclass a DataManager and add extra data to it. If you would like your custom datamanager to also support new parallel features, you can migrate any custom dataloading logic to the new `custom_view_processor()` API. Let's take a look at an example for the LERF method, which was built on Nerfstudio's VanillaDataManager. This API provides an interface to attach new information to the RayBundle (for ray based methods), Cameras object (for splatting based methods), or ground truth dictionary. It runs in a background process if disk caching is enabled, otherwise it runs in the main process.
111111

112-
**Note**: naively transfering code to `custom_view_processor` may still OOM on very large datasets if initialization code requires computing something over the whole dataset. To fully take advantage of parallelization make sure your subclassed datamanager computes new information inside the `custom_view_processor`, or caches a subset of the whole dataset. This can also still be slow if pre-computation requires GPU-heavy steps on the same GPU used for training.
112+
Naively transfering code to `custom_view_processor` may still OOM on very large datasets if initialization code requires computing something over the whole dataset. To fully take advantage of parallelization make sure your subclassed datamanager computes new information inside the `custom_view_processor`, or caches a subset of the whole dataset. This can also still be slow if pre-computation requires GPU-heavy steps on the same GPU used for training.
113+
114+
**Note**: Because the parallel DataManager uses background processes, any member of the DataManager needs to be *picklable* to be used inside `custom_view_processor`.
113115

114116
```python
115117
class LERFDataManager(VanillaDataManager):

nerfstudio/configs/method_configs.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,7 @@
177177
mixed_precision=True,
178178
pipeline=VanillaPipelineConfig(
179179
datamanager=VanillaDataManagerConfig(
180+
_target=ParallelDataManager[InputDataset],
180181
dataparser=NerfstudioDataParserConfig(),
181182
train_num_rays_per_batch=16384,
182183
eval_num_rays_per_batch=4096,
@@ -226,7 +227,7 @@
226227
mixed_precision=True,
227228
pipeline=VanillaPipelineConfig(
228229
datamanager=VanillaDataManagerConfig(
229-
_target=VanillaDataManager[DepthDataset],
230+
_target=ParallelDataManager[DepthDataset],
230231
dataparser=NerfstudioDataParserConfig(),
231232
train_num_rays_per_batch=4096,
232233
eval_num_rays_per_batch=4096,

nerfstudio/data/datasets/depth_dataset.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def __init__(
7979
filenames = dataparser_outputs.image_filenames
8080

8181
repo = "isl-org/ZoeDepth"
82-
self.zoe = torch_compile(torch.hub.load(repo, "ZoeD_NK", pretrained=True).to(device))
82+
zoe = torch_compile(torch.hub.load(repo, "ZoeD_NK", pretrained=True).to(device))
8383

8484
for i in track(range(len(filenames)), description="Generating depth images"):
8585
image_filename = filenames[i]
@@ -93,7 +93,7 @@ def __init__(
9393
image = torch.permute(image, (2, 0, 1)).unsqueeze(0).to(device)
9494
if image.shape[1] == 4:
9595
image = image[:, :3, :, :]
96-
depth_tensor = self.zoe.infer(image).squeeze().unsqueeze(-1)
96+
depth_tensor = zoe.infer(image).squeeze().unsqueeze(-1)
9797

9898
depth_tensors.append(depth_tensor)
9999

nerfstudio/data/pixel_samplers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def rejection_sample_mask(
106106
num_valid = 0
107107
for _ in range(self.config.max_num_iterations):
108108
c, y, x = (i.flatten() for i in torch.split(indices, 1, dim=-1))
109-
chosen_indices_validity = mask.squeeze()[c, y, x].bool()
109+
chosen_indices_validity = mask.squeeze(-1)[c, y, x].bool()
110110
num_valid = int(torch.sum(chosen_indices_validity).item())
111111
if num_valid == num_samples:
112112
break

0 commit comments

Comments
 (0)