@@ -502,43 +502,285 @@ def forward(ctx, input):
502502 @staticmethod
503503 def backward (ctx , grad_output ):
504504 return grad_output
505+
506+
507+
508+ class STGCompressionSimulation :
509+ """
510+ """
511+ def __init__ (self , quantization_sim_type : Optional [Literal ["round" , "noise" , "vq" ]] = None ,
512+ entropy_model_enable : bool = False ,
513+ entropy_steps : Dict [str , int ] = None ,
514+ device : device = None ,
515+ ada_mask_opt : bool = False ,
516+ ada_mask_step : int = 10_000 ,
517+ ** kwargs ) -> None :
518+ self .quantization_sim_type = quantization_sim_type
519+
520+ self .entropy_model_enable = entropy_model_enable
521+ self .entropy_steps = entropy_steps
522+ self .device = device
523+
524+ # simulation_option: dict to specify which properties should be involved in the compression simulation.
525+ # Once option is set to True, it must have corresponding simulate_fn
526+ self .simulation_option = {
527+ "means" : False ,
528+ "scales" : True ,
529+ "quats" : True ,
530+ "opacities" : True ,
531+ "trbf_center" : False ,
532+ "trbf_scale" : False ,
533+ "motion" : False , # [N, 9]
534+ "omega" : False , # [N, 4]
535+ "colors" : True ,
536+ "features_dir" : True ,
537+ "features_time" : True ,
538+ }
539+
540+ self .shN_qat = False
541+ self .shN_ada_mask_opt = ada_mask_opt
542+ self .shN_ada_mask_step = ada_mask_step
543+
544+ # configs for "differentiable quantization"
545+ self .q_bitwidth = {
546+ "means" : None ,
547+ "scales" : 8 ,
548+ "quats" : 8 ,
549+ "opacities" : 8 ,
550+ "trbf_center" : None ,
551+ "trbf_scale" : None ,
552+ "motion" : None , # [N, 9]
553+ "omega" : None , # [N, 4]
554+ "colors" : 8 ,
555+ "features_dir" : 8 ,
556+ "features_time" : 8 ,
557+ }
558+
559+ self .bds = {
560+ "means" : None ,
561+ "scales" : [- 10 , 2 ],
562+ "quats" : [- 1 , 1 ],
563+ "opacities" : [- 7 , 7 ],
564+ "trbf_center" : None ,
565+ "trbf_scale" : None ,
566+ "motion" : None , # [N, 9]
567+ "omega" : None , # [N, 4]
568+ "colors" : [- 7.5 , 7.5 ],
569+ "features_dir" : [- 10 , 10 ],
570+ "features_time" : [- 10 , 10 ],
571+ }
572+
573+ # configs for "entropy constraint"
574+ self .entropy_model_option = {
575+ "means" : False ,
576+ "scales" : True ,
577+ "quats" : True ,
578+ "opacities" : False ,
579+ "colors" : True ,
580+ "features_dir" : True ,
581+ "features_time" : True
582+ # "shN": False
583+ }
584+
585+ if self .entropy_model_enable :
586+ self .entropy_models = {
587+ "means" : None ,
588+ "scales" : Entropy_factorized_optimized_refactor (channel = 3 ).to (self .device ),
589+ # "scales": None,
590+ "quats" : Entropy_factorized_optimized_refactor (channel = 4 ).to (self .device ),
591+ "opacities" : None ,
592+ "colors" : Entropy_factorized_optimized_refactor (channel = 3 , filters = (3 , 3 )).to (self .device ),
593+ "features_dir" : Entropy_factorized_optimized_refactor (channel = 3 , filters = (3 , 3 )).to (self .device ),
594+ "features_time" : Entropy_factorized_optimized_refactor (channel = 3 , filters = (3 , 3 )).to (self .device ),
595+ }
596+
597+ self .entropy_model_optimizers = {}
598+ for k , v in self .entropy_models .items ():
599+ if isinstance (v , Entropy_factorized ) or isinstance (v , Entropy_factorized_optimized ) or isinstance (v , Entropy_factorized_optimized_refactor ):
600+ v_opt = torch .optim .Adam (
601+ [{"params" : p , "lr" : 1e-4 , "name" : n } for n , p in v .named_parameters ()]
602+ )
603+ # v_opt = torch.optim.SGD(
604+ # [{"params": p, "lr": 1e-4, "name": n} for n, p in v.named_parameters()]
605+ # )
606+ else :
607+ v_opt = None
608+ self .entropy_model_optimizers .update ({k : v_opt })
609+
610+ # configs for "adaptive mask"
611+ if self .shN_ada_mask_opt :
612+ from .ada_mask import AnnealingMask
613+ cap_max = kwargs .get ("cap_max" , 1_000_000 )
614+ self .shN_ada_mask = AnnealingMask (input_shape = [cap_max , 1 , 1 ],
615+ device = device ,
616+ annealing_start_iter = ada_mask_step )
617+
618+ self .shN_ada_mask_optimizer = torch .optim .Adam ([
619+ {'params' : self .shN_ada_mask .parameters (), 'lr' : 0.01 }
620+ ])
621+
622+ def _get_simulate_fn (self , param_name : str ) -> Callable :
623+ simulate_fn_map = {
624+ "means" : self .simulate_compression_means ,
625+ "scales" : self .simulate_compression_scales ,
626+ "quats" : self .simulate_compression_quats ,
627+ "opacities" : self .simulate_compression_opacities ,
628+ # "trbf_center": self.simulate_compression_trbf_center,
629+ # "trbf_scale": self.simulate_compression_trbf_scale,
630+ # "motion": self.simulate_compression_motion,
631+ # "omega": self.simulate_compression_omega,
632+ "colors" : self .simulate_compression_colors ,
633+ "features_dir" : self .simulate_compression_features_dir ,
634+ "features_time" : self .simulate_compression_features_time
635+ }
636+ if param_name in simulate_fn_map :
637+ return simulate_fn_map [param_name ]
638+ else :
639+ return torch .nn .Identity ()
640+
641+ def simulate_compression (self , splats : Dict [str , Tensor ], step : int ) -> Dict [str , Tensor ]:
642+ """
643+ """
644+ # Create empty dicts for output, including fake quantized values and (optional) estimated bits
645+ new_splats = {}
646+ esti_bits_dict = {}
647+
648+ # # Randomly sample approximately 5% of the points rather than all points for speedup.
649+ # choose_idx = torch.rand_like(splats["means"][:, 0], device=self.device) <= 1
650+ choose_idx = None
651+
652+ for param_name in splats .keys ():
653+ # Check which params need to be simulate
654+ if self .simulation_option [param_name ]:
655+ simulate_fn = self ._get_simulate_fn (param_name )
656+ new_splats [param_name ], esti_bits_dict [param_name ] = simulate_fn (splats [param_name ], step , choose_idx )
657+ else :
658+ new_splats [param_name ] = splats [param_name ] + 0.
659+ esti_bits_dict [param_name ] = None
660+
661+ return new_splats , esti_bits_dict
505662
506- # to simulate what happens in gsplat 's PngCompression()
507- def _min_max_quantization_16bit (param : Tensor ) -> Tensor :
508- maxs = torch .amax (param , dim = 0 )
509- mins = torch .amin (param , dim = 0 )
663+ def simulate_compression_means (self , param : torch .nn .Parameter , step : int , choose_idx : torch .Tensor ) -> Tensor :
664+ # out = torch.clamp(param, -5, 5)
665+ # out = inverse_log_transform(log_transform(clamped_param))
666+
667+ # return out, None
668+ return torch .nn .Identity ()(param ), None
669+
670+ def simulate_compression_quats (self , param : torch .nn .Parameter , step : int , choose_idx : torch .Tensor ) -> Tensor :
671+ # fake quantize
672+ if step < 10_000 :
673+ fq_out_dict = fake_quantize_ste (param , self .bds ["quats" ][0 ], self .bds ["quats" ][1 ], 8 , self .quantization_sim_type )
674+ else :
675+ fq_out_dict = fake_quantize_ste (param , self .bds ["quats" ][0 ], self .bds ["quats" ][1 ], self .q_bitwidth ["quats" ], self .quantization_sim_type )
676+
677+ # entropy constraint
678+ if step > self .entropy_steps ["quats" ] and self .entropy_model_enable and self .entropy_model_option ["quats" ]:
679+ # import pdb; pdb.set_trace()
680+ if choose_idx is not None :
681+ esti_bits = self .entropy_models ["quats" ](fq_out_dict ["output_value" ][choose_idx ], fq_out_dict ["q_step" ])
682+ else :
683+ esti_bits = self .entropy_models ["quats" ](fq_out_dict ["output_value" ], fq_out_dict ["q_step" ])
510684
511- param_norm = ( param - mins ) / ( maxs - mins )
512- q_step = 1 / ( 2 ** 16 - 1 )
513- q_param_norm = ((( param_norm / q_step ). round () * q_step ) - param_norm ). detach () + param_norm
685+ return fq_out_dict [ "output_value" ], esti_bits
686+ else :
687+ return fq_out_dict [ "output_value" ], None
514688
515- q_param = q_param_norm * (maxs - mins ) + mins
689+
690+ def simulate_compression_scales (self , param : torch .nn .Parameter , step : int , choose_idx : torch .Tensor ) -> Tensor :
691+ # fake quantize
692+ if step < 10_000 :
693+ fq_out_dict = fake_quantize_ste (param , self .bds ["scales" ][0 ], self .bds ["scales" ][1 ], 8 , self .quantization_sim_type )
694+ else :
695+ fq_out_dict = fake_quantize_ste (param , self .bds ["scales" ][0 ], self .bds ["scales" ][1 ], self .q_bitwidth ["scales" ], self .quantization_sim_type )
516696
517- return q_param
697+ # entropy constraint
698+ if step > self .entropy_steps ["scales" ] and self .entropy_model_enable and self .entropy_model_option ["scales" ]:
699+ # import pdb; pdb.set_trace()
700+ # factorized model
701+ if choose_idx is not None :
702+ esti_bits = self .entropy_models ["scales" ](fq_out_dict ["output_value" ][choose_idx ], fq_out_dict ["q_step" ])
703+ else :
704+ esti_bits = self .entropy_models ["scales" ](fq_out_dict ["output_value" ], fq_out_dict ["q_step" ])
518705
519- # to simulate what happens in gsplat 's PngCompression()
520- def _min_max_quantization ( param : Tensor ) -> Tensor : # seems not working...
521- maxs = torch .amax ( param , dim = 0 )
522- mins = torch . amin ( param , dim = 0 )
706+ # gaussian model
707+ # mean = torch.mean(fq_out_dict["output_value"][choose_idx])
708+ # std = torch.std(fq_out_dict["output_value"][choose_idx] )
709+ # esti_bits = self.entropy_models["scales"](fq_out_dict["output_value"][choose_idx], mean, std, fq_out_dict["q_step"] )
523710
524- param_norm = ( param - mins ) / ( maxs - mins )
525- q_step = 1 / ( 2 ** 8 - 1 )
526- q_param_norm = ((( param_norm / q_step ). round () * q_step ) - param_norm ). detach () + param_norm
711+ return fq_out_dict [ "output_value" ], esti_bits
712+ else :
713+ return fq_out_dict [ "output_value" ], None
527714
528- q_param = q_param_norm * (maxs - mins ) + mins
715+
716+ def simulate_compression_opacities (self , param : torch .nn .Parameter , step : int , choose_idx : torch .Tensor ) -> Tensor :
717+ # fake quantize
718+ fq_out_dict = fake_quantize_ste (param , self .bds ["opacities" ][0 ], self .bds ["opacities" ][1 ], 8 , self .quantization_sim_type )
529719
530- return q_param
720+ # entropy constraint
721+ if step > self .entropy_steps ["opacities" ] and self .entropy_model_enable and self .entropy_model_option ["opacities" ]:
722+ fq_out_dict ["output_value" ] = fq_out_dict ["output_value" ].unsqueeze (1 )
723+ if choose_idx is not None :
724+ esti_bits = self .entropy_models ["opacities" ](fq_out_dict ["output_value" ][choose_idx ], fq_out_dict ["q_step" ])
725+ else :
726+ esti_bits = self .entropy_models ["opacities" ](fq_out_dict ["output_value" ], fq_out_dict ["q_step" ])
727+ return fq_out_dict ["output_value" ].squeeze (1 ), esti_bits
728+ else :
729+ return fq_out_dict ["output_value" ], None
730+
731+
732+ def simulate_compression_colors (self , param : torch .nn .Parameter , step : int , choose_idx : torch .Tensor ) -> Tensor :
733+ # fake quantize
734+ if step < 10_000 :
735+ fq_out_dict = fake_quantize_ste (param , self .bds ["colors" ][0 ], self .bds ["colors" ][1 ], 8 , self .quantization_sim_type )
736+ else :
737+ fq_out_dict = fake_quantize_ste (param , self .bds ["colors" ][0 ], self .bds ["colors" ][1 ], self .q_bitwidth ["colors" ], self .quantization_sim_type )
738+
739+ # entropy constraint
740+ if step > self .entropy_steps ["colors" ] and self .entropy_model_enable and self .entropy_model_option ["colors" ]:
741+ fq_out_dict ["output_value" ] = fq_out_dict ["output_value" ]
742+ if choose_idx is not None :
743+ esti_bits = self .entropy_models ["colors" ](fq_out_dict ["output_value" ][choose_idx ], fq_out_dict ["q_step" ])
744+ else :
745+ esti_bits = self .entropy_models ["colors" ](fq_out_dict ["output_value" ], fq_out_dict ["q_step" ])
746+ return fq_out_dict ["output_value" ], esti_bits
747+ else :
748+ return fq_out_dict ["output_value" ], None
749+
531750
532- def _ste_quantization_for_quats (param : Tensor ) -> Tensor :
533- return STE_quant_for_quats .apply (param , 8 )
751+ def simulate_compression_features_dir (self , param : torch .nn .Parameter , step : int , choose_idx : torch .Tensor ) -> Tensor :
752+ # fake quantize
753+ if step < 10_000 :
754+ fq_out_dict = fake_quantize_ste (param , self .bds ["features_dir" ][0 ], self .bds ["features_dir" ][1 ], 8 , self .quantization_sim_type )
755+ else :
756+ fq_out_dict = fake_quantize_ste (param , self .bds ["features_dir" ][0 ], self .bds ["features_dir" ][1 ], self .q_bitwidth ["features_dir" ], self .quantization_sim_type )
534757
535- def _ste_quantization_given_q_step (param : Tensor ) -> Tensor :
536- return STE_multistep .apply (param , 0.001 )
758+ # entropy constraint
759+ if step > self .entropy_steps ["features_dir" ] and self .entropy_model_enable and self .entropy_model_option ["features_dir" ]:
760+ fq_out_dict ["output_value" ] = fq_out_dict ["output_value" ]
761+ if choose_idx is not None :
762+ esti_bits = self .entropy_models ["features_dir" ](fq_out_dict ["output_value" ][choose_idx ], fq_out_dict ["q_step" ])
763+ else :
764+ esti_bits = self .entropy_models ["features_dir" ](fq_out_dict ["output_value" ], fq_out_dict ["q_step" ])
765+ return fq_out_dict ["output_value" ], esti_bits
766+ else :
767+ return fq_out_dict ["output_value" ], None
768+
537769
538- def _ste_only (param : torch .nn .Parameter ) -> torch .nn .Parameter :
539- return param
540- # return STE.apply(param)
541- # return (param.detach() - param.detach()) + param # not working...
770+ def simulate_compression_features_time (self , param : torch .nn .Parameter , step : int , choose_idx : torch .Tensor ) -> Tensor :
771+ # fake quantize
772+ if step < 10_000 :
773+ fq_out_dict = fake_quantize_ste (param , self .bds ["features_time" ][0 ], self .bds ["features_time" ][1 ], 8 , self .quantization_sim_type )
774+ else :
775+ fq_out_dict = fake_quantize_ste (param , self .bds ["features_time" ][0 ], self .bds ["features_time" ][1 ], self .q_bitwidth ["features_time" ], self .quantization_sim_type )
542776
543- def _add_noise_to_simulate_quantization (param : Tensor ) -> Tensor :
544- return param + torch .empty_like (param ).uniform_ (- 0.5 , 0.5 ) * 0.001
777+ # entropy constraint
778+ if step > self .entropy_steps ["features_time" ] and self .entropy_model_enable and self .entropy_model_option ["features_time" ]:
779+ fq_out_dict ["output_value" ] = fq_out_dict ["output_value" ]
780+ if choose_idx is not None :
781+ esti_bits = self .entropy_models ["features_time" ](fq_out_dict ["output_value" ][choose_idx ], fq_out_dict ["q_step" ])
782+ else :
783+ esti_bits = self .entropy_models ["features_time" ](fq_out_dict ["output_value" ], fq_out_dict ["q_step" ])
784+ return fq_out_dict ["output_value" ], esti_bits
785+ else :
786+ return fq_out_dict ["output_value" ], None
0 commit comments