2424# from datasets.INVR import Dataset, Parser # This only supports preprocessed Bartender & CBA dataset
2525from datasets .INVR_N3D import Parser , Dataset # This only supports preprocessed N3D Dataset
2626
27+ from gsplat import strategy
2728from helper .STG .helper_model import getcolormodel , trbfunction
2829from utils import AppearanceOptModule , CameraOptModule , knn , rgb_to_sh , set_random_seed
2930
3031from fused_ssim import fused_ssim
3132
3233from gsplat .compression import PngCompression , STGPngCompression
3334from gsplat .rendering import rasterization
34- from gsplat .strategy . STG_Strategy import STG_Strategy # import densification and pruning strategy that fits STG model
35+ from gsplat .strategy import STG_Strategy , Modified_STG_Strategy # import densification and pruning strategy that fits STG model
3536from gsplat .compression_simulation import STGCompressionSimulation
3637
3738class ProfilerConfig :
@@ -135,7 +136,14 @@ class Config:
135136 # densification strategy
136137 # strategy: Union[STG_Strategy] = field(
137138 # default_factory=DefaultStrategy
138- # )
139+ # )
140+ strategy : Literal ["STG_Strategy" , "Modified_STG_Strategy" ] = "STG_Strategy"
141+
142+ # Temporal visibility masking
143+ temp_vis_mask : bool = False
144+
145+ # Test view
146+ test_view_id : List [int ] = field (default_factory = lambda : [0 ]) # Neu3DVideo do not need to specify, but INVR needs
139147
140148 # compression
141149 # Name of compression strategy to use
@@ -351,20 +359,39 @@ def __init__(self, cfg: Config) -> None:
351359
352360 # Densification Strategy
353361 # Only support one type of Densification Strategy for now
354- self .strategy = STG_Strategy (
355- verbose = True ,
356- prune_opa = cfg .prune_opa ,
357- grow_grad2d = cfg .grow_grad2d ,
358- grow_scale3d = cfg .grow_scale3d ,
359- # prune_scale3d=cfg.prune_scale3d,
360- # refine_scale2d_stop_iter=4000, # splatfacto behavior
361- refine_start_iter = cfg .refine_start_iter ,
362- refine_stop_iter = cfg .refine_stop_iter ,
363- reset_every = cfg .reset_every ,
364- refine_every = cfg .refine_every ,
365- absgrad = cfg .absgrad ,
366- # revised_opacity=cfg.revised_opacity,
367- )
362+ if cfg .strategy == "STG_Strategy" :
363+ self .strategy = STG_Strategy (
364+ verbose = True ,
365+ prune_opa = cfg .prune_opa ,
366+ grow_grad2d = cfg .grow_grad2d ,
367+ grow_scale3d = cfg .grow_scale3d ,
368+ # prune_scale3d=cfg.prune_scale3d,
369+ # refine_scale2d_stop_iter=4000, # splatfacto behavior
370+ refine_start_iter = cfg .refine_start_iter ,
371+ refine_stop_iter = cfg .refine_stop_iter ,
372+ reset_every = cfg .reset_every ,
373+ refine_every = cfg .refine_every ,
374+ absgrad = cfg .absgrad ,
375+ pause_refine_after_reset = cfg .pause_refine_after_reset
376+ # revised_opacity=cfg.revised_opacity,
377+ )
378+ elif cfg .strategy == "Modified_STG_Strategy" :
379+ self .strategy = Modified_STG_Strategy (
380+ verbose = True ,
381+ prune_opa = cfg .prune_opa ,
382+ grow_grad2d = cfg .grow_grad2d ,
383+ grow_scale3d = cfg .grow_scale3d ,
384+ # prune_scale3d=cfg.prune_scale3d,
385+ # refine_scale2d_stop_iter=4000, # splatfacto behavior
386+ refine_start_iter = cfg .refine_start_iter ,
387+ refine_stop_iter = cfg .refine_stop_iter ,
388+ reset_every = cfg .reset_every ,
389+ refine_every = cfg .refine_every ,
390+ absgrad = cfg .absgrad ,
391+ pause_refine_after_reset = cfg .pause_refine_after_reset ,
392+ temp_vis_mask = cfg .temp_vis_mask
393+ # revised_opacity=cfg.revised_opacity,
394+ )
368395 self .strategy .check_sanity (self .splats , self .optimizers )
369396 self .strategy_state = self .strategy .initialize_state (scene_scale = self .scene_scale )
370397
@@ -448,6 +475,7 @@ def rasterize_splats(
448475 basicfunction ,
449476 rays ,
450477 camtoworld ,
478+ temp_vis_mask = None ,
451479 ** kwargs ,
452480 ) -> Tuple [Tensor , Tensor , Dict ]:
453481 if not self .cfg .compression_sim :
@@ -458,7 +486,7 @@ def rasterize_splats(
458486 opacities = torch .sigmoid (self .splats ["opacities" ]) # [N,]
459487
460488 trbfcenter = self .splats ["trbf_center" ] # [N, 1]
461- trbfscale = self .splats ["trbf_scale" ] # [N, 1]
489+ trbfscale = torch . exp ( self .splats ["trbf_scale" ]) # [N, 1]
462490
463491 motion = self .splats ["motion" ] # [N, 9]
464492 omega = self .splats ["omega" ] # [N, 4]
@@ -473,21 +501,21 @@ def rasterize_splats(
473501 opacities = torch .sigmoid (self .comp_sim_splats ["opacities" ]) # [N,]
474502
475503 trbfcenter = self .comp_sim_splats ["trbf_center" ] # [N, 1]
476- trbfscale = self .comp_sim_splats ["trbf_scale" ] # [N, 1]
504+ trbfscale = torch . exp ( self .comp_sim_splats ["trbf_scale" ]) # [N, 1], log domain
477505
478506 motion = self .comp_sim_splats ["motion" ] # [N, 9]
479507 omega = self .comp_sim_splats ["omega" ] # [N, 4]
480508 feature_color = self .comp_sim_splats ["colors" ] # [N, 3]
481509 feature_dir = self .comp_sim_splats ["features_dir" ] # [N, 3]
482510 feature_time = self .comp_sim_splats ["features_time" ] # [N, 3]
483511
484-
485512 pointtimes = torch .ones ((means .shape [0 ],1 ), dtype = means .dtype , requires_grad = False , device = "cuda" ) + 0 #
486513 timestamp = timestamp
487514
488515 trbfdistanceoffset = timestamp * pointtimes - trbfcenter
489- trbfdistance = trbfdistanceoffset / torch .exp (trbfscale )
490- trbfoutput = basicfunction (trbfdistance )
516+ trbfdistance = trbfdistanceoffset / (math .sqrt (2 ) * trbfscale )
517+ trbfoutput = basicfunction (trbfdistance )
518+
491519 # opacity decay
492520 opacity = opacities * trbfoutput .squeeze ()
493521
@@ -500,7 +528,18 @@ def rasterize_splats(
500528
501529 # Calculate feature
502530 colors_precomp = torch .cat ((feature_color , feature_dir , tforpoly * feature_time ), dim = 1 )
503- # colors_precomp = feature_color
531+
532+ # Filter out unvisible splats at this timestamp. "mask == 1" means visible, otherwise unvisible.
533+ if temp_vis_mask :
534+ t_vis_mask = trbfoutput .squeeze () > 0.05
535+ means_motion = means_motion [t_vis_mask ]
536+ rotations = rotations [t_vis_mask ]
537+ scales = scales [t_vis_mask ]
538+ opacity = opacity [t_vis_mask ]
539+ colors_precomp = colors_precomp [t_vis_mask ]
540+
541+ num_t_vis_mask = t_vis_mask .sum ()
542+ num_all_splats = t_vis_mask .shape [0 ]
504543
505544 rasterize_mode = "antialiased" if self .cfg .antialiased else "classic"
506545 render_colors , render_alphas , info = rasterization (
@@ -520,6 +559,21 @@ def rasterize_splats(
520559 ** kwargs ,
521560 )
522561
562+ # (Opt.) modify per-Gaussian related data in "info", except "info['means2d']"
563+ if temp_vis_mask :
564+ for name , v in info .items ():
565+ if isinstance (v , torch .Tensor ) and name in ["radii" , "depths" , "conics" , "opacities" ]:
566+ new_shape = list (v .shape )
567+ new_shape [1 ] = num_all_splats
568+ new_tensor = torch .zeros (new_shape , dtype = v .dtype , device = v .device )
569+ new_tensor [:, t_vis_mask ] = v
570+
571+ info [name ] = new_tensor
572+
573+ info .update (
574+ {"t_vis_mask" : t_vis_mask }
575+ )
576+
523577 # Decode
524578 render_colors = render_colors .permute (0 ,3 ,1 ,2 )
525579 render_colors = self .decoder (render_colors , rays , timestamp ) # 1 , 3
@@ -610,16 +664,14 @@ def train(self):
610664
611665 # forward
612666 renders , alphas , info = self .rasterize_splats (
613- # R=R,
614- # T=T,
615667 timestamp = timestamp , # [C]
616668 Ks = Ks , # [C, 3, 3]
617669 width = width ,
618670 height = height ,
619671 basicfunction = trbfunction ,
620672 rays = rays , # [C, 6, H, W]
621673 camtoworld = camtoworld , # [C, 4, 4]
622- # batch_size= cfg.batch_size,
674+ temp_vis_mask = self . cfg .temp_vis_mask
623675 )
624676
625677 if renders .shape [- 1 ] == 4 :
@@ -738,6 +790,19 @@ def train(self):
738790 maxbounds = self .maxbounds ,
739791 minbounds = self .minbounds ,
740792 )
793+ elif isinstance (self .strategy , Modified_STG_Strategy ):
794+ flag = self .strategy .step_post_backward (
795+ params = self .splats ,
796+ optimizers = self .optimizers ,
797+ state = self .strategy_state ,
798+ step = step ,
799+ info = info ,
800+ packed = cfg .packed ,
801+ flag = flag ,
802+ desicnt = cfg .desicnt ,
803+ maxbounds = self .maxbounds ,
804+ minbounds = self .minbounds ,
805+ )
741806 else :
742807 assert False , "Invalid strategy!"
743808
0 commit comments