@@ -177,9 +177,9 @@ class Config:
177177 enable_autograd_detect_anomaly : bool = False
178178
179179 # Enable exporting dynamic splats to per-frame ply files
180- enable_dyn_splats_export : bool = True
180+ enable_dyn_splats_export : bool = False
181181 # Time-variant opacity based pruning
182- temp_opa_vis_pruning : bool = True
182+ temp_opa_vis_pruning : bool = False
183183 # ply file saving mode
184184 sliced_splats_saving_mode : Literal ["splats" , "pcd" ] = "splats"
185185
@@ -1161,6 +1161,11 @@ def export_dyn_splats_to_ply_sequence(self,):
11611161 shutil .rmtree (ply_dir )
11621162 os .makedirs (ply_dir , exist_ok = True )
11631163
1164+ if self .cfg .temp_opa_vis_pruning :
1165+ print ("[Info] Pruning splats whose opacities are below threshold before saving splats into .ply files." )
1166+ else :
1167+ print ("[Info] Directly saving splats into .ply files, without pruning." )
1168+
11641169 for f_id in range (cfg .duration ):
11651170 timestamp = f_id / cfg .duration
11661171 sliced_splats = self .get_sliced_splats_from_dyn_splats (self .splats , timestamp , self .cfg .temp_opa_vis_pruning )
@@ -1203,7 +1208,7 @@ def trbfunction(x):
12031208 means_motion = means + motion [:, 0 :3 ] * tforpoly + motion [:, 3 :6 ] * tforpoly * tforpoly + motion [:, 6 :9 ] * tforpoly * tforpoly * tforpoly
12041209 # Calculate rotations
12051210 rotations = torch .nn .functional .normalize (quats + tforpoly * omega )
1206- import pdb ; pdb .set_trace ()
1211+ # import pdb; pdb.set_trace()
12071212 if opa_vis_mask :
12081213 vis_mask = trbfoutput .squeeze () > 0.05
12091214 means_motion = means_motion [vis_mask ]
@@ -1215,17 +1220,20 @@ def trbfunction(x):
12151220 num_vis_mask = vis_mask .sum ()
12161221 num_all_splats = vis_mask .shape [0 ]
12171222
1223+ def inverse_sigmoid (x ):
1224+ return torch .log (x / (1 - x ))
1225+
12181226 sliced_splats = {
12191227 "means" : means_motion ,
1220- "scales" : scales ,
1228+ "scales" : torch . log ( scales ) ,
12211229 "quats" : rotations ,
1222- "opacities" : opacity ,
1230+ "opacities" : inverse_sigmoid ( opacity ) ,
12231231 "rgb" : color
12241232 }
12251233
12261234 return sliced_splats
12271235
1228- def save_static_splats_to_ply (self , path , splats , mode = "splats" ):
1236+ def save_static_splats_to_ply (self , path , splats , mode = "splats" , use_text = False ):
12291237 from plyfile import PlyElement , PlyData
12301238
12311239 xyz = splats ["means" ].detach ().cpu ().numpy ()
@@ -1235,29 +1243,31 @@ def save_static_splats_to_ply(self, path, splats, mode="splats"):
12351243 rgb = splats ["rgb" ].detach ().contiguous ().cpu ().numpy ()
12361244
12371245 if mode == "splats" :
1246+ sh0 = rgb_to_sh (rgb )
1247+
12381248 dtype_full = [(attribute , 'f4' ) for attribute in construct_list_of_attributes ()]
12391249 elements = np .empty (xyz .shape [0 ], dtype = dtype_full )
1240- attributes = np .concatenate ((xyz , scales , quats , opacities , rgb ), axis = 1 )
1250+ attributes = np .concatenate ((xyz , scales , quats , opacities , sh0 ), axis = 1 )
12411251 elements [:] = list (map (tuple , attributes ))
1252+
12421253 elif mode == "pcd" :
12431254 rgb = np .round (np .clip (rgb , 0 , 1 ) * 255 ).astype (np .uint8 )
1255+
12441256 dtype_full = [('x' , 'f4' ), ('y' , 'f4' ), ('z' , 'f4' ),
12451257 ('red' , 'u1' ), ('green' , 'u1' ), ('blue' , 'u1' )]
1246-
12471258 elements = np .empty (len (xyz ), dtype = dtype_full )
12481259
12491260 elements ['x' ] = xyz [:, 0 ]
12501261 elements ['y' ] = xyz [:, 1 ]
12511262 elements ['z' ] = xyz [:, 2 ]
1252-
12531263 elements ['red' ] = rgb [:, 0 ]
12541264 elements ['green' ] = rgb [:, 1 ]
12551265 elements ['blue' ] = rgb [:, 2 ]
12561266
12571267 el = PlyElement .describe (elements , 'vertex' )
1258- PlyData ([el ], text = True ).write (path )
1268+ PlyData ([el ], text = use_text ).write (path )
12591269
1260- print (f"Save splats to: { path } " )
1270+ print (f"Save { xyz . shape [ 0 ] } splats to: { path } " )
12611271
12621272def construct_list_of_attributes ():
12631273 l = ['x' , 'y' , 'z' ]
@@ -1267,7 +1277,7 @@ def construct_list_of_attributes():
12671277 l .append ('rot_{}' .format (i ))
12681278 l .append ('opacity' )
12691279 for i in range (3 ):
1270- l .append ('sh0_ {}' .format (i ))
1280+ l .append ('f_dc_ {}' .format (i ))
12711281
12721282 return l
12731283
@@ -1287,8 +1297,9 @@ def main(cfg: Config):
12871297 runner .decoder .load_state_dict (ckpts [0 ]["decoder" ])
12881298 step = ckpts [0 ]["step" ]
12891299
1290- print (f"Evaluate ckpt saved at step { step } " )
1291- runner .eval (step = step )
1300+ if cfg .compression is not None :
1301+ print (f"Evaluate ckpt saved at step { step } " )
1302+ runner .eval (step = step )
12921303
12931304 # print(f"Render trajectory using ckpt saved at step {step}")
12941305 # runner.render_traj(step=step)
0 commit comments