2323from typing import Dict , List , Literal , Optional , Tuple , Type , Union
2424
2525import torch
26- from gsplat .strategy import DefaultStrategy
26+ from gsplat .strategy import DefaultStrategy , MCMCStrategy
2727
2828try :
2929 from gsplat .rendering import rasterization
@@ -156,6 +156,16 @@ class SplatfactoModelConfig(ModelConfig):
156156 """Shape of the bilateral grid (X, Y, W)"""
157157 color_corrected_metrics : bool = False
158158 """If True, apply color correction to the rendered images before computing the metrics."""
159+ strategy : Literal ["default" , "mcmc" ] = "default"
160+ """The default strategy will be used if strategy is not specified. Other strategies, e.g. mcmc, can be used."""
161+ max_gs_num : int = 1_000_000
162+ """Maximum number of GSs. Default to 1_000_000."""
163+ noise_lr : float = 5e5
164+ """MCMC samping noise learning rate. Default to 5e5."""
165+ mcmc_opacity_reg : float = 0.01
166+ """Regularization term for opacity in MCMC strategy. Only enabled when using MCMC strategy"""
167+ mcmc_scale_reg : float = 0.01
168+ """Regularization term for scale in MCMC strategy. Only enabled when using MCMC strategy"""
159169
160170
161171class SplatfactoModel (Model ):
@@ -249,24 +259,40 @@ def populate_modules(self):
249259 )
250260
251261 # Strategy for GS densification
252- self .strategy = DefaultStrategy (
253- prune_opa = self .config .cull_alpha_thresh ,
254- grow_grad2d = self .config .densify_grad_thresh ,
255- grow_scale3d = self .config .densify_size_thresh ,
256- grow_scale2d = self .config .split_screen_size ,
257- prune_scale3d = self .config .cull_scale_thresh ,
258- prune_scale2d = self .config .cull_screen_size ,
259- refine_scale2d_stop_iter = self .config .stop_screen_size_at ,
260- refine_start_iter = self .config .warmup_length ,
261- refine_stop_iter = self .config .stop_split_at ,
262- reset_every = self .config .reset_alpha_every * self .config .refine_every ,
263- refine_every = self .config .refine_every ,
264- pause_refine_after_reset = self .num_train_data + self .config .refine_every ,
265- absgrad = self .config .use_absgrad ,
266- revised_opacity = False ,
267- verbose = True ,
268- )
269- self .strategy_state = self .strategy .initialize_state (scene_scale = 1.0 )
262+ if self .config .strategy == "default" :
263+ # Strategy for GS densification
264+ self .strategy = DefaultStrategy (
265+ prune_opa = self .config .cull_alpha_thresh ,
266+ grow_grad2d = self .config .densify_grad_thresh ,
267+ grow_scale3d = self .config .densify_size_thresh ,
268+ grow_scale2d = self .config .split_screen_size ,
269+ prune_scale3d = self .config .cull_scale_thresh ,
270+ prune_scale2d = self .config .cull_screen_size ,
271+ refine_scale2d_stop_iter = self .config .stop_screen_size_at ,
272+ refine_start_iter = self .config .warmup_length ,
273+ refine_stop_iter = self .config .stop_split_at ,
274+ reset_every = self .config .reset_alpha_every * self .config .refine_every ,
275+ refine_every = self .config .refine_every ,
276+ pause_refine_after_reset = self .num_train_data + self .config .refine_every ,
277+ absgrad = self .config .use_absgrad ,
278+ revised_opacity = False ,
279+ verbose = True ,
280+ )
281+ self .strategy_state = self .strategy .initialize_state (scene_scale = 1.0 )
282+ elif self .config .strategy == "mcmc" :
283+ self .strategy = MCMCStrategy (
284+ cap_max = self .config .max_gs_num ,
285+ noise_lr = self .config .noise_lr ,
286+ refine_start_iter = self .config .warmup_length ,
287+ refine_stop_iter = self .config .stop_split_at ,
288+ refine_every = self .config .refine_every ,
289+ min_opacity = self .config .cull_alpha_thresh ,
290+ verbose = False ,
291+ )
292+ self .strategy_state = self .strategy .initialize_state ()
293+ else :
294+ raise ValueError (f"""Splatfacto does not support strategy { self .config .strategy }
295+ Currently, the supported strategies include default and mcmc.""" )
270296
271297 @property
272298 def colors (self ):
@@ -338,14 +364,26 @@ def set_background(self, background_color: torch.Tensor):
338364
339365 def step_post_backward (self , step ):
340366 assert step == self .step
341- self .strategy .step_post_backward (
342- params = self .gauss_params ,
343- optimizers = self .optimizers ,
344- state = self .strategy_state ,
345- step = self .step ,
346- info = self .info ,
347- packed = False ,
348- )
367+ if isinstance (self .strategy , DefaultStrategy ):
368+ self .strategy .step_post_backward (
369+ params = self .gauss_params ,
370+ optimizers = self .optimizers ,
371+ state = self .strategy_state ,
372+ step = self .step ,
373+ info = self .info ,
374+ packed = False ,
375+ )
376+ elif isinstance (self .strategy , MCMCStrategy ):
377+ self .strategy .step_post_backward (
378+ params = self .gauss_params ,
379+ optimizers = self .optimizers ,
380+ state = self .strategy_state ,
381+ step = step ,
382+ info = self .info ,
383+ lr = self .schedulers ["means" ].get_last_lr ()[0 ], # the learning rate for the "means" attribute of the GS
384+ )
385+ else :
386+ raise ValueError (f"Unknown strategy { self .strategy } " )
349387
350388 def get_training_callbacks (
351389 self , training_callback_attributes : TrainingCallbackAttributes
@@ -369,6 +407,7 @@ def get_training_callbacks(
369407 def step_cb (self , optimizers : Optimizers , step ):
370408 self .step = step
371409 self .optimizers = optimizers .optimizers
410+ self .schedulers = optimizers .schedulers
372411
373412 def get_gaussian_param_groups (self ) -> Dict [str , List [Parameter ]]:
374413 # Here we explicitly use the means, scales as parameters so that the user can override this function and
@@ -529,7 +568,7 @@ def get_outputs(self, camera: Cameras) -> Dict[str, Union[torch.Tensor, List]]:
529568 render_mode = render_mode ,
530569 sh_degree = sh_degree_to_use ,
531570 sparse_grad = False ,
532- absgrad = self .strategy .absgrad ,
571+ absgrad = self .strategy .absgrad if isinstance ( self . strategy , DefaultStrategy ) else False ,
533572 rasterize_mode = self .config .rasterize_mode ,
534573 # set some threshold to disregrad small gaussians for faster rendering.
535574 # radius_clip=3.0,
@@ -651,6 +690,17 @@ def get_loss_dict(self, outputs, batch, metrics_dict=None) -> Dict[str, torch.Te
651690 "scale_reg" : scale_reg ,
652691 }
653692
693+ # Losses for mcmc
694+ if self .config .strategy == "mcmc" :
695+ if self .config .mcmc_opacity_reg > 0.0 :
696+ mcmc_opacity_reg = (
697+ self .config .mcmc_opacity_reg * torch .abs (torch .sigmoid (self .gauss_params ["opacities" ])).mean ()
698+ )
699+ loss_dict ["mcmc_opacity_reg" ] = mcmc_opacity_reg
700+ if self .config .mcmc_scale_reg > 0.0 :
701+ mcmc_scale_reg = self .config .mcmc_scale_reg * torch .abs (torch .exp (self .gauss_params ["scales" ])).mean ()
702+ loss_dict ["mcmc_scale_reg" ] = mcmc_scale_reg
703+
654704 if self .training :
655705 # Add loss from camera optimizer
656706 self .camera_optimizer .get_loss_dict (loss_dict )
0 commit comments