@@ -181,8 +181,15 @@ class Config:
181181 # Steps to enable shN adaptive mask
182182 ada_mask_steps : int = 10_000
183183
184- # Enable torch.autograd.detect_anomaly ?
184+ # Enable torch.autograd.detect_anomaly
185185 enable_autograd_detect_anomaly : bool = False
186+
187+ # Enable exporting dynamic splats to per-frame ply files
188+ enable_dyn_splats_export : bool = True
189+ # Time-variant opacity based pruning
190+ temp_opa_vis_pruning : bool = True
191+ # ply file saving mode
192+ sliced_splats_saving_mode : Literal ["splats" , "pcd" ] = "splats"
186193
187194
188195def create_splats_with_optimizers (
@@ -1173,10 +1180,10 @@ def export_dyn_splats_to_ply_sequence(self,):
11731180
11741181 for f_id in range (cfg .duration ):
11751182 timestamp = f_id / cfg .duration
1176- sliced_splats = self .get_sliced_splats_from_dyn_splats (self .splats , timestamp )
1183+ sliced_splats = self .get_sliced_splats_from_dyn_splats (self .splats , timestamp , self . cfg . temp_opa_vis_pruning )
11771184
11781185 ply_filename = ply_dir + f"/{ f_id :03d} .ply"
1179- self .save_static_splats_to_ply (ply_filename , sliced_splats )
1186+ self .save_static_splats_to_ply (ply_filename , sliced_splats , mode = self . cfg . sliced_splats_saving_mode )
11801187
11811188
11821189 @torch .no_grad ()
@@ -1213,7 +1220,7 @@ def trbfunction(x):
12131220 means_motion = means + motion [:, 0 :3 ] * tforpoly + motion [:, 3 :6 ] * tforpoly * tforpoly + motion [:, 6 :9 ] * tforpoly * tforpoly * tforpoly
12141221 # Calculate rotations
12151222 rotations = torch .nn .functional .normalize (quats + tforpoly * omega )
1216-
1223+ import pdb ; pdb . set_trace ()
12171224 if opa_vis_mask :
12181225 vis_mask = trbfoutput .squeeze () > 0.05
12191226 means_motion = means_motion [vis_mask ]
@@ -1297,18 +1304,18 @@ def main(cfg: Config):
12971304 runner .decoder .load_state_dict (ckpts [0 ]["decoder" ])
12981305 step = ckpts [0 ]["step" ]
12991306
1300- # print(f"Evaluate ckpt saved at step {step}")
1301- # runner.eval(step=step)
1307+ print (f"Evaluate ckpt saved at step { step } " )
1308+ runner .eval (step = step )
13021309
13031310 # print(f"Render trajectory using ckpt saved at step {step}")
13041311 # runner.render_traj(step=step)
1312+ if cfg .enable_dyn_splats_export :
1313+ print (f"Save .ply files using ckpt saved at step { step } " )
1314+ runner .export_dyn_splats_to_ply_sequence ()
13051315
1306- print (f"Save .ply files using ckpt saved at step { step } " )
1307- runner .export_dyn_splats_to_ply_sequence ()
1308-
1309- # if cfg.compression is not None:
1310- # print(f"Compress ckpt saved at step {step}")
1311- # runner.run_compression(step=step)
1316+ if cfg .compression is not None :
1317+ print (f"Compress ckpt saved at step { step } " )
1318+ runner .run_compression (step = step )
13121319
13131320 else :
13141321 runner .train ()
0 commit comments