Skip to content

Commit a4186d9

Browse files
committed
feat: entropy coding from learned dist. (from Gaussian model)
1 parent ceb404b commit a4186d9

File tree

9 files changed

+1045
-18
lines changed

9 files changed

+1045
-18
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,5 +133,6 @@ Readme_GSCodec.md
133133

134134
figs
135135
stats
136+
temp
136137

137138
!examples/benchmarks/compression/results/
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
# ----------------- Training Setting-------------- #
2+
SCENE_DIR="data/tandt"
3+
# eval all 9 scenes for benchmarking
4+
SCENE_LIST="train truck" # truck
5+
# SCENE_LIST="garden bicycle stump bonsai counter kitchen room treehill flowers"
6+
7+
# # 0.36M GSs
8+
# RESULT_DIR="results/benchmark_tt_mcmc_0_36M_png_compression"
9+
# CAP_MAX=360000
10+
11+
# # 0.49M GSs
12+
# RESULT_DIR="results/benchmark_tt_mcmc_tt_0_49M_png_compression"
13+
# CAP_MAX=490000
14+
15+
# 1M GSs
16+
RESULT_DIR="results/Ours_TT_hash_grid"
17+
CAP_MAX=1000000
18+
19+
# # 4M GSs
20+
# RESULT_DIR="results/benchmark_tt_mcmc_4M_png_compression"
21+
# CAP_MAX=4000000
22+
23+
RD_LAMBDA=0.01
24+
25+
# ----------------- Training Setting-------------- #
26+
27+
# ----------------- Args ------------------------- #
28+
29+
if [ ! -z "$1" ]; then
30+
RD_LAMBDA="$1"
31+
RESULT_DIR="results/Ours_TT_rd_lambda_${RD_LAMBDA}"
32+
fi
33+
34+
# ----------------- Args ------------------------- #
35+
36+
# ----------------- Main Job --------------------- #
37+
run_single_scene() {
38+
local GPU_ID=$1
39+
local SCENE=$2
40+
41+
echo "Running $SCENE on GPU: $GPU_ID"
42+
43+
# train without eval
44+
CUDA_VISIBLE_DEVICES=$GPU_ID python simple_trainer.py mcmc --eval_steps -1 --disable_viewer --data_factor 1 \
45+
--strategy.cap-max $CAP_MAX \
46+
--data_dir $SCENE_DIR/$SCENE/ \
47+
--result_dir $RESULT_DIR/$SCENE/ \
48+
--compression_sim \
49+
--entropy_model_opt --entropy_model_type gaussian_model \
50+
--rd_lambda $RD_LAMBDA \
51+
--shN_ada_mask_opt \
52+
--compression entropy_coding
53+
54+
55+
# eval: use vgg for lpips to align with other benchmarks
56+
CUDA_VISIBLE_DEVICES=$GPU_ID python simple_trainer.py mcmc --disable_viewer --data_factor 1 \
57+
--strategy.cap-max $CAP_MAX \
58+
--data_dir $SCENE_DIR/$SCENE/ \
59+
--result_dir $RESULT_DIR/$SCENE/ \
60+
--lpips_net vgg \
61+
--compression entropy_coding --entropy_model_type gaussian_model \
62+
--ckpt $RESULT_DIR/$SCENE/ckpts/ckpt_29999_rank0.pt
63+
64+
}
65+
# ----------------- Main Job --------------------- #
66+
67+
68+
69+
# ----------------- Experiment Loop -------------- #
70+
# GPU_LIST=(5 7)
71+
# GPU_COUNT=${#GPU_LIST[@]}
72+
73+
# SCENE_IDX=-1
74+
75+
# for SCENE in $SCENE_LIST;
76+
# do
77+
# SCENE_IDX=$((SCENE_IDX + 1))
78+
# {
79+
# run_single_scene ${GPU_LIST[$SCENE_IDX]} $SCENE
80+
# } &
81+
82+
# done
83+
84+
# ----------------- Experiment Loop -------------- #
85+
86+
# Wait for finishing the jobs across all scenes
87+
wait
88+
echo "All scenes finished."
89+
90+
# Zip the compressed files and summarize the stats
91+
if command -v zip &> /dev/null
92+
then
93+
echo "Zipping results"
94+
python benchmarks/compression/summarize_stats.py --results_dir $RESULT_DIR --scenes $SCENE_LIST
95+
else
96+
echo "zip command not found, skipping zipping"
97+
fi
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
SCENE_DIR="data/tandt"
2+
# eval all 2 scenes for benchmarking
3+
SCENE_LIST="train truck" # truck
4+
5+
# # 0.36M GSs
6+
# RESULT_DIR="results/benchmark_tt_mcmc_0_36M_png_compression"
7+
# CAP_MAX=360000
8+
9+
# # 0.49M GSs
10+
# RESULT_DIR="results/benchmark_tt_mcmc_tt_0_49M_png_compression"
11+
# CAP_MAX=490000
12+
13+
# 1M GSs
14+
RESULT_DIR="results/Ours_TT_rd_lambda_0.002_qualitative"
15+
CAP_MAX=1000000
16+
17+
# # 4M GSs
18+
# RESULT_DIR="results/benchmark_tt_mcmc_4M_png_compression"
19+
# CAP_MAX=4000000
20+
21+
for SCENE in $SCENE_LIST;
22+
do
23+
echo "Running $SCENE"
24+
25+
# train without eval
26+
# CUDA_VISIBLE_DEVICES=0 python simple_trainer.py mcmc --eval_steps -1 --disable_viewer --data_factor 1 \
27+
# --strategy.cap-max $CAP_MAX \
28+
# --data_dir $SCENE_DIR/$SCENE/ \
29+
# --result_dir $RESULT_DIR/$SCENE/
30+
31+
# eval: use vgg for lpips to align with other benchmarks
32+
CUDA_VISIBLE_DEVICES=0 python simple_trainer.py mcmc --disable_viewer --data_factor 1 \
33+
--strategy.cap-max $CAP_MAX \
34+
--data_dir $SCENE_DIR/$SCENE/ \
35+
--result_dir $RESULT_DIR/$SCENE/ \
36+
--lpips_net vgg \
37+
--compression png \
38+
--ckpt $RESULT_DIR/$SCENE/ckpts/ckpt_29999_rank0.pt
39+
done
40+
41+
# Zip the compressed files and summarize the stats
42+
if command -v zip &> /dev/null
43+
then
44+
echo "Zipping results"
45+
python benchmarks/compression/summarize_stats.py --results_dir $RESULT_DIR --scenes $SCENE_LIST
46+
else
47+
echo "zip command not found, skipping zipping"
48+
fi

examples/simple_trainer.py

Lines changed: 49 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
3131
from typing_extensions import Literal, assert_never
3232
from gsplat import strategy
33+
from gsplat.compression.entropy_coding_compression import EntropyCodingCompression
3334
from gsplat.compression_simulation import simulation
3435
from utils import AppearanceOptModule, CameraOptModule, knn, rgb_to_sh, set_random_seed
3536
from lib_bilagrid import (
@@ -44,6 +45,7 @@
4445
from gsplat.rendering import rasterization
4546
from gsplat.strategy import DefaultStrategy, MCMCStrategy
4647
from gsplat.compression_simulation import CompressionSimulation
48+
from gsplat.compression_simulation.entropy_model import Entropy_factorized_optimized_refactor, Entropy_gaussian
4749

4850
class ProfilerConfig:
4951
def __init__(self):
@@ -92,7 +94,7 @@ class Config:
9294
# Path to the .pt files. If provide, it will skip training and run evaluation only.
9395
ckpt: Optional[List[str]] = None
9496
# Name of compression strategy to use
95-
compression: Optional[Literal["png"]] = None
97+
compression: Optional[Literal["png", "entropy_coding"]] = None
9698

9799
# Enable profiler
98100
profiler_enabled: bool = False
@@ -118,10 +120,10 @@ class Config:
118120
"shN": 10_000})
119121
# gaussian model:
120122
# entropy_steps: Dict[str, int] = field(default_factory=lambda: {"means": -1,
121-
# "quats": -1,
123+
# "quats": 10_000,
122124
# "scales": 10_000,
123-
# "opacities": -1,
124-
# "sh0": -1,
125+
# "opacities": 10_000,
126+
# "sh0": 20_000,
125127
# "shN": -1})
126128

127129
# Enable shN adaptive mask
@@ -420,6 +422,8 @@ def __init__(
420422
if cfg.compression is not None:
421423
if cfg.compression == "png":
422424
self.compression_method = PngCompression()
425+
elif cfg.compression == "entropy_coding":
426+
self.compression_method = EntropyCodingCompression()
423427
else:
424428
raise ValueError(f"Unknown compression strategy: {cfg.compression}")
425429

@@ -897,6 +901,11 @@ def train(self):
897901

898902
if cfg.shN_ada_mask_opt and step > cfg.ada_mask_steps:
899903
data["shN_ada_mask"] = shN_ada_mask
904+
905+
if cfg.compression_sim and cfg.entropy_model_opt and cfg.compression == "entropy_coding":
906+
for name, entropy_model in self.compression_sim_method.entropy_models.items():
907+
if entropy_model is not None:
908+
data[name+"_entropy_model"] = entropy_model.state_dict()
900909

901910
torch.save(
902911
data, f"{self.ckpt_dir}/ckpt_{step}_rank{self.world_rank}.pt"
@@ -1032,8 +1041,9 @@ def eval(self, step: int, stage: str = "val"):
10321041
canvas_list = [pixels, colors]
10331042

10341043
if world_rank == 0:
1035-
# write images
1036-
canvas = torch.cat(canvas_list, dim=2).squeeze(0).cpu().numpy()
1044+
# write images
1045+
# canvas = torch.cat(canvas_list, dim=2).squeeze(0).cpu().numpy() # side by side
1046+
canvas = canvas_list[1].squeeze(0).cpu().numpy() # signle image
10371047
canvas = (canvas * 255).astype(np.uint8)
10381048
imageio.imwrite(
10391049
f"{self.render_dir}/{stage}_step{step}_{i:04d}.png",
@@ -1074,16 +1084,18 @@ def eval(self, step: int, stage: str = "val"):
10741084
self.writer.flush()
10751085

10761086
@torch.no_grad()
1077-
def render_traj(self, step: int):
1087+
def render_traj(self, step: int, stage: str = "val"):
10781088
"""Entry for trajectory rendering."""
10791089
print("Running trajectory rendering...")
10801090
cfg = self.cfg
10811091
device = self.device
10821092

1083-
camtoworlds_all = self.parser.camtoworlds[5:-5]
1093+
num_imgs = len(self.parser.camtoworlds)
1094+
1095+
camtoworlds_all = self.parser.camtoworlds[: num_imgs//2]
10841096
if cfg.render_traj_path == "interp":
10851097
camtoworlds_all = generate_interpolated_path(
1086-
camtoworlds_all, 1
1098+
camtoworlds_all, 6 #1
10871099
) # [N, 3, 4]
10881100
elif cfg.render_traj_path == "ellipse":
10891101
height = camtoworlds_all[:, 2, 3].mean()
@@ -1118,7 +1130,7 @@ def render_traj(self, step: int):
11181130
# save to video
11191131
video_dir = f"{cfg.result_dir}/videos"
11201132
os.makedirs(video_dir, exist_ok=True)
1121-
writer = imageio.get_writer(f"{video_dir}/traj_{step}.mp4", fps=30)
1133+
writer = imageio.get_writer(f"{video_dir}/{stage}_traj_{step}.mp4", fps=30)
11221134
for i in tqdm.trange(len(camtoworlds_all), desc="Rendering trajectory"):
11231135
camtoworlds = camtoworlds_all[i : i + 1]
11241136
Ks = K[None]
@@ -1139,11 +1151,12 @@ def render_traj(self, step: int):
11391151
canvas_list = [colors, depths.repeat(1, 1, 1, 3)]
11401152

11411153
# write images
1142-
canvas = torch.cat(canvas_list, dim=2).squeeze(0).cpu().numpy()
1154+
# canvas = torch.cat(canvas_list, dim=2).squeeze(0).cpu().numpy()
1155+
canvas = canvas_list[0].squeeze(0).cpu().numpy()
11431156
canvas = (canvas * 255).astype(np.uint8)
11441157
writer.append_data(canvas)
11451158
writer.close()
1146-
print(f"Video saved to {video_dir}/traj_{step}.mp4")
1159+
print(f"Video saved to {video_dir}/{stage}_traj_{step}.mp4")
11471160

11481161
@torch.no_grad()
11491162
def run_compression(self, step: int):
@@ -1156,8 +1169,12 @@ def run_compression(self, step: int):
11561169
# import pdb; pdb.set_trace()
11571170
self.run_param_distribution_vis(self.splats, save_dir=f"{cfg.result_dir}/visualization/raw")
11581171

1159-
self.compression_method.compress(compress_dir, self.splats)
1160-
# self.run_param_distribution_vis(self.splats, save_dir=f"{cfg.result_dir}/visualization/log_transform")
1172+
if isinstance(self.compression_method, PngCompression):
1173+
self.compression_method.compress(compress_dir, self.splats)
1174+
elif isinstance(self.compression_method, EntropyCodingCompression):
1175+
self.compression_method.compress(compress_dir, self.splats, self.entropy_models)
1176+
else:
1177+
raise NotImplementedError(f"The compression method is not implemented yet.")
11611178

11621179
# evaluate compression
11631180
splats_c = self.compression_method.decompress(compress_dir)
@@ -1167,6 +1184,7 @@ def run_compression(self, step: int):
11671184
for k in splats_c.keys():
11681185
self.splats[k].data = splats_c[k].to(self.device)
11691186
self.eval(step=step, stage="compress")
1187+
self.render_traj(step=step, stage="compress")
11701188

11711189
@torch.no_grad()
11721190
def run_param_distribution_vis(self, param_dict: Dict[str, Tensor], save_dir: str):
@@ -1199,6 +1217,21 @@ def run_param_distribution_vis(self, param_dict: Dict[str, Tensor], save_dir: st
11991217
plt.close()
12001218

12011219
print(f"Histograms saved in '{save_dir}' directory.")
1220+
1221+
def load_entropy_model_from_ckpt(self, ckpt: Dict, entropy_model_type: str):
1222+
self.entropy_models = {}
1223+
for name, value in ckpt.items():
1224+
if "_entropy_model" in name:
1225+
attr_name = name[:(len(name) - len("_entropy_model"))]
1226+
num_ch = ckpt["splats"][attr_name].shape[-1]
1227+
if entropy_model_type == "factorized_model":
1228+
# TODO
1229+
pass
1230+
elif entropy_model_type == "gaussian_model":
1231+
entropy_model = Entropy_gaussian(channel=num_ch)
1232+
1233+
entropy_model.load_state_dict(value)
1234+
self.entropy_models[attr_name] = entropy_model
12021235

12031236
@torch.no_grad()
12041237
def _viewer_render_fn(
@@ -1242,6 +1275,8 @@ def main(local_rank: int, world_rank, world_size: int, cfg: Config):
12421275
runner.eval(step=step)
12431276
runner.render_traj(step=step)
12441277
if cfg.compression is not None:
1278+
if cfg.compression == "entropy_coding":
1279+
runner.load_entropy_model_from_ckpt(ckpts[0], cfg.entropy_model_type)
12451280
runner.run_compression(step=step)
12461281
else:
12471282
runner.train()

gsplat/compression/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
from .png_compression import PngCompression
2+
from .entropy_coding_compression import EntropyCodingCompression

0 commit comments

Comments
 (0)