Skip to content

Commit bb04f83

Browse files
committed
feat: update config parser, more flexible cli command for post-training compression, update gaussian_entropy_model
1 parent 8d3c54b commit bb04f83

File tree

7 files changed

+334
-81
lines changed

7 files changed

+334
-81
lines changed

examples/benchmarks/compression/final_exp/mcmc_tt_sim.sh

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -41,15 +41,15 @@ run_single_scene() {
4141
echo "Running $SCENE on GPU: $GPU_ID"
4242

4343
# train without eval
44-
CUDA_VISIBLE_DEVICES=$GPU_ID python simple_trainer.py mcmc --eval_steps -1 --disable_viewer --data_factor 1 \
45-
--strategy.cap-max $CAP_MAX \
46-
--data_dir $SCENE_DIR/$SCENE/ \
47-
--result_dir $RESULT_DIR/$SCENE/ \
48-
--compression_sim \
49-
--entropy_model_opt \
50-
--rd_lambda $RD_LAMBDA \
51-
--shN_ada_mask_opt \
52-
--compression png
44+
# CUDA_VISIBLE_DEVICES=$GPU_ID python simple_trainer.py mcmc --eval_steps -1 --disable_viewer --data_factor 1 \
45+
# --strategy.cap-max $CAP_MAX \
46+
# --data_dir $SCENE_DIR/$SCENE/ \
47+
# --result_dir $RESULT_DIR/$SCENE/ \
48+
# --compression_sim \
49+
# --entropy_model_opt \
50+
# --rd_lambda $RD_LAMBDA \
51+
# --shN_ada_mask_opt \
52+
# --compression png
5353

5454

5555
# eval: use vgg for lpips to align with other benchmarks

examples/benchmarks/compression/final_exp/mcmc_tt_sim_hash_grid.sh

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -41,15 +41,15 @@ run_single_scene() {
4141
echo "Running $SCENE on GPU: $GPU_ID"
4242

4343
# train without eval
44-
# CUDA_VISIBLE_DEVICES=$GPU_ID python simple_trainer.py mcmc --eval_steps -1 --disable_viewer --data_factor 1 \
45-
# --strategy.cap-max $CAP_MAX \
46-
# --data_dir $SCENE_DIR/$SCENE/ \
47-
# --result_dir $RESULT_DIR/$SCENE/ \
48-
# --compression_sim \
49-
# --entropy_model_opt --entropy_model_type gaussian_model \
50-
# --rd_lambda $RD_LAMBDA \
51-
# --shN_ada_mask_opt \
52-
# --compression entropy_coding
44+
CUDA_VISIBLE_DEVICES=$GPU_ID python simple_trainer.py mcmc --eval_steps 10000 20000 30000 --disable_viewer --data_factor 1 \
45+
--strategy.cap-max $CAP_MAX \
46+
--data_dir $SCENE_DIR/$SCENE/ \
47+
--result_dir $RESULT_DIR/$SCENE/ \
48+
--compression_sim \
49+
--entropy_model_opt --entropy_model_type gaussian_model \
50+
--rd_lambda $RD_LAMBDA \
51+
--shN_ada_mask_opt \
52+
--compression entropy_coding
5353

5454

5555
# eval: use vgg for lpips to align with other benchmarks
@@ -59,15 +59,19 @@ run_single_scene() {
5959
--result_dir $RESULT_DIR/$SCENE/ \
6060
--lpips_net vgg \
6161
--compression entropy_coding --entropy_model_type gaussian_model \
62-
--ckpt $RESULT_DIR/$SCENE/ckpts/ckpt_29999_rank0.pt
62+
--ckpt $RESULT_DIR/$SCENE/ckpts/ckpt_29999_rank0.pt \
63+
--compression_cfg.attribute_codec_registry.scales.encode _compress_gaussian_ans \
64+
--compression_cfg.attribute_codec_registry.scales.decode _decompress_gaussian_ans \
65+
--compression_cfg.attribute_codec_registry.quats.encode _compress_gaussian_ans \
66+
--compression_cfg.attribute_codec_registry.quats.decode _decompress_gaussian_ans
6367

6468
}
6569
# ----------------- Main Job --------------------- #
6670

6771

6872

6973
# ----------------- Experiment Loop -------------- #
70-
GPU_LIST=(3 4)
74+
GPU_LIST=(0 1)
7175
GPU_COUNT=${#GPU_LIST[@]}
7276

7377
SCENE_IDX=-1
@@ -77,7 +81,7 @@ do
7781
SCENE_IDX=$((SCENE_IDX + 1))
7882
{
7983
run_single_scene ${GPU_LIST[$SCENE_IDX]} $SCENE
80-
} &
84+
} #&
8185

8286
done
8387

examples/simple_trainer.py

Lines changed: 65 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from contextlib import nullcontext
77
from dataclasses import dataclass, field
88
from collections import defaultdict
9-
from typing import Dict, List, Optional, Tuple, Union, ContextManager
9+
from typing import Dict, List, Optional, Tuple, Union, ContextManager, TypedDict, Any
1010

1111
import imageio
1212
import nerfview
@@ -77,12 +77,65 @@ def _create_schedule(self):
7777
)
7878

7979
def update_schedule(self, **kwargs):
80-
"""动态更新schedule参数"""
8180
for key, value in kwargs.items():
8281
if hasattr(self, key):
8382
setattr(self, key, value)
8483
self.schedule = self._create_schedule()
8584

85+
@dataclass
86+
class CodecConfig:
87+
encode: str
88+
decode: str
89+
90+
@dataclass
91+
class AttributeCodecs:
92+
means: CodecConfig = field(default_factory=lambda: CodecConfig("_compress_png_16bit", "_decompress_png_16bit"))
93+
scales: CodecConfig = field(default_factory=lambda: CodecConfig("_compress_factorized_ans", "_decompress_factorized_ans"))
94+
quats: CodecConfig = field(default_factory=lambda: CodecConfig("_compress_factorized_ans", "_decompress_factorized_ans"))
95+
opacities: CodecConfig = field(default_factory=lambda: CodecConfig("_compress_png", "_decompress_png"))
96+
sh0: CodecConfig = field(default_factory=lambda: CodecConfig("_compress_png", "_decompress_png"))
97+
shN: CodecConfig = field(default_factory=lambda: CodecConfig("_compress_masked_kmeans", "_decompress_masked_kmeans"))
98+
99+
def to_dict(self) -> Dict[str, Dict[str, str]]:
100+
return {
101+
attr: {"encode": getattr(self, attr).encode, "decode": getattr(self, attr).decode}
102+
for attr in ["means", "scales", "quats", "opacities", "sh0", "shN"]
103+
}
104+
105+
@dataclass
106+
class CompressionConfig:
107+
# Use PLAS sort in compression or not
108+
use_sort: bool = True
109+
# Verbose or not
110+
verbose: bool = True
111+
# QP value for video coding
112+
qp: Optional[int] = field(default=None)
113+
# Number of cluster of VQ for shN compression
114+
n_clusters: int = 32768
115+
# Maps attribute names to their codec functions
116+
attribute_codec_registry: Optional[AttributeCodecs] = field(default_factory=lambda: AttributeCodecs())
117+
118+
def to_dict(self) -> Dict[str, Any]:
119+
"""
120+
Convert the CompressionConfig instance to a dictionary.
121+
If attribute_codec_registry is not None, it will be converted to a dictionary using its to_dict method.
122+
Fields with None values (use_sort, verbose) will be excluded from the resulting dictionary.
123+
"""
124+
result = {
125+
"use_sort": self.use_sort,
126+
"verbose": self.verbose,
127+
"n_clusters": self.n_clusters,
128+
}
129+
130+
if self.qp is not None:
131+
result["qp"] = self.qp
132+
133+
# handle attribute_codec_registry
134+
if self.attribute_codec_registry is not None:
135+
result["attribute_codec_registry"] = self.attribute_codec_registry.to_dict()
136+
137+
return result
138+
86139
@dataclass
87140
class Config:
88141
# Disable viewer
@@ -91,8 +144,12 @@ class Config:
91144
ckpt: Optional[List[str]] = None
92145
# Name of compression strategy to use
93146
compression: Optional[Literal["png", "entropy_coding", "hevc"]] = None
94-
# Quantization parameters when set to hevc
95-
qp: Optional[int] = None
147+
# # Quantization parameters when set to hevc
148+
# qp: Optional[int] = None
149+
# Configuration for compression methods
150+
compression_cfg: CompressionConfig = field(
151+
default_factory=CompressionConfig
152+
)
96153

97154
# Enable profiler
98155
profiler_enabled: bool = False
@@ -549,9 +606,11 @@ def __init__(
549606
if cfg.compression == "png":
550607
self.compression_method = PngCompression()
551608
elif cfg.compression == "entropy_coding":
552-
self.compression_method = EntropyCodingCompression()
609+
compression_cfg = cfg.compression_cfg.to_dict()
610+
self.compression_method = EntropyCodingCompression(**compression_cfg)
553611
elif cfg.compression == "hevc":
554-
self.compression_method = HevcCompression(qp=cfg.qp)
612+
compression_cfg = cfg.compression_cfg.to_dict()
613+
self.compression_method = HevcCompression(**compression_cfg)
555614
else:
556615
raise ValueError(f"Unknown compression strategy: {cfg.compression}")
557616

0 commit comments

Comments
 (0)