Skip to content

Commit 072ea3d

Browse files
committed
misc: add cfg support for per-frame gs exportion
1 parent de40d25 commit 072ea3d

File tree

1 file changed

+19
-12
lines changed

1 file changed

+19
-12
lines changed

examples/simple_trainer_dyngs.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -181,8 +181,15 @@ class Config:
181181
# Steps to enable shN adaptive mask
182182
ada_mask_steps: int = 10_000
183183

184-
# Enable torch.autograd.detect_anomaly ?
184+
# Enable torch.autograd.detect_anomaly
185185
enable_autograd_detect_anomaly: bool = False
186+
187+
# Enable exporting dynamic splats to per-frame ply files
188+
enable_dyn_splats_export: bool = True
189+
# Time-variant opacity based pruning
190+
temp_opa_vis_pruning: bool = True
191+
# ply file saving mode
192+
sliced_splats_saving_mode: Literal["splats", "pcd"] = "splats"
186193

187194

188195
def create_splats_with_optimizers(
@@ -1173,10 +1180,10 @@ def export_dyn_splats_to_ply_sequence(self,):
11731180

11741181
for f_id in range(cfg.duration):
11751182
timestamp = f_id/cfg.duration
1176-
sliced_splats = self.get_sliced_splats_from_dyn_splats(self.splats, timestamp)
1183+
sliced_splats = self.get_sliced_splats_from_dyn_splats(self.splats, timestamp, self.cfg.temp_opa_vis_pruning)
11771184

11781185
ply_filename = ply_dir + f"/{f_id:03d}.ply"
1179-
self.save_static_splats_to_ply(ply_filename, sliced_splats)
1186+
self.save_static_splats_to_ply(ply_filename, sliced_splats, mode=self.cfg.sliced_splats_saving_mode)
11801187

11811188

11821189
@torch.no_grad()
@@ -1213,7 +1220,7 @@ def trbfunction(x):
12131220
means_motion = means + motion[:, 0:3] * tforpoly + motion[:, 3:6] * tforpoly * tforpoly + motion[:, 6:9] * tforpoly *tforpoly * tforpoly
12141221
# Calculate rotations
12151222
rotations = torch.nn.functional.normalize(quats + tforpoly * omega)
1216-
1223+
import pdb; pdb.set_trace()
12171224
if opa_vis_mask:
12181225
vis_mask = trbfoutput.squeeze() > 0.05
12191226
means_motion = means_motion[vis_mask]
@@ -1297,18 +1304,18 @@ def main(cfg: Config):
12971304
runner.decoder.load_state_dict(ckpts[0]["decoder"])
12981305
step = ckpts[0]["step"]
12991306

1300-
# print(f"Evaluate ckpt saved at step {step}")
1301-
# runner.eval(step=step)
1307+
print(f"Evaluate ckpt saved at step {step}")
1308+
runner.eval(step=step)
13021309

13031310
# print(f"Render trajectory using ckpt saved at step {step}")
13041311
# runner.render_traj(step=step)
1312+
if cfg.enable_dyn_splats_export:
1313+
print(f"Save .ply files using ckpt saved at step {step}")
1314+
runner.export_dyn_splats_to_ply_sequence()
13051315

1306-
print(f"Save .ply files using ckpt saved at step {step}")
1307-
runner.export_dyn_splats_to_ply_sequence()
1308-
1309-
# if cfg.compression is not None:
1310-
# print(f"Compress ckpt saved at step {step}")
1311-
# runner.run_compression(step=step)
1316+
if cfg.compression is not None:
1317+
print(f"Compress ckpt saved at step {step}")
1318+
runner.run_compression(step=step)
13121319

13131320
else:
13141321
runner.train()

0 commit comments

Comments
 (0)