Skip to content

Commit 0cc5cce

Browse files
committed
bugfix: bugs coming from branch merging
1 parent 504d4cb commit 0cc5cce

File tree

5 files changed

+297
-40
lines changed

5 files changed

+297
-40
lines changed

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@ cd examples
2525
pip install -r requirements.txt
2626
# download mipnerf_360 benchmark data
2727
python datasets/download_dataset.py
28-
# place other dataset under 'data' folder
28+
# or place other dataset under 'data' folder
29+
ln -s data/tandt /xxxx/Dataset/tandt
2930
```
3031

3132
We also use third-party library, 'python-fpnge', to accelerate image saving operations during the experiment for now.

gsplat/compression/entropy_coding_compression.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,7 @@
1313
from gsplat.compression.sort import sort_splats
1414
from gsplat.utils import inverse_log_transform, log_transform
1515

16-
try:
17-
import constriction
18-
except:
19-
raise ImportError(
20-
"Please install constriction with 'pip install constriction' to use ANS"
21-
)
16+
2217

2318

2419
@dataclass
@@ -262,6 +257,13 @@ def _get_likelihood(symbols: np.array, bitwidth: int=8) -> np.array:
262257
pass
263258

264259
def _categorical_ans_encode(symbols: np.array, probabilities: np.array, save_path:str):
260+
try:
261+
import constriction
262+
except:
263+
raise ImportError(
264+
"Please install constriction with 'pip install constriction' to use ANS"
265+
)
266+
265267
num_symbols = symbols.shape[-1]
266268

267269
message_list = []
@@ -364,7 +366,13 @@ def _decompress_factorized_ans(compress_dir: str, param_name: str, meta: Dict[st
364366
Returns:
365367
Tensor: parameters
366368
"""
367-
import imageio.v2 as imageio
369+
try:
370+
import constriction
371+
except:
372+
raise ImportError(
373+
"Please install constriction with 'pip install constriction' to use ANS"
374+
)
375+
368376
if param_name == "sh0":
369377
import pdb; pdb.set_trace()
370378
if not np.all(meta["shape"]):
@@ -413,6 +421,12 @@ def _compress_gaussian_ans(
413421
Returns:
414422
Dict[str, Any]: metadata
415423
"""
424+
try:
425+
import constriction
426+
except:
427+
raise ImportError(
428+
"Please install constriction with 'pip install constriction' to use ANS"
429+
)
416430

417431
mins = torch.amin(params, dim=0)
418432
maxs = torch.amax(params, dim=0)
@@ -683,7 +697,7 @@ def _compress_kmeans(
683697
"dtype": str(params.dtype).split(".")[1],
684698
}
685699
return meta
686-
# import pdb; pdb.set_trace()
700+
687701
kmeans = KMeans(n_clusters=n_clusters, distance="manhattan", verbose=verbose)
688702
x = params.reshape(params.shape[0], -1).permute(1, 0).contiguous()
689703
labels = kmeans.fit(x)

gsplat/compression_simulation/gaussian_distribution_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import _gridencoder as _backend
99
except:
1010
raise ImportError(
11-
"Please install gridencoder with 'pip install gscodec/gridencoder' to use hash encoding"
11+
"Please install gridencoder with 'pip install third_party/gridencoder' to use hash encoding"
1212
)
1313

1414
anchor_round_digits = 16

gsplat/compression_simulation/simulation.py

Lines changed: 270 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -502,43 +502,285 @@ def forward(ctx, input):
502502
@staticmethod
503503
def backward(ctx, grad_output):
504504
return grad_output
505+
506+
507+
508+
class STGCompressionSimulation:
509+
"""
510+
"""
511+
def __init__(self, quantization_sim_type: Optional[Literal["round", "noise", "vq"]] = None,
512+
entropy_model_enable: bool = False,
513+
entropy_steps: Dict[str, int] = None,
514+
device: device = None,
515+
ada_mask_opt: bool = False,
516+
ada_mask_step: int = 10_000,
517+
**kwargs) -> None:
518+
self.quantization_sim_type = quantization_sim_type
519+
520+
self.entropy_model_enable = entropy_model_enable
521+
self.entropy_steps = entropy_steps
522+
self.device = device
523+
524+
# simulation_option: dict to specify which properties should be involved in the compression simulation.
525+
# Once option is set to True, it must have corresponding simulate_fn
526+
self.simulation_option = {
527+
"means": False,
528+
"scales": True,
529+
"quats": True,
530+
"opacities": True,
531+
"trbf_center": False,
532+
"trbf_scale": False,
533+
"motion": False, # [N, 9]
534+
"omega": False, # [N, 4]
535+
"colors": True,
536+
"features_dir": True,
537+
"features_time": True,
538+
}
539+
540+
self.shN_qat = False
541+
self.shN_ada_mask_opt = ada_mask_opt
542+
self.shN_ada_mask_step = ada_mask_step
543+
544+
# configs for "differentiable quantization"
545+
self.q_bitwidth = {
546+
"means": None,
547+
"scales": 8,
548+
"quats": 8,
549+
"opacities": 8,
550+
"trbf_center": None,
551+
"trbf_scale": None,
552+
"motion": None, # [N, 9]
553+
"omega": None, # [N, 4]
554+
"colors": 8,
555+
"features_dir": 8,
556+
"features_time": 8,
557+
}
558+
559+
self.bds = {
560+
"means": None,
561+
"scales": [-10, 2],
562+
"quats": [-1, 1],
563+
"opacities": [-7, 7],
564+
"trbf_center": None,
565+
"trbf_scale": None,
566+
"motion": None, # [N, 9]
567+
"omega": None, # [N, 4]
568+
"colors": [-7.5, 7.5],
569+
"features_dir": [-10, 10],
570+
"features_time": [-10, 10],
571+
}
572+
573+
# configs for "entropy constraint"
574+
self.entropy_model_option = {
575+
"means": False,
576+
"scales": True,
577+
"quats": True,
578+
"opacities": False,
579+
"colors": True,
580+
"features_dir": True,
581+
"features_time": True
582+
# "shN": False
583+
}
584+
585+
if self.entropy_model_enable:
586+
self.entropy_models = {
587+
"means": None,
588+
"scales": Entropy_factorized_optimized_refactor(channel=3).to(self.device),
589+
# "scales": None,
590+
"quats": Entropy_factorized_optimized_refactor(channel=4).to(self.device),
591+
"opacities": None,
592+
"colors": Entropy_factorized_optimized_refactor(channel=3, filters=(3, 3)).to(self.device),
593+
"features_dir": Entropy_factorized_optimized_refactor(channel=3, filters=(3, 3)).to(self.device),
594+
"features_time": Entropy_factorized_optimized_refactor(channel=3, filters=(3, 3)).to(self.device),
595+
}
596+
597+
self.entropy_model_optimizers = {}
598+
for k, v in self.entropy_models.items():
599+
if isinstance(v, Entropy_factorized) or isinstance(v, Entropy_factorized_optimized) or isinstance(v, Entropy_factorized_optimized_refactor):
600+
v_opt = torch.optim.Adam(
601+
[{"params": p, "lr": 1e-4, "name": n} for n, p in v.named_parameters()]
602+
)
603+
# v_opt = torch.optim.SGD(
604+
# [{"params": p, "lr": 1e-4, "name": n} for n, p in v.named_parameters()]
605+
# )
606+
else:
607+
v_opt = None
608+
self.entropy_model_optimizers.update({k: v_opt})
609+
610+
# configs for "adaptive mask"
611+
if self.shN_ada_mask_opt:
612+
from .ada_mask import AnnealingMask
613+
cap_max = kwargs.get("cap_max", 1_000_000)
614+
self.shN_ada_mask = AnnealingMask(input_shape=[cap_max, 1, 1],
615+
device=device,
616+
annealing_start_iter=ada_mask_step)
617+
618+
self.shN_ada_mask_optimizer = torch.optim.Adam([
619+
{'params': self.shN_ada_mask.parameters(), 'lr': 0.01}
620+
])
621+
622+
def _get_simulate_fn(self, param_name: str) -> Callable:
623+
simulate_fn_map = {
624+
"means": self.simulate_compression_means,
625+
"scales": self.simulate_compression_scales,
626+
"quats": self.simulate_compression_quats,
627+
"opacities": self.simulate_compression_opacities,
628+
# "trbf_center": self.simulate_compression_trbf_center,
629+
# "trbf_scale": self.simulate_compression_trbf_scale,
630+
# "motion": self.simulate_compression_motion,
631+
# "omega": self.simulate_compression_omega,
632+
"colors": self.simulate_compression_colors,
633+
"features_dir": self.simulate_compression_features_dir,
634+
"features_time": self.simulate_compression_features_time
635+
}
636+
if param_name in simulate_fn_map:
637+
return simulate_fn_map[param_name]
638+
else:
639+
return torch.nn.Identity()
640+
641+
def simulate_compression(self, splats: Dict[str, Tensor], step: int) -> Dict[str, Tensor]:
642+
"""
643+
"""
644+
# Create empty dicts for output, including fake quantized values and (optional) estimated bits
645+
new_splats = {}
646+
esti_bits_dict = {}
647+
648+
# # Randomly sample approximately 5% of the points rather than all points for speedup.
649+
# choose_idx = torch.rand_like(splats["means"][:, 0], device=self.device) <= 1
650+
choose_idx = None
651+
652+
for param_name in splats.keys():
653+
# Check which params need to be simulate
654+
if self.simulation_option[param_name]:
655+
simulate_fn = self._get_simulate_fn(param_name)
656+
new_splats[param_name], esti_bits_dict[param_name] = simulate_fn(splats[param_name], step, choose_idx)
657+
else:
658+
new_splats[param_name] = splats[param_name] + 0.
659+
esti_bits_dict[param_name] = None
660+
661+
return new_splats, esti_bits_dict
505662

506-
# to simulate what happens in gsplat 's PngCompression()
507-
def _min_max_quantization_16bit(param: Tensor) -> Tensor:
508-
maxs = torch.amax(param, dim=0)
509-
mins = torch.amin(param, dim=0)
663+
def simulate_compression_means(self, param: torch.nn.Parameter, step: int, choose_idx: torch.Tensor) -> Tensor:
664+
# out = torch.clamp(param, -5, 5)
665+
# out = inverse_log_transform(log_transform(clamped_param))
666+
667+
# return out, None
668+
return torch.nn.Identity()(param), None
669+
670+
def simulate_compression_quats(self, param: torch.nn.Parameter, step: int, choose_idx: torch.Tensor) -> Tensor:
671+
# fake quantize
672+
if step < 10_000:
673+
fq_out_dict = fake_quantize_ste(param, self.bds["quats"][0], self.bds["quats"][1], 8, self.quantization_sim_type)
674+
else:
675+
fq_out_dict = fake_quantize_ste(param, self.bds["quats"][0], self.bds["quats"][1], self.q_bitwidth["quats"], self.quantization_sim_type)
676+
677+
# entropy constraint
678+
if step > self.entropy_steps["quats"] and self.entropy_model_enable and self.entropy_model_option["quats"]:
679+
# import pdb; pdb.set_trace()
680+
if choose_idx is not None:
681+
esti_bits = self.entropy_models["quats"](fq_out_dict["output_value"][choose_idx], fq_out_dict["q_step"])
682+
else:
683+
esti_bits = self.entropy_models["quats"](fq_out_dict["output_value"], fq_out_dict["q_step"])
510684

511-
param_norm = (param - mins) / (maxs - mins)
512-
q_step = 1 / (2**16 - 1)
513-
q_param_norm = (((param_norm / q_step).round() * q_step) - param_norm).detach() + param_norm
685+
return fq_out_dict["output_value"], esti_bits
686+
else:
687+
return fq_out_dict["output_value"], None
514688

515-
q_param = q_param_norm * (maxs - mins) + mins
689+
690+
def simulate_compression_scales(self, param: torch.nn.Parameter, step: int, choose_idx: torch.Tensor) -> Tensor:
691+
# fake quantize
692+
if step < 10_000:
693+
fq_out_dict = fake_quantize_ste(param, self.bds["scales"][0], self.bds["scales"][1], 8, self.quantization_sim_type)
694+
else:
695+
fq_out_dict = fake_quantize_ste(param, self.bds["scales"][0], self.bds["scales"][1], self.q_bitwidth["scales"], self.quantization_sim_type)
516696

517-
return q_param
697+
# entropy constraint
698+
if step > self.entropy_steps["scales"] and self.entropy_model_enable and self.entropy_model_option["scales"]:
699+
# import pdb; pdb.set_trace()
700+
# factorized model
701+
if choose_idx is not None:
702+
esti_bits = self.entropy_models["scales"](fq_out_dict["output_value"][choose_idx], fq_out_dict["q_step"])
703+
else:
704+
esti_bits = self.entropy_models["scales"](fq_out_dict["output_value"], fq_out_dict["q_step"])
518705

519-
# to simulate what happens in gsplat 's PngCompression()
520-
def _min_max_quantization(param: Tensor) -> Tensor: # seems not working...
521-
maxs = torch.amax(param, dim=0)
522-
mins = torch.amin(param, dim=0)
706+
# gaussian model
707+
# mean = torch.mean(fq_out_dict["output_value"][choose_idx])
708+
# std = torch.std(fq_out_dict["output_value"][choose_idx])
709+
# esti_bits = self.entropy_models["scales"](fq_out_dict["output_value"][choose_idx], mean, std, fq_out_dict["q_step"])
523710

524-
param_norm = (param - mins) / (maxs - mins)
525-
q_step = 1 / (2**8 - 1)
526-
q_param_norm = (((param_norm / q_step).round() * q_step) - param_norm).detach() + param_norm
711+
return fq_out_dict["output_value"], esti_bits
712+
else:
713+
return fq_out_dict["output_value"], None
527714

528-
q_param = q_param_norm * (maxs - mins) + mins
715+
716+
def simulate_compression_opacities(self, param: torch.nn.Parameter, step: int, choose_idx: torch.Tensor) -> Tensor:
717+
# fake quantize
718+
fq_out_dict = fake_quantize_ste(param, self.bds["opacities"][0], self.bds["opacities"][1], 8, self.quantization_sim_type)
529719

530-
return q_param
720+
# entropy constraint
721+
if step > self.entropy_steps["opacities"] and self.entropy_model_enable and self.entropy_model_option["opacities"]:
722+
fq_out_dict["output_value"] = fq_out_dict["output_value"].unsqueeze(1)
723+
if choose_idx is not None:
724+
esti_bits = self.entropy_models["opacities"](fq_out_dict["output_value"][choose_idx], fq_out_dict["q_step"])
725+
else:
726+
esti_bits = self.entropy_models["opacities"](fq_out_dict["output_value"], fq_out_dict["q_step"])
727+
return fq_out_dict["output_value"].squeeze(1), esti_bits
728+
else:
729+
return fq_out_dict["output_value"], None
730+
731+
732+
def simulate_compression_colors(self, param: torch.nn.Parameter, step: int, choose_idx: torch.Tensor) -> Tensor:
733+
# fake quantize
734+
if step < 10_000:
735+
fq_out_dict = fake_quantize_ste(param, self.bds["colors"][0], self.bds["colors"][1], 8, self.quantization_sim_type)
736+
else:
737+
fq_out_dict = fake_quantize_ste(param, self.bds["colors"][0], self.bds["colors"][1], self.q_bitwidth["colors"], self.quantization_sim_type)
738+
739+
# entropy constraint
740+
if step > self.entropy_steps["colors"] and self.entropy_model_enable and self.entropy_model_option["colors"]:
741+
fq_out_dict["output_value"] = fq_out_dict["output_value"]
742+
if choose_idx is not None:
743+
esti_bits = self.entropy_models["colors"](fq_out_dict["output_value"][choose_idx], fq_out_dict["q_step"])
744+
else:
745+
esti_bits = self.entropy_models["colors"](fq_out_dict["output_value"], fq_out_dict["q_step"])
746+
return fq_out_dict["output_value"], esti_bits
747+
else:
748+
return fq_out_dict["output_value"], None
749+
531750

532-
def _ste_quantization_for_quats(param: Tensor) -> Tensor:
533-
return STE_quant_for_quats.apply(param, 8)
751+
def simulate_compression_features_dir(self, param: torch.nn.Parameter, step: int, choose_idx: torch.Tensor) -> Tensor:
752+
# fake quantize
753+
if step < 10_000:
754+
fq_out_dict = fake_quantize_ste(param, self.bds["features_dir"][0], self.bds["features_dir"][1], 8, self.quantization_sim_type)
755+
else:
756+
fq_out_dict = fake_quantize_ste(param, self.bds["features_dir"][0], self.bds["features_dir"][1], self.q_bitwidth["features_dir"], self.quantization_sim_type)
534757

535-
def _ste_quantization_given_q_step(param: Tensor) -> Tensor:
536-
return STE_multistep.apply(param, 0.001)
758+
# entropy constraint
759+
if step > self.entropy_steps["features_dir"] and self.entropy_model_enable and self.entropy_model_option["features_dir"]:
760+
fq_out_dict["output_value"] = fq_out_dict["output_value"]
761+
if choose_idx is not None:
762+
esti_bits = self.entropy_models["features_dir"](fq_out_dict["output_value"][choose_idx], fq_out_dict["q_step"])
763+
else:
764+
esti_bits = self.entropy_models["features_dir"](fq_out_dict["output_value"], fq_out_dict["q_step"])
765+
return fq_out_dict["output_value"], esti_bits
766+
else:
767+
return fq_out_dict["output_value"], None
768+
537769

538-
def _ste_only(param: torch.nn.Parameter) -> torch.nn.Parameter:
539-
return param
540-
# return STE.apply(param)
541-
# return (param.detach() - param.detach()) + param # not working...
770+
def simulate_compression_features_time(self, param: torch.nn.Parameter, step: int, choose_idx: torch.Tensor) -> Tensor:
771+
# fake quantize
772+
if step < 10_000:
773+
fq_out_dict = fake_quantize_ste(param, self.bds["features_time"][0], self.bds["features_time"][1], 8, self.quantization_sim_type)
774+
else:
775+
fq_out_dict = fake_quantize_ste(param, self.bds["features_time"][0], self.bds["features_time"][1], self.q_bitwidth["features_time"], self.quantization_sim_type)
542776

543-
def _add_noise_to_simulate_quantization(param: Tensor) -> Tensor:
544-
return param + torch.empty_like(param).uniform_(-0.5, 0.5) * 0.001
777+
# entropy constraint
778+
if step > self.entropy_steps["features_time"] and self.entropy_model_enable and self.entropy_model_option["features_time"]:
779+
fq_out_dict["output_value"] = fq_out_dict["output_value"]
780+
if choose_idx is not None:
781+
esti_bits = self.entropy_models["features_time"](fq_out_dict["output_value"][choose_idx], fq_out_dict["q_step"])
782+
else:
783+
esti_bits = self.entropy_models["features_time"](fq_out_dict["output_value"], fq_out_dict["q_step"])
784+
return fq_out_dict["output_value"], esti_bits
785+
else:
786+
return fq_out_dict["output_value"], None

0 commit comments

Comments
 (0)