Skip to content

Commit 99ffc5e

Browse files
committed
feat: new strategy for STG
1 parent 52fb40a commit 99ffc5e

File tree

5 files changed

+545
-34
lines changed

5 files changed

+545
-34
lines changed

examples/simple_trainer_STG.py

Lines changed: 90 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,15 @@
2424
# from datasets.INVR import Dataset, Parser # This only supports preprocessed Bartender & CBA dataset
2525
from datasets.INVR_N3D import Parser, Dataset # This only supports preprocessed N3D Dataset
2626

27+
from gsplat import strategy
2728
from helper.STG.helper_model import getcolormodel, trbfunction
2829
from utils import AppearanceOptModule, CameraOptModule, knn, rgb_to_sh, set_random_seed
2930

3031
from fused_ssim import fused_ssim
3132

3233
from gsplat.compression import PngCompression, STGPngCompression
3334
from gsplat.rendering import rasterization
34-
from gsplat.strategy.STG_Strategy import STG_Strategy # import densification and pruning strategy that fits STG model
35+
from gsplat.strategy import STG_Strategy, Modified_STG_Strategy # import densification and pruning strategy that fits STG model
3536
from gsplat.compression_simulation import STGCompressionSimulation
3637

3738
class ProfilerConfig:
@@ -135,7 +136,14 @@ class Config:
135136
# densification strategy
136137
# strategy: Union[STG_Strategy] = field(
137138
# default_factory=DefaultStrategy
138-
# )
139+
# )
140+
strategy: Literal["STG_Strategy", "Modified_STG_Strategy"] = "STG_Strategy"
141+
142+
# Temporal visibility masking
143+
temp_vis_mask: bool = False
144+
145+
# Test view
146+
test_view_id: List[int] = field(default_factory=lambda: [0]) # Neu3DVideo do not need to specify, but INVR needs
139147

140148
# compression
141149
# Name of compression strategy to use
@@ -351,20 +359,39 @@ def __init__(self, cfg: Config) -> None:
351359

352360
# Densification Strategy
353361
# Only support one type of Densification Strategy for now
354-
self.strategy = STG_Strategy(
355-
verbose=True,
356-
prune_opa=cfg.prune_opa,
357-
grow_grad2d=cfg.grow_grad2d,
358-
grow_scale3d=cfg.grow_scale3d,
359-
# prune_scale3d=cfg.prune_scale3d,
360-
# refine_scale2d_stop_iter=4000, # splatfacto behavior
361-
refine_start_iter=cfg.refine_start_iter,
362-
refine_stop_iter=cfg.refine_stop_iter,
363-
reset_every=cfg.reset_every,
364-
refine_every=cfg.refine_every,
365-
absgrad=cfg.absgrad,
366-
# revised_opacity=cfg.revised_opacity,
367-
)
362+
if cfg.strategy == "STG_Strategy":
363+
self.strategy = STG_Strategy(
364+
verbose=True,
365+
prune_opa=cfg.prune_opa,
366+
grow_grad2d=cfg.grow_grad2d,
367+
grow_scale3d=cfg.grow_scale3d,
368+
# prune_scale3d=cfg.prune_scale3d,
369+
# refine_scale2d_stop_iter=4000, # splatfacto behavior
370+
refine_start_iter=cfg.refine_start_iter,
371+
refine_stop_iter=cfg.refine_stop_iter,
372+
reset_every=cfg.reset_every,
373+
refine_every=cfg.refine_every,
374+
absgrad=cfg.absgrad,
375+
pause_refine_after_reset=cfg.pause_refine_after_reset
376+
# revised_opacity=cfg.revised_opacity,
377+
)
378+
elif cfg.strategy == "Modified_STG_Strategy":
379+
self.strategy = Modified_STG_Strategy(
380+
verbose=True,
381+
prune_opa=cfg.prune_opa,
382+
grow_grad2d=cfg.grow_grad2d,
383+
grow_scale3d=cfg.grow_scale3d,
384+
# prune_scale3d=cfg.prune_scale3d,
385+
# refine_scale2d_stop_iter=4000, # splatfacto behavior
386+
refine_start_iter=cfg.refine_start_iter,
387+
refine_stop_iter=cfg.refine_stop_iter,
388+
reset_every=cfg.reset_every,
389+
refine_every=cfg.refine_every,
390+
absgrad=cfg.absgrad,
391+
pause_refine_after_reset=cfg.pause_refine_after_reset,
392+
temp_vis_mask=cfg.temp_vis_mask
393+
# revised_opacity=cfg.revised_opacity,
394+
)
368395
self.strategy.check_sanity(self.splats, self.optimizers)
369396
self.strategy_state = self.strategy.initialize_state(scene_scale=self.scene_scale)
370397

@@ -448,6 +475,7 @@ def rasterize_splats(
448475
basicfunction,
449476
rays,
450477
camtoworld,
478+
temp_vis_mask = None,
451479
**kwargs,
452480
) -> Tuple[Tensor, Tensor, Dict]:
453481
if not self.cfg.compression_sim:
@@ -458,7 +486,7 @@ def rasterize_splats(
458486
opacities = torch.sigmoid(self.splats["opacities"]) # [N,]
459487

460488
trbfcenter = self.splats["trbf_center"] # [N, 1]
461-
trbfscale = self.splats["trbf_scale"] # [N, 1]
489+
trbfscale = torch.exp(self.splats["trbf_scale"]) # [N, 1]
462490

463491
motion = self.splats["motion"] # [N, 9]
464492
omega = self.splats["omega"] # [N, 4]
@@ -473,21 +501,21 @@ def rasterize_splats(
473501
opacities = torch.sigmoid(self.comp_sim_splats["opacities"]) # [N,]
474502

475503
trbfcenter = self.comp_sim_splats["trbf_center"] # [N, 1]
476-
trbfscale = self.comp_sim_splats["trbf_scale"] # [N, 1]
504+
trbfscale = torch.exp(self.comp_sim_splats["trbf_scale"]) # [N, 1], log domain
477505

478506
motion = self.comp_sim_splats["motion"] # [N, 9]
479507
omega = self.comp_sim_splats["omega"] # [N, 4]
480508
feature_color = self.comp_sim_splats["colors"] # [N, 3]
481509
feature_dir = self.comp_sim_splats["features_dir"] # [N, 3]
482510
feature_time = self.comp_sim_splats["features_time"] # [N, 3]
483511

484-
485512
pointtimes = torch.ones((means.shape[0],1), dtype=means.dtype, requires_grad=False, device="cuda") + 0 #
486513
timestamp = timestamp
487514

488515
trbfdistanceoffset = timestamp * pointtimes - trbfcenter
489-
trbfdistance = trbfdistanceoffset / torch.exp(trbfscale)
490-
trbfoutput = basicfunction(trbfdistance)
516+
trbfdistance = trbfdistanceoffset / (math.sqrt(2) * trbfscale)
517+
trbfoutput = basicfunction(trbfdistance)
518+
491519
# opacity decay
492520
opacity = opacities * trbfoutput.squeeze()
493521

@@ -500,7 +528,18 @@ def rasterize_splats(
500528

501529
# Calculate feature
502530
colors_precomp = torch.cat((feature_color, feature_dir, tforpoly * feature_time), dim=1)
503-
# colors_precomp = feature_color
531+
532+
# Filter out unvisible splats at this timestamp. "mask == 1" means visible, otherwise unvisible.
533+
if temp_vis_mask:
534+
t_vis_mask = trbfoutput.squeeze() > 0.05
535+
means_motion = means_motion[t_vis_mask]
536+
rotations = rotations[t_vis_mask]
537+
scales = scales[t_vis_mask]
538+
opacity = opacity[t_vis_mask]
539+
colors_precomp = colors_precomp[t_vis_mask]
540+
541+
num_t_vis_mask = t_vis_mask.sum()
542+
num_all_splats = t_vis_mask.shape[0]
504543

505544
rasterize_mode = "antialiased" if self.cfg.antialiased else "classic"
506545
render_colors, render_alphas, info = rasterization(
@@ -520,6 +559,21 @@ def rasterize_splats(
520559
**kwargs,
521560
)
522561

562+
# (Opt.) modify per-Gaussian related data in "info", except "info['means2d']"
563+
if temp_vis_mask:
564+
for name, v in info.items():
565+
if isinstance(v, torch.Tensor) and name in ["radii", "depths", "conics", "opacities"]:
566+
new_shape = list(v.shape)
567+
new_shape[1] = num_all_splats
568+
new_tensor = torch.zeros(new_shape, dtype=v.dtype, device=v.device)
569+
new_tensor[:, t_vis_mask] = v
570+
571+
info[name] = new_tensor
572+
573+
info.update(
574+
{"t_vis_mask": t_vis_mask}
575+
)
576+
523577
# Decode
524578
render_colors = render_colors.permute(0,3,1,2)
525579
render_colors = self.decoder(render_colors, rays, timestamp) # 1 , 3
@@ -610,16 +664,14 @@ def train(self):
610664

611665
# forward
612666
renders, alphas, info = self.rasterize_splats(
613-
# R=R,
614-
# T=T,
615667
timestamp=timestamp, # [C]
616668
Ks=Ks, # [C, 3, 3]
617669
width=width,
618670
height=height,
619671
basicfunction=trbfunction,
620672
rays=rays, # [C, 6, H, W]
621673
camtoworld=camtoworld, # [C, 4, 4]
622-
# batch_size=cfg.batch_size,
674+
temp_vis_mask=self.cfg.temp_vis_mask
623675
)
624676

625677
if renders.shape[-1] == 4:
@@ -738,6 +790,19 @@ def train(self):
738790
maxbounds=self.maxbounds,
739791
minbounds=self.minbounds,
740792
)
793+
elif isinstance(self.strategy, Modified_STG_Strategy):
794+
flag = self.strategy.step_post_backward(
795+
params=self.splats,
796+
optimizers=self.optimizers,
797+
state=self.strategy_state,
798+
step=step,
799+
info=info,
800+
packed=cfg.packed,
801+
flag=flag,
802+
desicnt=cfg.desicnt,
803+
maxbounds=self.maxbounds,
804+
minbounds=self.minbounds,
805+
)
741806
else:
742807
assert False, "Invalid strategy!"
743808

gsplat/compression_simulation/simulation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -310,7 +310,7 @@ def __init__(self, quantization_sim_type: Optional[Literal["round", "noise", "vq
310310
"means": None,
311311
"scales": [-10, 2],
312312
"quats": [-1, 1],
313-
"opacities": [-5.5, 7],
313+
"opacities": [-7, 7],
314314
"trbf_center": None,
315315
"trbf_scale": None,
316316
"motion": None, # [N, 9]

gsplat/strategy/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
11
from .base import Strategy
22
from .default import DefaultStrategy
33
from .mcmc import MCMCStrategy
4+
from .STG_Strategy import STG_Strategy
5+
from .modified_stg import Modified_STG_Strategy

0 commit comments

Comments
 (0)