Skip to content

Commit c83ed7d

Browse files
KevinXu02VasuAgrawalethanweberbrentyikerrj
authored
fix the bugs associated with blender datas (#2704)
* fix the bug when camera.distortion_params is None * Handle background color override when using blender. * fix bare except * format * Update background color override in Blender dataparser * Add ability to download EyefulTower dataset * wip before I copy linning's stuff in * Generate per-resolution cameras.xml * Generate transforms.json at download * Fix a couple of quotes * Use official EyefulTower splits for train and val * Disable projectaria-tools on windows * Fix extra imports * Add a new nerfacto method tund for EyefulTower * Split eyefultower download into a separate file * Fix typo * Add some fisheye support for eyeful data * Reformatted imports to not be dumb * Apparently this file was missed when formatting originally * Added 1k resolution scenes * revert method_configs.py to original values * Also add 1k exrs * Add option to modify bg color in gaussian splatting * fix back the config, bg color should work now * removed camera optimizer for gs to align with main * Address feedback * Revert changes to pyproject.toml, to be added in a later PR * Oops, probably shouldn't have gotten rid of awscli ... * adding support for bg color, tested and reformatted now * formatted * formatted * changed glob variable name * Refactor background color variable name * prevent viser overriding --------- Co-authored-by: Vasu Agrawal <[email protected]> Co-authored-by: Ethan Weber <[email protected]> Co-authored-by: Brent Yi <[email protected]> Co-authored-by: Justin Kerr <[email protected]>
1 parent ff4002d commit c83ed7d

File tree

4 files changed

+95
-28
lines changed

4 files changed

+95
-28
lines changed

nerfstudio/data/datamanagers/full_images_datamanager.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,10 @@ def __init__(
107107
self.train_dataset = self.create_train_dataset()
108108
self.eval_dataset = self.create_eval_dataset()
109109
if len(self.train_dataset) > 500 and self.config.cache_images == "gpu":
110-
CONSOLE.print("Train dataset has over 500 images, overriding cach_images to cpu", style="bold yellow")
110+
CONSOLE.print(
111+
"Train dataset has over 500 images, overriding cache_images to cpu",
112+
style="bold yellow",
113+
)
111114
self.config.cache_images = "cpu"
112115
self.cached_train, self.cached_eval = self.cache_images(self.config.cache_images)
113116
self.exclude_batch_keys_from_device = self.train_dataset.exclude_batch_keys_from_device
@@ -133,6 +136,7 @@ def cache_images(self, cache_images_option):
133136
camera = self.train_dataset.cameras[i].reshape(())
134137
K = camera.get_intrinsics_matrices().numpy()
135138
if camera.distortion_params is None:
139+
cached_train.append(data)
136140
continue
137141
distortion_params = camera.distortion_params.numpy()
138142
image = data["image"].numpy()
@@ -158,6 +162,7 @@ def cache_images(self, cache_images_option):
158162
camera = self.eval_dataset.cameras[i].reshape(())
159163
K = camera.get_intrinsics_matrices().numpy()
160164
if camera.distortion_params is None:
165+
cached_eval.append(data)
161166
continue
162167
distortion_params = camera.distortion_params.numpy()
163168
image = data["image"].numpy()

nerfstudio/data/dataparsers/blender_dataparser.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -57,13 +57,12 @@ def __init__(self, config: BlenderDataParserConfig):
5757
self.data: Path = config.data
5858
self.scale_factor: float = config.scale_factor
5959
self.alpha_color = config.alpha_color
60-
61-
def _generate_dataparser_outputs(self, split="train"):
6260
if self.alpha_color is not None:
63-
alpha_color_tensor = get_color(self.alpha_color)
61+
self.alpha_color_tensor = get_color(self.alpha_color)
6462
else:
65-
alpha_color_tensor = None
63+
self.alpha_color_tensor = None
6664

65+
def _generate_dataparser_outputs(self, split="train"):
6766
meta = load_from_json(self.data / f"transforms_{split}.json")
6867
image_filenames = []
6968
poses = []
@@ -98,7 +97,7 @@ def _generate_dataparser_outputs(self, split="train"):
9897
dataparser_outputs = DataparserOutputs(
9998
image_filenames=image_filenames,
10099
cameras=cameras,
101-
alpha_color=alpha_color_tensor,
100+
alpha_color=self.alpha_color_tensor,
102101
scene_box=scene_box,
103102
dataparser_scale=self.scale_factor,
104103
)

nerfstudio/models/splatfacto.py

Lines changed: 84 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from gsplat.sh import num_sh_bases, spherical_harmonics
3232
from pytorch_msssim import SSIM
3333
from torch.nn import Parameter
34+
from typing_extensions import Literal
3435

3536
from nerfstudio.cameras.cameras import Cameras
3637
from nerfstudio.data.scene_box import OrientedBox
@@ -40,6 +41,7 @@
4041
# need following import for background color override
4142
from nerfstudio.model_components import renderers
4243
from nerfstudio.models.base_model import Model, ModelConfig
44+
from nerfstudio.utils.colors import get_color
4345
from nerfstudio.utils.rich_utils import CONSOLE
4446

4547

@@ -109,6 +111,8 @@ class SplatfactoModelConfig(ModelConfig):
109111
"""period of steps where gaussians are culled and densified"""
110112
resolution_schedule: int = 250
111113
"""training starts at 1/d resolution, every n steps this is doubled"""
114+
background_color: Literal["random", "black", "white"] = "random"
115+
"""Whether to randomize the background color."""
112116
num_downscales: int = 0
113117
"""at the beginning, resolution is 1/2^d, where d is this number"""
114118
cull_alpha_thresh: float = 0.1
@@ -135,6 +139,10 @@ class SplatfactoModelConfig(ModelConfig):
135139
"""stop culling/splitting at this step WRT screen size of gaussians"""
136140
random_init: bool = False
137141
"""whether to initialize the positions uniformly randomly (not SFM points)"""
142+
num_random: int = 50000
143+
"""Number of gaussians to initialize if random init is used"""
144+
random_scale: float = 10.0
145+
"Size of the cube to initialize random gaussians within"
138146
ssim_lambda: float = 0.2
139147
"""weight of ssim loss"""
140148
stop_split_at: int = 15000
@@ -171,7 +179,7 @@ def populate_modules(self):
171179
if self.seed_points is not None and not self.config.random_init:
172180
self.means = torch.nn.Parameter(self.seed_points[0]) # (Location, Color)
173181
else:
174-
self.means = torch.nn.Parameter((torch.rand((500000, 3)) - 0.5) * 10)
182+
self.means = torch.nn.Parameter((torch.rand((self.config.num_random, 3)) - 0.5) * self.config.random_scale)
175183
self.xys_grad_norm = None
176184
self.max_2Dsize = None
177185
distances, _ = self.k_nearest_sklearn(self.means.data, 3)
@@ -213,7 +221,10 @@ def populate_modules(self):
213221
self.step = 0
214222

215223
self.crop_box: Optional[OrientedBox] = None
216-
self.back_color = torch.zeros(3)
224+
if self.config.background_color == "random":
225+
self.background_color = torch.rand(3)
226+
else:
227+
self.background_color = get_color(self.config.background_color)
217228

218229
@property
219230
def colors(self):
@@ -295,7 +306,10 @@ def dup_in_optim(self, optimizer, dup_mask, new_params, n=2):
295306
param_state = optimizer.state[param]
296307
repeat_dims = (n,) + tuple(1 for _ in range(param_state["exp_avg"].dim() - 1))
297308
param_state["exp_avg"] = torch.cat(
298-
[param_state["exp_avg"], torch.zeros_like(param_state["exp_avg"][dup_mask.squeeze()]).repeat(*repeat_dims)],
309+
[
310+
param_state["exp_avg"],
311+
torch.zeros_like(param_state["exp_avg"][dup_mask.squeeze()]).repeat(*repeat_dims),
312+
],
299313
dim=0,
300314
)
301315
param_state["exp_avg_sq"] = torch.cat(
@@ -339,15 +353,16 @@ def after_train(self, step: int):
339353
self.max_2Dsize = torch.zeros_like(self.radii, dtype=torch.float32)
340354
newradii = self.radii.detach()[visible_mask]
341355
self.max_2Dsize[visible_mask] = torch.maximum(
342-
self.max_2Dsize[visible_mask], newradii / float(max(self.last_size[0], self.last_size[1]))
356+
self.max_2Dsize[visible_mask],
357+
newradii / float(max(self.last_size[0], self.last_size[1])),
343358
)
344359

345360
def set_crop(self, crop_box: Optional[OrientedBox]):
346361
self.crop_box = crop_box
347362

348-
def set_background(self, back_color: torch.Tensor):
349-
assert back_color.shape == (3,)
350-
self.back_color = back_color
363+
def set_background(self, background_color: torch.Tensor):
364+
assert background_color.shape == (3,)
365+
self.background_color = background_color
351366

352367
def refinement_after(self, optimizers: Optimizers, step):
353368
assert step == self.step
@@ -394,17 +409,31 @@ def refinement_after(self, optimizers: Optimizers, step):
394409
) = self.dup_gaussians(dups)
395410
self.means = Parameter(torch.cat([self.means.detach(), split_means, dup_means], dim=0))
396411
self.features_dc = Parameter(
397-
torch.cat([self.features_dc.detach(), split_features_dc, dup_features_dc], dim=0)
412+
torch.cat(
413+
[self.features_dc.detach(), split_features_dc, dup_features_dc],
414+
dim=0,
415+
)
398416
)
399417
self.features_rest = Parameter(
400-
torch.cat([self.features_rest.detach(), split_features_rest, dup_features_rest], dim=0)
418+
torch.cat(
419+
[
420+
self.features_rest.detach(),
421+
split_features_rest,
422+
dup_features_rest,
423+
],
424+
dim=0,
425+
)
401426
)
402427
self.opacities = Parameter(torch.cat([self.opacities.detach(), split_opacities, dup_opacities], dim=0))
403428
self.scales = Parameter(torch.cat([self.scales.detach(), split_scales, dup_scales], dim=0))
404429
self.quats = Parameter(torch.cat([self.quats.detach(), split_quats, dup_quats], dim=0))
405430
# append zeros to the max_2Dsize tensor
406431
self.max_2Dsize = torch.cat(
407-
[self.max_2Dsize, torch.zeros_like(split_scales[:, 0]), torch.zeros_like(dup_scales[:, 0])],
432+
[
433+
self.max_2Dsize,
434+
torch.zeros_like(split_scales[:, 0]),
435+
torch.zeros_like(dup_scales[:, 0]),
436+
],
408437
dim=0,
409438
)
410439

@@ -416,7 +445,14 @@ def refinement_after(self, optimizers: Optimizers, step):
416445

417446
# After a guassian is split into two new gaussians, the original one should also be pruned.
418447
splits_mask = torch.cat(
419-
(splits, torch.zeros(nsamps * splits.sum() + dups.sum(), device=self.device, dtype=torch.bool))
448+
(
449+
splits,
450+
torch.zeros(
451+
nsamps * splits.sum() + dups.sum(),
452+
device=self.device,
453+
dtype=torch.bool,
454+
),
455+
)
420456
)
421457

422458
deleted_mask = self.cull_gaussians(splits_mask)
@@ -433,7 +469,8 @@ def refinement_after(self, optimizers: Optimizers, step):
433469
# Reset value is set to be twice of the cull_alpha_thresh
434470
reset_value = self.config.cull_alpha_thresh * 2.0
435471
self.opacities.data = torch.clamp(
436-
self.opacities.data, max=torch.logit(torch.tensor(reset_value, device=self.device)).item()
472+
self.opacities.data,
473+
max=torch.logit(torch.tensor(reset_value, device=self.device)).item(),
437474
)
438475
# reset the exp of optimizer
439476
optim = optimizers.optimizers["opacity"]
@@ -507,7 +544,14 @@ def split_gaussians(self, split_mask, samps):
507544
self.scales[split_mask] = torch.log(torch.exp(self.scales[split_mask]) / size_fac)
508545
# step 5, sample new quats
509546
new_quats = self.quats[split_mask].repeat(samps, 1)
510-
return new_means, new_features_dc, new_features_rest, new_opacities, new_scales, new_quats
547+
return (
548+
new_means,
549+
new_features_dc,
550+
new_features_rest,
551+
new_opacities,
552+
new_scales,
553+
new_quats,
554+
)
511555

512556
def dup_gaussians(self, dup_mask):
513557
"""
@@ -521,7 +565,14 @@ def dup_gaussians(self, dup_mask):
521565
dup_opacities = self.opacities[dup_mask]
522566
dup_scales = self.scales[dup_mask]
523567
dup_quats = self.quats[dup_mask]
524-
return dup_means, dup_features_dc, dup_features_rest, dup_opacities, dup_scales, dup_quats
568+
return (
569+
dup_means,
570+
dup_features_dc,
571+
dup_features_rest,
572+
dup_opacities,
573+
dup_scales,
574+
dup_quats,
575+
)
525576

526577
@property
527578
def num_points(self):
@@ -573,7 +624,10 @@ def get_param_groups(self) -> Dict[str, List[Parameter]]:
573624

574625
def _get_downscale_factor(self):
575626
if self.training:
576-
return 2 ** max((self.config.num_downscales - self.step // self.config.resolution_schedule), 0)
627+
return 2 ** max(
628+
(self.config.num_downscales - self.step // self.config.resolution_schedule),
629+
0,
630+
)
577631
else:
578632
return 1
579633

@@ -591,14 +645,23 @@ def get_outputs(self, camera: Cameras) -> Dict[str, Union[torch.Tensor, List]]:
591645
print("Called get_outputs with not a camera")
592646
return {}
593647
assert camera.shape[0] == 1, "Only one camera at a time"
648+
649+
# get the background color
594650
if self.training:
595-
background = torch.rand(3, device=self.device)
651+
if self.config.background_color == "random":
652+
background = torch.rand(3, device=self.device)
653+
elif self.config.background_color == "white":
654+
background = torch.ones(3, device=self.device)
655+
elif self.config.background_color == "black":
656+
background = torch.zeros(3, device=self.device)
657+
else:
658+
background = self.background_color.to(self.device)
596659
else:
597-
# logic for setting the background of the scene
598660
if renderers.BACKGROUND_COLOR_OVERRIDE is not None:
599-
background = renderers.BACKGROUND_COLOR_OVERRIDE
661+
background = renderers.BACKGROUND_COLOR_OVERRIDE.to(self.device)
600662
else:
601-
background = self.back_color.to(self.device)
663+
background = self.background_color.to(self.device)
664+
602665
if self.crop_box is not None and not self.training:
603666
crop_ids = self.crop_box.within(self.means).squeeze()
604667
if crop_ids.sum() == 0:
@@ -684,9 +747,7 @@ def get_outputs(self, camera: Cameras) -> Dict[str, Union[torch.Tensor, List]]:
684747

685748
# rescale the camera back to original dimensions
686749
camera.rescale_output_resolution(camera_downscale)
687-
688750
assert (num_tiles_hit > 0).any() # type: ignore
689-
690751
rgb = rasterize_gaussians( # type: ignore
691752
self.xys,
692753
depths,
@@ -777,7 +838,8 @@ def get_loss_dict(self, outputs, batch, metrics_dict=None) -> Dict[str, torch.Te
777838
scale_exp = torch.exp(self.scales)
778839
scale_reg = (
779840
torch.maximum(
780-
scale_exp.amax(dim=-1) / scale_exp.amin(dim=-1), torch.tensor(self.config.max_gauss_ratio)
841+
scale_exp.amax(dim=-1) / scale_exp.amin(dim=-1),
842+
torch.tensor(self.config.max_gauss_ratio),
781843
)
782844
- self.config.max_gauss_ratio
783845
)

nerfstudio/utils/tensor_dataclass.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ def _broadcast_dict_fields(self, dict_: Dict, batch_shape) -> Dict:
142142
elif isinstance(v, Dict):
143143
new_dict[k] = self._broadcast_dict_fields(v, batch_shape)
144144
else:
145+
# Don't broadcast the remaining fields
145146
new_dict[k] = v
146147
return new_dict
147148

0 commit comments

Comments
 (0)