Skip to content
30 changes: 12 additions & 18 deletions nerfstudio/data/pixel_samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import random
import warnings
from collections import defaultdict
from dataclasses import dataclass, field
from typing import Dict, Optional, Type, Union

Expand Down Expand Up @@ -335,8 +336,7 @@ def collate_image_dataset_batch_list(self, batch: Dict, num_rays_per_batch: int,

# only sample within the mask, if the mask is in the batch
all_indices = []
all_images = []
all_depth_images = []
all_images = defaultdict(list)

assert num_rays_per_batch % 2 == 0, "num_rays_per_batch must be divisible by 2"
num_rays_per_image = divide_rays_per_image(num_rays_per_batch, num_images)
Expand All @@ -350,10 +350,11 @@ def collate_image_dataset_batch_list(self, batch: Dict, num_rays_per_batch: int,
)
indices[:, 0] = i
all_indices.append(indices)
all_images.append(batch["image"][i][indices[:, 1], indices[:, 2]])
if "depth_image" in batch:
all_depth_images.append(batch["depth_image"][i][indices[:, 1], indices[:, 2]])

for key, value in batch.items():
if key in ["image_idx", "mask"]:
continue
all_images[key].append(value[i][indices[:, 1], indices[:, 2]])
else:
for i, num_rays in enumerate(num_rays_per_image):
image_height, image_width, _ = batch["image"][i].shape
Expand All @@ -363,26 +364,19 @@ def collate_image_dataset_batch_list(self, batch: Dict, num_rays_per_batch: int,
indices = self.sample_method(num_rays, 1, image_height, image_width, device=device)
indices[:, 0] = i
all_indices.append(indices)
all_images.append(batch["image"][i][indices[:, 1], indices[:, 2]])
if "depth_image" in batch:
all_depth_images.append(batch["depth_image"][i][indices[:, 1], indices[:, 2]])
for key, value in batch.items():
if key in ["image_idx", "mask"]:
continue
all_images[key].append(value[i][indices[:, 1], indices[:, 2]])

indices = torch.cat(all_indices, dim=0)

c, y, x = (i.flatten() for i in torch.split(indices, 1, dim=-1))
collated_batch = {
key: value[c, y, x]
for key, value in batch.items()
if key not in ("image_idx", "image", "mask", "depth_image") and value is not None
}

collated_batch["image"] = torch.cat(all_images, dim=0)
if "depth_image" in batch:
collated_batch["depth_image"] = torch.cat(all_depth_images, dim=0)
collated_batch = {key: torch.cat(all_images[key], dim=0) for key in all_images}

assert collated_batch["image"].shape[0] == num_rays_per_batch

# Needed to correct the random indices to their actual camera idx locations.
c = indices[..., 0].flatten()
indices[:, 0] = batch["image_idx"][c]
collated_batch["indices"] = indices # with the abs camera indices

Expand Down