@@ -84,7 +84,7 @@ class Config:
8484 source_path : str = ""
8585 model_path : str = ""
8686 images : str = "images"
87- resolution : int = 2 #-1
87+ downscale_factor : int = 2 #-1
8888 white_background : bool = False
8989 veryrify_llff : int = 0
9090 eval : bool = True
@@ -95,7 +95,7 @@ class Config:
9595 # Optimization Params / op
9696 max_steps : int = 30_000
9797 init_opa : float = 0.1 # Initial opacity of GS
98- batch_size : int = 2 # TODO Do not support batch size = 1 for now
98+ batch_size : int = 2
9999 feature_dim : int = 3
100100 device : str = "cuda"
101101 global_scale : float = 1.0 # A global scaler that applies to the scene size related parameters
@@ -106,14 +106,15 @@ class Config:
106106 refine_stop_iter : int = 9_000 # STG changed this param from 15_000 to 9_000, comprared with 3dgs
107107 reset_every : int = 3_000
108108 refine_every : int = 100
109+ pause_refine_after_reset : int = 0
109110 absgrad : bool = False
110111 packed : bool = False
111112 sparse_grad : bool = False
112113 antialiased : bool = False
113114 duration : int = 50 # 20 # number of frames to train
114115 ssim_lambda : float = 0.2 # Weight for SSIM loss
115- save_steps : List [int ] = field (default_factory = lambda : [i for i in range (9_000 , 30_001 , 3_000 )]) # Steps to save the model
116- eval_steps : List [int ] = field (default_factory = lambda : [i for i in range (9_000 , 30_001 , 3_000 )]) # Steps to evaluate the model # 7_000, 30_000
116+ save_steps : List [int ] = field (default_factory = lambda : [i for i in range (9_000 , 75_001 , 3_000 )]) # Steps to save the model
117+ eval_steps : List [int ] = field (default_factory = lambda : [i for i in range (0 , 75_001 , 3_000 )]) # Steps to evaluate the model # 7_000, 30_000
117118 # eval_steps: List[int] = field(default_factory=lambda: [1_000, 2_000, 3_000, 4_000, 5_000, 6_000, 7_000, 25_000, 30_000])
118119 # Number of densification
119120 desicnt : int = 6 # default: 6
@@ -215,15 +216,15 @@ def create_splats_with_optimizers(
215216
216217 # Didn't introduce world_rank and world_size
217218 N = points .shape [0 ]
218- # quats = torch.rand((N, 4)) # [N, 4]
219- quats = torch .zeros ((N , 4 ))
219+ # quats = torch.rand((N, 4))
220+ quats = torch .zeros ((N , 4 )) # [N, 4]
220221 quats [:, 0 ] = 1
221222 # opacities = torch.logit(torch.full((N,), init_opacity)) # [N,]
222223 opacities = inverse_sigmoid (0.1 * torch .ones (N ,))
223- trbf_scale = torch .ones ((N , 1 ))
224+ trbf_scale = torch .log ( torch . ones ((N , 1 ))) # [N, 1]
224225 times = parser .timestamp
225226 times = torch .tensor (times )
226- trbf_center = times .contiguous ()
227+ trbf_center = times .contiguous () # [N, 1]
227228 motion = torch .zeros ((N , 9 ))
228229 omega = torch .zeros ((N , 4 ))
229230
@@ -285,7 +286,7 @@ def __init__(self, cfg: Config) -> None:
285286 # only enable when debug!!
286287 if cfg .enable_autograd_detect_anomaly :
287288 torch .autograd .set_detect_anomaly (True )
288-
289+
289290 self .cfg = cfg
290291 # Write cfg file: Skipped
291292 self .device = self .cfg .device
@@ -310,10 +311,10 @@ def __init__(self, cfg: Config) -> None:
310311
311312 # Load data: Training data should contain initial points and colors.
312313 parser = Parser (model_path = self .cfg .model_path , source_path = self .cfg .data_dir , duration = cfg .duration ,
313- shuffle = False , eval = self .cfg .eval , resolution = cfg .resolution , data_device = 'cpu' )
314+ shuffle = False , eval = self .cfg .eval , downscale_factor = cfg .downscale_factor , data_device = 'cpu' , test_view_id = cfg . test_view_id )
314315 self .parser = parser
315316 self .trainset = Dataset (parser = self .parser , split = "train" , num_views = cfg .batch_size , use_fake_length = True , fake_length = cfg .max_steps + 100 )
316- self .testset = Dataset (parser = self .parser , split = "test" , num_views = 1 )
317+ self .testset = Dataset (parser = self .parser , split = "test" , num_views = len ( cfg . test_view_id ) )
317318
318319 self .trainloader = torch .utils .data .DataLoader (
319320 self .trainset ,
@@ -352,7 +353,7 @@ def __init__(self, cfg: Config) -> None:
352353 self .decoder_optimizer = torch .optim .Adam (self .decoder .parameters (), lr = 0.0001 )
353354
354355 currentxyz = self .splats ["means" ]
355- maxx , maxy , maxz = torch .amax (currentxyz [:,0 ]), torch .amax (currentxyz [:,1 ]), torch .amax (currentxyz [:,2 ])# z wrong...
356+ maxx , maxy , maxz = torch .amax (currentxyz [:,0 ]), torch .amax (currentxyz [:,1 ]), torch .amax (currentxyz [:,2 ])
356357 minx , miny , minz = torch .amin (currentxyz [:,0 ]), torch .amin (currentxyz [:,1 ]), torch .amin (currentxyz [:,2 ])
357358 self .maxbounds = [maxx , maxy , maxz ]
358359 self .minbounds = [minx , miny , minz ]
@@ -878,6 +879,13 @@ def eval(self, step: int, stage: str = "val"):
878879 ellipse_time = 0
879880 metrics = defaultdict (list )
880881 pbar = tqdm .tqdm (range (0 , len (self .testloader )))
882+
883+ # save path
884+ eval_save_path = f"{ self .render_dir } /{ stage } _step{ step } "
885+ os .makedirs (eval_save_path , exist_ok = True )
886+
887+ ## init writer(s) based on num. of test views
888+ writers = [imageio .get_writer (f"{ eval_save_path } /{ stage } _step{ step } _testv{ i } .mp4" , fps = 30 , quality = 10 ) for i in range (len (cfg .test_view_id ))]
881889
882890 for t_idx , batch in enumerate (self .testloader ): # t_idx
883891
@@ -887,10 +895,6 @@ def eval(self, step: int, stage: str = "val"):
887895 timestamp = batch ["timestamp" ][0 ].float ().to (device )
888896 rays = batch ["ray" ][0 ].float ().to (device )
889897 camtoworld = batch ['camtoworld' ][0 ].float ().to (device )
890-
891- # R = camtoworld.cpu().numpy()[0, :3,:3]
892- # T = torch.inverse(camtoworld).cpu().numpy()[0, :3,-1]
893- # new_rays = self.get_rays(R, T, Ks[0,0,0], Ks[0,1,1], width, height).float().to(device)
894898
895899 torch .cuda .synchronize ()
896900 tic = time .time ()
@@ -906,29 +910,39 @@ def eval(self, step: int, stage: str = "val"):
906910 torch .cuda .synchronize ()
907911 ellipse_time += time .time () - tic
908912
909- colors = torch .clamp (colors , 0.0 , 1.0 )
913+ colors = torch .clamp (colors , 0.0 , 1.0 ) # colors: [N, H, W, C]
910914 canvas_list = [pixels , colors ]
911915
912916 desc = ""
913917 if world_rank == 0 :
914- # write GT-vs-rendered image
915918 try :
916- # new version - fpng
917- # canvas = torch.cat(canvas_list, dim=2).squeeze(0).contiguous().cpu().numpy() # fpnge needs [H,W,C]
918- # import pdb; pdb.set_trace()
919- canvas = canvas_list [- 1 ].squeeze (0 ).contiguous ().cpu ().numpy ()
920- canvas = (canvas * 255 ).astype (np .uint8 )
921- import fpnge
922- png = fpnge .fromNP (canvas )
923- with open (f"{ self .render_dir } /{ stage } _step{ step } _{ t_idx :04d} .png" , 'wb' ) as f :
924- f .write (png )
925-
926- canvas = canvas_list [0 ].squeeze (0 ).contiguous ().cpu ().numpy ()
927- canvas = (canvas * 255 ).astype (np .uint8 )
928919 import fpnge
929- png = fpnge .fromNP (canvas )
930- with open (f"{ self .render_dir } /{ stage } _step{ step } _{ t_idx :04d} _gt.png" , 'wb' ) as f :
931- f .write (png )
920+ canvases = torch .cat (canvas_list , dim = 2 ) # canvas: [N, H, 2*W, C]
921+ for i in range (canvases .shape [0 ]): # loop on test views
922+ # save side-by-side comparison
923+ canvas = canvases [i ].contiguous ().cpu ().numpy ()
924+ canvas = (canvas * 255 ).astype (np .uint8 )
925+ # new version - fpng
926+
927+ png = fpnge .fromNP (canvas ) # fpnge needs tensor in order as [H,W,C]
928+ # with open(f"{self.render_dir}/{stage}_step{step}_{t_idx:04d}_testv{i}.png", 'wb') as f:
929+ # f.write(png)
930+ with open (f"{ eval_save_path } /sidebyside_testv{ i } _fid{ t_idx :04d} .png" , 'wb' ) as f :
931+ f .write (png )
932+
933+ # save gt
934+ gt = pixels [i ].contiguous ().cpu ().numpy ()
935+ gt = (gt * 255 ).astype (np .uint8 )
936+ png = fpnge .fromNP (gt )
937+ with open (f"{ eval_save_path } /gt_testv{ i } _fid{ t_idx :04d} .png" , 'wb' ) as f :
938+ f .write (png )
939+
940+ # save rendered
941+ rendered = colors [i ].contiguous ().cpu ().numpy ()
942+ rendered = (rendered * 255 ).astype (np .uint8 )
943+ png = fpnge .fromNP (rendered )
944+ with open (f"{ eval_save_path } /rendered_testv{ i } _fid{ t_idx :04d} .png" , 'wb' ) as f :
945+ f .write (png )
932946
933947 except :
934948 # original version - imageio
@@ -939,6 +953,12 @@ def eval(self, step: int, stage: str = "val"):
939953 canvas ,
940954 )
941955
956+ # save rendered test-view videos
957+ for i in range (colors .shape [0 ]):
958+ color = colors [i ].cpu ().numpy ()
959+ color = (color * 255 ).astype (np .uint8 )
960+ writers [i ].append_data (color )
961+
942962 # write difference image
943963 # difference = abs(colors - pixels).squeeze().detach().cpu().numpy()
944964 # imageio.imwrite(
@@ -956,6 +976,10 @@ def eval(self, step: int, stage: str = "val"):
956976 pbar .update (1 )
957977 pbar .close ()
958978
979+ for i , writer in enumerate (writers ):
980+ writer .close ()
981+ print (f"Video saved to { self .render_dir } /{ stage } _step{ step } _testv{ i } .mp4" )
982+
959983 if world_rank == 0 :
960984 ellipse_time /= len (self .testloader )
961985
@@ -1221,9 +1245,11 @@ def main(cfg: Config):
12211245 runner .splats [k ].data = torch .cat ([ckpt ["splats" ][k ] for ckpt in ckpts ])
12221246 runner .decoder .load_state_dict (ckpts [0 ]["decoder" ])
12231247 step = ckpts [0 ]["step" ]
1224- runner .render_traj (step = step )
1225- # runner.eval(step=step)
1248+ print (f"Evaluate ckpt saved at step { step } " )
1249+ # runner.render_traj(step=step)
1250+ runner .eval (step = step )
12261251 if cfg .compression is not None :
1252+ print (f"Compress ckpt saved at step { step } " )
12271253 runner .run_compression (step = step )
12281254
12291255 else :
0 commit comments