Skip to content

Commit 91a3300

Browse files
committed
refactor: implement adaptive masking strategies in compression simulation
1 parent 02fd977 commit 91a3300

File tree

12 files changed

+940
-46
lines changed

12 files changed

+940
-46
lines changed

dev_docs/config_usage.md

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,3 +35,57 @@ python examples/simple_trainer.py default --config logs/run_001.yaml
3535

3636
> 提示:策略(`strategy`)字段会以 `{type: DefaultStrategy, params: {...}}` 形式写入 YAML,方便直接调整或复用。
3737
38+
39+
## Adaptive Mask 重构设计(进行中)
40+
41+
- **目标**:将自适应 SHN mask 重构为独立子模块,和可微分量化、熵约束保持同等抽象层级;Trainer 与 orchestrator 仅通过统一接口消费 mask 输出/损失/指标。
42+
- **配置结构**
43+
```python
44+
@dataclass
45+
class LearnableMaskSettings:
46+
start_temp: float = 5.0
47+
end_temp: float = 0.1
48+
total_iters: int = 30_000
49+
target_sparsity: float = 0.2
50+
lr: float = 1e-2
51+
52+
@dataclass
53+
class GradientMaskSettings:
54+
grad_threshold: float = 2e-3 # 小于该阈值的梯度会被清零
55+
56+
@dataclass
57+
class MaskConfig:
58+
enabled: bool = False
59+
strategy: Optional[str] = "learnable"
60+
start_step: int = 10_000
61+
learnable: LearnableMaskSettings = field(default_factory=LearnableMaskSettings)
62+
gradient: GradientMaskSettings = field(default_factory=GradientMaskSettings)
63+
```
64+
YAML 示例:
65+
```yaml
66+
compression_sim_cfg:
67+
mask:
68+
enabled: true
69+
strategy: gradient
70+
start_step: 8000
71+
gradient:
72+
grad_threshold: 0.003
73+
```
74+
- **模块划分**:新增 `gsplat/compression_simulation/mask.py`,定义 `MaskResult`、`AdaptiveMaskBase` 接口;`AdaptiveMaskFactory` 根据策略返回 `LearnableAdaptiveMask`(封装 `AnnealingMask`)、`GradientAdaptiveMask`(梯度阈值裁剪)或 `NullAdaptiveMask`。
75+
- **Gradient 策略**:
76+
- `maybe_update` 统计 SHN 非零比例,供日志与阈值计算。
77+
- `apply` 在 `step > start_step` 时注册一次 `tensor.register_hook`,返回 `MaskResult(value=tensor, metrics={"mask_ratio": ..., "mask_grad_threshold": ...})`。
78+
- Hook 逻辑:
79+
```python
80+
def grad_hook(grad):
81+
shn = tensor.detach()
82+
zero_mask = (shn == 0).all(dim=-1).all(dim=-1)
83+
grad_norm = grad.flatten(2).norm(p=2, dim=-1)
84+
mask = ~(zero_mask & (grad_norm < cfg.gradient.grad_threshold))
85+
while mask.ndim < grad.ndim:
86+
mask = mask.unsqueeze(-1)
87+
return grad * mask
88+
```
89+
- **Orchestrator 接入**:`CompressionSimulation.run()` 中调用 `mask.maybe_update(step, splats)`;对 `shN` 应用 `mask.apply` 并汇总 `loss_terms` 与 `metrics`;`step_optimizers`/`state_dict` 同步 mask 状态。
90+
- **Trainer/日志**:训练循环只消费 `SimulationResult.loss_terms.get("mask")` 与 `metrics["mask_ratio"]`;导出 PLY 时通过 mask 模块暴露的 `get_binary_mask()` 获取最终掩码。
91+
- **测试计划**:新增 `tests/test_adaptive_mask.py` 覆盖 learnable(温度调度、loss、优化器)、gradient(阈值裁剪、指标)与 null(passthrough)。

dev_docs/post-training_comp.md

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
# Post-Training Compression 重构设想
2+
3+
## 目标
4+
- 支持统一的后处理压缩入口,可消费训练得到的 splats(来源为 PLY 或 checkpoint)。
5+
- 提供 encode / decode 两个核心函数:
6+
- `encode()`:从输入载入 splats,执行 pruning、quantization、3D→2D 映射、视频/图像编码,输出码流或压缩文件集合。
7+
- `decode()`:从码流恢复 splats,执行视频/图像解码、反量化等操作,可输出内存结构或保存为 PLY。
8+
- 允许通过配置选择 codec 以及预处理策略,便于扩展新变种。
9+
10+
## 拟议架构
11+
### 1. 顶层 Orchestrator(暂名 `PostTrainingCompressor`
12+
- 初始化参数:
13+
- `input_spec`: 指定输入类型(`ply``ckpt`)及路径。
14+
- `codec_config`: 指定编码方式(PNG / Entropy / HEVC / SeqYUV 等)及对应参数。
15+
- `preprocess_config`: 可选,控制 pruning、排序、属性变换等。
16+
- `quant_config`: 可选,控制按属性的量化位宽、截断范围等。
17+
- 方法:
18+
- `encode()`:驱动全流程;返回压缩输出路径、元信息。
19+
- `decode(compressed_dir)`:读取码流、还原 splats;可返回字典或写入 PLY。
20+
21+
### 2. 数据加载层
22+
- `load_ckpt(path)`:从 checkpoint 中提取 `splats`(means/scales/quats/...)。
23+
- `load_ply_sequence(path | list)`:加载单帧或序列 PLY,返回统一的张量字典。
24+
- 两者都转换成标准的 `Dict[str, Tensor]` 供后续模块使用。
25+
26+
### 3. 编码流水线模块
27+
- `PruningStage`:离群点过滤、mask 过滤等,可配置开关。
28+
- `QuantizationStage`:基于属性设置 bitwidth、clamp range,复用现有 `_compress_*` 逻辑。
29+
- `MappingStage`:排序/映射策略(PLAS、morton、无排序);负责 3D→2D 重排。
30+
- `CodecStage`:调度 `gsplat/compression` 中的具体 codec 类。
31+
- 元信息统一写入 `meta.json`(沿用现状)。
32+
33+
### 4. 解码流水线模块
34+
- `CodecStage.decode`:调用 codec 的 `decompress`,获得属性张量。
35+
- `DequantizationStage`:执行逆变换、逆量化;可复用 `inverse_log_transform` 等函数。
36+
- 最终输出:
37+
- 内存中的 `Dict[str, Tensor]`,或
38+
- 调用 `save_ply` 写入磁盘。
39+
40+
### 5. 配置结构
41+
使用 dataclass 组织配置,便于 YAML/CLI:
42+
```python
43+
@dataclass
44+
class PTCompressionConfig:
45+
input_type: Literal["ply", "ckpt"]
46+
codec: Literal["png", "entropy", "hevc", "seq_hevc", "seq_yuv"]
47+
codec_params: Dict[str, Any] = field(default_factory=dict)
48+
preprocess: PreprocessConfig = field(default_factory=PreprocessConfig)
49+
quant: QuantConfig = field(default_factory=QuantConfig)
50+
```
51+
其中 `PreprocessConfig``QuantConfig` 再细分字段(是否过滤、排序策略、bitwidth 等)。
52+
53+
### 6. CLI/脚本整合
54+
- 重写/新增脚本(如 `examples/benchmarks/post_train_compress.sh`)调用 `PostTrainingCompressor`
55+
- YAML/CLI 用同一套配置,避免脚本内硬编码。
56+
- 脚本流程:解析配置 → `encode()``decode()` → 统计评估 → 输出路径。
57+
58+
## 顾及到的补充需求
59+
- 模块内部独立提供 PLY / CKPT 的读写接口(与 encode/decode 解耦)。
60+
- 支持读取单帧或序列 splats,保留 `is_sequence``num_frames` 标识。
61+
- `encode()` 前和 `decode()` 后输出参数分布直方图(matplotlib),但不写入 meta 数据。
62+
- 实现时参考现有 `gsplat/compression` 的操作顺序,避免与现有流程背离。
63+
64+
## `PostTrainingComp` 草图
65+
```python
66+
@dataclass
67+
class PostTrainingComp:
68+
input_spec: InputSpec
69+
codec_config: CodecConfig
70+
preprocess_cfg: PreprocessConfig
71+
quant_cfg: QuantConfig
72+
output_dir: Path
73+
74+
splats: Union[Dict[str, Tensor], List[Dict[str, Tensor]]] = field(init=False)
75+
is_sequence: bool = field(init=False)
76+
codec: BaseCodec = field(init=False)
77+
metadata: Dict[str, Any] = field(default_factory=dict)
78+
79+
def load_inputs(self) -> None:
80+
if self.input_spec.type == "ply":
81+
self.splats = load_ply(self.input_spec.path, as_sequence=self.input_spec.as_sequence)
82+
elif self.input_spec.type == "ckpt":
83+
self.splats = load_ckpt(self.input_spec.path)
84+
self.is_sequence = isinstance(self.splats, list)
85+
self.metadata["num_frames"] = len(self.splats) if self.is_sequence else 1
86+
87+
def encode(self) -> CompressionResult:
88+
self.load_inputs()
89+
self._plot_stats(self.splats, stage="before_encode")
90+
payload = self._run_encode_pipeline(self.splats)
91+
self._save_payload(payload)
92+
return CompressionResult(payload_path=..., metadata=self.metadata)
93+
94+
def decode(self, payload_path: Path) -> DecodeResult:
95+
payload = self._load_payload(payload_path)
96+
decoded = self._run_decode_pipeline(payload)
97+
self._plot_stats(decoded, stage="after_decode")
98+
self._write_outputs(decoded)
99+
return DecodeResult(splats=decoded, saved_paths=...)
100+
101+
def _run_encode_pipeline(self, splats):
102+
pruned = run_pruning(splats, self.preprocess_cfg)
103+
quantized = run_quantization(pruned, self.quant_cfg)
104+
mapped, mapping_ctx = run_mapping(quantized, self.preprocess_cfg.mapping)
105+
encoded = self.codec.encode(mapped, context=mapping_ctx)
106+
return {"encoded": encoded, "mapping_ctx": mapping_ctx}
107+
108+
def _run_decode_pipeline(self, payload):
109+
decoded = self.codec.decode(payload["encoded"], context=payload.get("mapping_ctx"))
110+
unmapped = run_inverse_mapping(decoded, payload.get("mapping_ctx"))
111+
dequant = run_dequantization(unmapped, self.quant_cfg)
112+
restored = run_postprocess(dequant, self.preprocess_cfg)
113+
return restored
114+
115+
def _plot_stats(self, splats, stage: str) -> None:
116+
# 遍历属性画直方图,保存到 output_dir/stage_* 下
117+
```
118+
119+
- `InputSpec``CodecConfig``PreprocessConfig``QuantConfig` 等 dataclass 可进一步细化。
120+
- `CompressionResult` / `DecodeResult` 用于统一返回路径、元数据。
121+
- PLY/CKPT 的读写实现为独立的 util 函数。
122+
- `BaseCodec` 为各 codec 的抽象基类,具体实现参考现有 PNG/Entropy/HEVC 等类。
123+
124+
## 未决问题
125+
1. 是否需要一次性处理多个场景/帧并输出统一码流?
126+
2. 统计文件命名、保存位置是否需要可配置?
127+
3. 与现有脚本结合时,是否要提供默认 YAML 模板。
128+
129+
## 设计补充(2024-xx-xx)
130+
- `Quantization` / `Dequantization` 阶段为必选流程,需始终执行。
131+
- `Pruning``Mapping``Codec` 阶段可以按需求启用或跳过,可通过配置显式控制。
132+
- `PostTrainingComp` 需提供 `save_decoded_splats(decoded, destination)` 助手,用于在 `decode()` 完成后将还原的 splats 持久化到 PLY / CKPT 等目标格式,避免重复实现写盘逻辑。
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
# ----------------- Training Setting-------------- #
2+
SCENE_DIR="data/tandt"
3+
# eval all 9 scenes for benchmarking
4+
SCENE_LIST="train truck" # train truck
5+
# SCENE_LIST="garden bicycle stump bonsai counter kitchen room treehill flowers"
6+
7+
# # 0.36M GSs
8+
# RESULT_DIR="results/benchmark_tt_mcmc_0_36M_png_compression"
9+
# CAP_MAX=360000
10+
11+
# # 0.49M GSs
12+
# RESULT_DIR="results/benchmark_tt_mcmc_tt_0_49M_png_compression"
13+
# CAP_MAX=490000
14+
15+
# 1M GSs
16+
RESULT_DIR="results/new_cfg_tt"
17+
CAP_MAX=1000000
18+
19+
# # 4M GSs
20+
# RESULT_DIR="results/benchmark_tt_mcmc_4M_png_compression"
21+
# CAP_MAX=4000000
22+
23+
RD_LAMBDA=0.01
24+
25+
# ----------------- Training Setting-------------- #
26+
27+
# ----------------- Args ------------------------- #
28+
29+
if [ ! -z "$1" ]; then
30+
RD_LAMBDA="$1"
31+
RESULT_DIR="results/new_cfg_tt_rd_lambda_${RD_LAMBDA}"
32+
fi
33+
34+
# ----------------- Args ------------------------- #
35+
36+
# ----------------- Main Job --------------------- #
37+
run_single_scene() {
38+
local GPU_ID=$1
39+
local SCENE=$2
40+
41+
echo "Running $SCENE on GPU: $GPU_ID"
42+
43+
# train without eval
44+
CUDA_VISIBLE_DEVICES=$GPU_ID python simple_trainer.py mcmc --eval_steps -1 --disable_viewer --data_factor 1 \
45+
--strategy.cap-max $CAP_MAX \
46+
--data_dir $SCENE_DIR/$SCENE/ \
47+
--result_dir $RESULT_DIR/$SCENE/ \
48+
--compression_sim \
49+
--entropy_model_opt \
50+
--rd_lambda $RD_LAMBDA \
51+
--shN_ada_mask_opt \
52+
--compression png
53+
54+
55+
# eval: use vgg for lpips to align with other benchmarks
56+
CUDA_VISIBLE_DEVICES=$GPU_ID python simple_trainer.py mcmc --disable_viewer --data_factor 1 \
57+
--strategy.cap-max $CAP_MAX \
58+
--data_dir $SCENE_DIR/$SCENE/ \
59+
--result_dir $RESULT_DIR/$SCENE/ \
60+
--lpips_net vgg \
61+
--ckpt $RESULT_DIR/$SCENE/ckpts/ckpt_29999_rank0.pt \
62+
--compression png
63+
64+
}
65+
# ----------------- Main Job --------------------- #
66+
67+
68+
69+
# ----------------- Experiment Loop -------------- #
70+
GPU_LIST=(6 7)
71+
GPU_COUNT=${#GPU_LIST[@]}
72+
73+
SCENE_IDX=-1
74+
75+
for SCENE in $SCENE_LIST;
76+
do
77+
SCENE_IDX=$((SCENE_IDX + 1))
78+
{
79+
run_single_scene ${GPU_LIST[$SCENE_IDX]} $SCENE
80+
} #&
81+
82+
done
83+
84+
# ----------------- Experiment Loop -------------- #
85+
86+
# Wait for finishing the jobs across all scenes
87+
wait
88+
echo "All scenes finished."
89+
90+
# Zip the compressed files and summarize the stats
91+
if command -v zip &> /dev/null
92+
then
93+
echo "Zipping results"
94+
python benchmarks/compression/summarize_stats.py --results_dir $RESULT_DIR --scenes $SCENE_LIST
95+
else
96+
echo "zip command not found, skipping zipping"
97+
fi

0 commit comments

Comments
 (0)