3131from gsplat .sh import num_sh_bases , spherical_harmonics
3232from pytorch_msssim import SSIM
3333from torch .nn import Parameter
34+ from typing_extensions import Literal
3435
3536from nerfstudio .cameras .cameras import Cameras
3637from nerfstudio .data .scene_box import OrientedBox
4041# need following import for background color override
4142from nerfstudio .model_components import renderers
4243from nerfstudio .models .base_model import Model , ModelConfig
44+ from nerfstudio .utils .colors import get_color
4345from nerfstudio .utils .rich_utils import CONSOLE
4446
4547
@@ -109,6 +111,8 @@ class SplatfactoModelConfig(ModelConfig):
109111 """period of steps where gaussians are culled and densified"""
110112 resolution_schedule : int = 250
111113 """training starts at 1/d resolution, every n steps this is doubled"""
114+ background_color : Literal ["random" , "black" , "white" ] = "random"
115+ """Whether to randomize the background color."""
112116 num_downscales : int = 0
113117 """at the beginning, resolution is 1/2^d, where d is this number"""
114118 cull_alpha_thresh : float = 0.1
@@ -135,6 +139,10 @@ class SplatfactoModelConfig(ModelConfig):
135139 """stop culling/splitting at this step WRT screen size of gaussians"""
136140 random_init : bool = False
137141 """whether to initialize the positions uniformly randomly (not SFM points)"""
142+ num_random : int = 50000
143+ """Number of gaussians to initialize if random init is used"""
144+ random_scale : float = 10.0
145+ "Size of the cube to initialize random gaussians within"
138146 ssim_lambda : float = 0.2
139147 """weight of ssim loss"""
140148 stop_split_at : int = 15000
@@ -171,7 +179,7 @@ def populate_modules(self):
171179 if self .seed_points is not None and not self .config .random_init :
172180 self .means = torch .nn .Parameter (self .seed_points [0 ]) # (Location, Color)
173181 else :
174- self .means = torch .nn .Parameter ((torch .rand ((500000 , 3 )) - 0.5 ) * 10 )
182+ self .means = torch .nn .Parameter ((torch .rand ((self . config . num_random , 3 )) - 0.5 ) * self . config . random_scale )
175183 self .xys_grad_norm = None
176184 self .max_2Dsize = None
177185 distances , _ = self .k_nearest_sklearn (self .means .data , 3 )
@@ -213,7 +221,10 @@ def populate_modules(self):
213221 self .step = 0
214222
215223 self .crop_box : Optional [OrientedBox ] = None
216- self .back_color = torch .zeros (3 )
224+ if self .config .background_color == "random" :
225+ self .background_color = torch .rand (3 )
226+ else :
227+ self .background_color = get_color (self .config .background_color )
217228
218229 @property
219230 def colors (self ):
@@ -295,7 +306,10 @@ def dup_in_optim(self, optimizer, dup_mask, new_params, n=2):
295306 param_state = optimizer .state [param ]
296307 repeat_dims = (n ,) + tuple (1 for _ in range (param_state ["exp_avg" ].dim () - 1 ))
297308 param_state ["exp_avg" ] = torch .cat (
298- [param_state ["exp_avg" ], torch .zeros_like (param_state ["exp_avg" ][dup_mask .squeeze ()]).repeat (* repeat_dims )],
309+ [
310+ param_state ["exp_avg" ],
311+ torch .zeros_like (param_state ["exp_avg" ][dup_mask .squeeze ()]).repeat (* repeat_dims ),
312+ ],
299313 dim = 0 ,
300314 )
301315 param_state ["exp_avg_sq" ] = torch .cat (
@@ -339,15 +353,16 @@ def after_train(self, step: int):
339353 self .max_2Dsize = torch .zeros_like (self .radii , dtype = torch .float32 )
340354 newradii = self .radii .detach ()[visible_mask ]
341355 self .max_2Dsize [visible_mask ] = torch .maximum (
342- self .max_2Dsize [visible_mask ], newradii / float (max (self .last_size [0 ], self .last_size [1 ]))
356+ self .max_2Dsize [visible_mask ],
357+ newradii / float (max (self .last_size [0 ], self .last_size [1 ])),
343358 )
344359
345360 def set_crop (self , crop_box : Optional [OrientedBox ]):
346361 self .crop_box = crop_box
347362
348- def set_background (self , back_color : torch .Tensor ):
349- assert back_color .shape == (3 ,)
350- self .back_color = back_color
363+ def set_background (self , background_color : torch .Tensor ):
364+ assert background_color .shape == (3 ,)
365+ self .background_color = background_color
351366
352367 def refinement_after (self , optimizers : Optimizers , step ):
353368 assert step == self .step
@@ -394,17 +409,31 @@ def refinement_after(self, optimizers: Optimizers, step):
394409 ) = self .dup_gaussians (dups )
395410 self .means = Parameter (torch .cat ([self .means .detach (), split_means , dup_means ], dim = 0 ))
396411 self .features_dc = Parameter (
397- torch .cat ([self .features_dc .detach (), split_features_dc , dup_features_dc ], dim = 0 )
412+ torch .cat (
413+ [self .features_dc .detach (), split_features_dc , dup_features_dc ],
414+ dim = 0 ,
415+ )
398416 )
399417 self .features_rest = Parameter (
400- torch .cat ([self .features_rest .detach (), split_features_rest , dup_features_rest ], dim = 0 )
418+ torch .cat (
419+ [
420+ self .features_rest .detach (),
421+ split_features_rest ,
422+ dup_features_rest ,
423+ ],
424+ dim = 0 ,
425+ )
401426 )
402427 self .opacities = Parameter (torch .cat ([self .opacities .detach (), split_opacities , dup_opacities ], dim = 0 ))
403428 self .scales = Parameter (torch .cat ([self .scales .detach (), split_scales , dup_scales ], dim = 0 ))
404429 self .quats = Parameter (torch .cat ([self .quats .detach (), split_quats , dup_quats ], dim = 0 ))
405430 # append zeros to the max_2Dsize tensor
406431 self .max_2Dsize = torch .cat (
407- [self .max_2Dsize , torch .zeros_like (split_scales [:, 0 ]), torch .zeros_like (dup_scales [:, 0 ])],
432+ [
433+ self .max_2Dsize ,
434+ torch .zeros_like (split_scales [:, 0 ]),
435+ torch .zeros_like (dup_scales [:, 0 ]),
436+ ],
408437 dim = 0 ,
409438 )
410439
@@ -416,7 +445,14 @@ def refinement_after(self, optimizers: Optimizers, step):
416445
417446 # After a guassian is split into two new gaussians, the original one should also be pruned.
418447 splits_mask = torch .cat (
419- (splits , torch .zeros (nsamps * splits .sum () + dups .sum (), device = self .device , dtype = torch .bool ))
448+ (
449+ splits ,
450+ torch .zeros (
451+ nsamps * splits .sum () + dups .sum (),
452+ device = self .device ,
453+ dtype = torch .bool ,
454+ ),
455+ )
420456 )
421457
422458 deleted_mask = self .cull_gaussians (splits_mask )
@@ -433,7 +469,8 @@ def refinement_after(self, optimizers: Optimizers, step):
433469 # Reset value is set to be twice of the cull_alpha_thresh
434470 reset_value = self .config .cull_alpha_thresh * 2.0
435471 self .opacities .data = torch .clamp (
436- self .opacities .data , max = torch .logit (torch .tensor (reset_value , device = self .device )).item ()
472+ self .opacities .data ,
473+ max = torch .logit (torch .tensor (reset_value , device = self .device )).item (),
437474 )
438475 # reset the exp of optimizer
439476 optim = optimizers .optimizers ["opacity" ]
@@ -507,7 +544,14 @@ def split_gaussians(self, split_mask, samps):
507544 self .scales [split_mask ] = torch .log (torch .exp (self .scales [split_mask ]) / size_fac )
508545 # step 5, sample new quats
509546 new_quats = self .quats [split_mask ].repeat (samps , 1 )
510- return new_means , new_features_dc , new_features_rest , new_opacities , new_scales , new_quats
547+ return (
548+ new_means ,
549+ new_features_dc ,
550+ new_features_rest ,
551+ new_opacities ,
552+ new_scales ,
553+ new_quats ,
554+ )
511555
512556 def dup_gaussians (self , dup_mask ):
513557 """
@@ -521,7 +565,14 @@ def dup_gaussians(self, dup_mask):
521565 dup_opacities = self .opacities [dup_mask ]
522566 dup_scales = self .scales [dup_mask ]
523567 dup_quats = self .quats [dup_mask ]
524- return dup_means , dup_features_dc , dup_features_rest , dup_opacities , dup_scales , dup_quats
568+ return (
569+ dup_means ,
570+ dup_features_dc ,
571+ dup_features_rest ,
572+ dup_opacities ,
573+ dup_scales ,
574+ dup_quats ,
575+ )
525576
526577 @property
527578 def num_points (self ):
@@ -573,7 +624,10 @@ def get_param_groups(self) -> Dict[str, List[Parameter]]:
573624
574625 def _get_downscale_factor (self ):
575626 if self .training :
576- return 2 ** max ((self .config .num_downscales - self .step // self .config .resolution_schedule ), 0 )
627+ return 2 ** max (
628+ (self .config .num_downscales - self .step // self .config .resolution_schedule ),
629+ 0 ,
630+ )
577631 else :
578632 return 1
579633
@@ -591,14 +645,23 @@ def get_outputs(self, camera: Cameras) -> Dict[str, Union[torch.Tensor, List]]:
591645 print ("Called get_outputs with not a camera" )
592646 return {}
593647 assert camera .shape [0 ] == 1 , "Only one camera at a time"
648+
649+ # get the background color
594650 if self .training :
595- background = torch .rand (3 , device = self .device )
651+ if self .config .background_color == "random" :
652+ background = torch .rand (3 , device = self .device )
653+ elif self .config .background_color == "white" :
654+ background = torch .ones (3 , device = self .device )
655+ elif self .config .background_color == "black" :
656+ background = torch .zeros (3 , device = self .device )
657+ else :
658+ background = self .background_color .to (self .device )
596659 else :
597- # logic for setting the background of the scene
598660 if renderers .BACKGROUND_COLOR_OVERRIDE is not None :
599- background = renderers .BACKGROUND_COLOR_OVERRIDE
661+ background = renderers .BACKGROUND_COLOR_OVERRIDE . to ( self . device )
600662 else :
601- background = self .back_color .to (self .device )
663+ background = self .background_color .to (self .device )
664+
602665 if self .crop_box is not None and not self .training :
603666 crop_ids = self .crop_box .within (self .means ).squeeze ()
604667 if crop_ids .sum () == 0 :
@@ -684,9 +747,7 @@ def get_outputs(self, camera: Cameras) -> Dict[str, Union[torch.Tensor, List]]:
684747
685748 # rescale the camera back to original dimensions
686749 camera .rescale_output_resolution (camera_downscale )
687-
688750 assert (num_tiles_hit > 0 ).any () # type: ignore
689-
690751 rgb = rasterize_gaussians ( # type: ignore
691752 self .xys ,
692753 depths ,
@@ -777,7 +838,8 @@ def get_loss_dict(self, outputs, batch, metrics_dict=None) -> Dict[str, torch.Te
777838 scale_exp = torch .exp (self .scales )
778839 scale_reg = (
779840 torch .maximum (
780- scale_exp .amax (dim = - 1 ) / scale_exp .amin (dim = - 1 ), torch .tensor (self .config .max_gauss_ratio )
841+ scale_exp .amax (dim = - 1 ) / scale_exp .amin (dim = - 1 ),
842+ torch .tensor (self .config .max_gauss_ratio ),
781843 )
782844 - self .config .max_gauss_ratio
783845 )
0 commit comments