3030from torchmetrics .image .lpip import LearnedPerceptualImagePatchSimilarity
3131from typing_extensions import Literal , assert_never
3232from gsplat import strategy
33+ from gsplat .compression .entropy_coding_compression import EntropyCodingCompression
3334from gsplat .compression_simulation import simulation
3435from utils import AppearanceOptModule , CameraOptModule , knn , rgb_to_sh , set_random_seed
3536from lib_bilagrid import (
4445from gsplat .rendering import rasterization
4546from gsplat .strategy import DefaultStrategy , MCMCStrategy
4647from gsplat .compression_simulation import CompressionSimulation
48+ from gsplat .compression_simulation .entropy_model import Entropy_factorized_optimized_refactor , Entropy_gaussian
4749
4850class ProfilerConfig :
4951 def __init__ (self ):
@@ -92,7 +94,7 @@ class Config:
9294 # Path to the .pt files. If provide, it will skip training and run evaluation only.
9395 ckpt : Optional [List [str ]] = None
9496 # Name of compression strategy to use
95- compression : Optional [Literal ["png" ]] = None
97+ compression : Optional [Literal ["png" , "entropy_coding" ]] = None
9698
9799 # Enable profiler
98100 profiler_enabled : bool = False
@@ -118,10 +120,10 @@ class Config:
118120 "shN" : 10_000 })
119121 # gaussian model:
120122 # entropy_steps: Dict[str, int] = field(default_factory=lambda: {"means": -1,
121- # "quats": -1 ,
123+ # "quats": 10_000 ,
122124 # "scales": 10_000,
123- # "opacities": -1 ,
124- # "sh0": -1 ,
125+ # "opacities": 10_000 ,
126+ # "sh0": 20_000 ,
125127 # "shN": -1})
126128
127129 # Enable shN adaptive mask
@@ -420,6 +422,8 @@ def __init__(
420422 if cfg .compression is not None :
421423 if cfg .compression == "png" :
422424 self .compression_method = PngCompression ()
425+ elif cfg .compression == "entropy_coding" :
426+ self .compression_method = EntropyCodingCompression ()
423427 else :
424428 raise ValueError (f"Unknown compression strategy: { cfg .compression } " )
425429
@@ -897,6 +901,11 @@ def train(self):
897901
898902 if cfg .shN_ada_mask_opt and step > cfg .ada_mask_steps :
899903 data ["shN_ada_mask" ] = shN_ada_mask
904+
905+ if cfg .compression_sim and cfg .entropy_model_opt and cfg .compression == "entropy_coding" :
906+ for name , entropy_model in self .compression_sim_method .entropy_models .items ():
907+ if entropy_model is not None :
908+ data [name + "_entropy_model" ] = entropy_model .state_dict ()
900909
901910 torch .save (
902911 data , f"{ self .ckpt_dir } /ckpt_{ step } _rank{ self .world_rank } .pt"
@@ -1032,8 +1041,9 @@ def eval(self, step: int, stage: str = "val"):
10321041 canvas_list = [pixels , colors ]
10331042
10341043 if world_rank == 0 :
1035- # write images
1036- canvas = torch .cat (canvas_list , dim = 2 ).squeeze (0 ).cpu ().numpy ()
1044+ # write images
1045+ # canvas = torch.cat(canvas_list, dim=2).squeeze(0).cpu().numpy() # side by side
1046+ canvas = canvas_list [1 ].squeeze (0 ).cpu ().numpy () # signle image
10371047 canvas = (canvas * 255 ).astype (np .uint8 )
10381048 imageio .imwrite (
10391049 f"{ self .render_dir } /{ stage } _step{ step } _{ i :04d} .png" ,
@@ -1074,16 +1084,18 @@ def eval(self, step: int, stage: str = "val"):
10741084 self .writer .flush ()
10751085
10761086 @torch .no_grad ()
1077- def render_traj (self , step : int ):
1087+ def render_traj (self , step : int , stage : str = "val" ):
10781088 """Entry for trajectory rendering."""
10791089 print ("Running trajectory rendering..." )
10801090 cfg = self .cfg
10811091 device = self .device
10821092
1083- camtoworlds_all = self .parser .camtoworlds [5 :- 5 ]
1093+ num_imgs = len (self .parser .camtoworlds )
1094+
1095+ camtoworlds_all = self .parser .camtoworlds [: num_imgs // 2 ]
10841096 if cfg .render_traj_path == "interp" :
10851097 camtoworlds_all = generate_interpolated_path (
1086- camtoworlds_all , 1
1098+ camtoworlds_all , 6 # 1
10871099 ) # [N, 3, 4]
10881100 elif cfg .render_traj_path == "ellipse" :
10891101 height = camtoworlds_all [:, 2 , 3 ].mean ()
@@ -1118,7 +1130,7 @@ def render_traj(self, step: int):
11181130 # save to video
11191131 video_dir = f"{ cfg .result_dir } /videos"
11201132 os .makedirs (video_dir , exist_ok = True )
1121- writer = imageio .get_writer (f"{ video_dir } /traj_ { step } .mp4" , fps = 30 )
1133+ writer = imageio .get_writer (f"{ video_dir } /{ stage } _traj_ { step } .mp4" , fps = 30 )
11221134 for i in tqdm .trange (len (camtoworlds_all ), desc = "Rendering trajectory" ):
11231135 camtoworlds = camtoworlds_all [i : i + 1 ]
11241136 Ks = K [None ]
@@ -1139,11 +1151,12 @@ def render_traj(self, step: int):
11391151 canvas_list = [colors , depths .repeat (1 , 1 , 1 , 3 )]
11401152
11411153 # write images
1142- canvas = torch .cat (canvas_list , dim = 2 ).squeeze (0 ).cpu ().numpy ()
1154+ # canvas = torch.cat(canvas_list, dim=2).squeeze(0).cpu().numpy()
1155+ canvas = canvas_list [0 ].squeeze (0 ).cpu ().numpy ()
11431156 canvas = (canvas * 255 ).astype (np .uint8 )
11441157 writer .append_data (canvas )
11451158 writer .close ()
1146- print (f"Video saved to { video_dir } /traj_ { step } .mp4" )
1159+ print (f"Video saved to { video_dir } /{ stage } _traj_ { step } .mp4" )
11471160
11481161 @torch .no_grad ()
11491162 def run_compression (self , step : int ):
@@ -1156,8 +1169,12 @@ def run_compression(self, step: int):
11561169 # import pdb; pdb.set_trace()
11571170 self .run_param_distribution_vis (self .splats , save_dir = f"{ cfg .result_dir } /visualization/raw" )
11581171
1159- self .compression_method .compress (compress_dir , self .splats )
1160- # self.run_param_distribution_vis(self.splats, save_dir=f"{cfg.result_dir}/visualization/log_transform")
1172+ if isinstance (self .compression_method , PngCompression ):
1173+ self .compression_method .compress (compress_dir , self .splats )
1174+ elif isinstance (self .compression_method , EntropyCodingCompression ):
1175+ self .compression_method .compress (compress_dir , self .splats , self .entropy_models )
1176+ else :
1177+ raise NotImplementedError (f"The compression method is not implemented yet." )
11611178
11621179 # evaluate compression
11631180 splats_c = self .compression_method .decompress (compress_dir )
@@ -1167,6 +1184,7 @@ def run_compression(self, step: int):
11671184 for k in splats_c .keys ():
11681185 self .splats [k ].data = splats_c [k ].to (self .device )
11691186 self .eval (step = step , stage = "compress" )
1187+ self .render_traj (step = step , stage = "compress" )
11701188
11711189 @torch .no_grad ()
11721190 def run_param_distribution_vis (self , param_dict : Dict [str , Tensor ], save_dir : str ):
@@ -1199,6 +1217,21 @@ def run_param_distribution_vis(self, param_dict: Dict[str, Tensor], save_dir: st
11991217 plt .close ()
12001218
12011219 print (f"Histograms saved in '{ save_dir } ' directory." )
1220+
1221+ def load_entropy_model_from_ckpt (self , ckpt : Dict , entropy_model_type : str ):
1222+ self .entropy_models = {}
1223+ for name , value in ckpt .items ():
1224+ if "_entropy_model" in name :
1225+ attr_name = name [:(len (name ) - len ("_entropy_model" ))]
1226+ num_ch = ckpt ["splats" ][attr_name ].shape [- 1 ]
1227+ if entropy_model_type == "factorized_model" :
1228+ # TODO
1229+ pass
1230+ elif entropy_model_type == "gaussian_model" :
1231+ entropy_model = Entropy_gaussian (channel = num_ch )
1232+
1233+ entropy_model .load_state_dict (value )
1234+ self .entropy_models [attr_name ] = entropy_model
12021235
12031236 @torch .no_grad ()
12041237 def _viewer_render_fn (
@@ -1242,6 +1275,8 @@ def main(local_rank: int, world_rank, world_size: int, cfg: Config):
12421275 runner .eval (step = step )
12431276 runner .render_traj (step = step )
12441277 if cfg .compression is not None :
1278+ if cfg .compression == "entropy_coding" :
1279+ runner .load_entropy_model_from_ckpt (ckpts [0 ], cfg .entropy_model_type )
12451280 runner .run_compression (step = step )
12461281 else :
12471282 runner .train ()
0 commit comments