|
| 1 | +# Post-Training Compression 重构设想 |
| 2 | + |
| 3 | +## 目标 |
| 4 | +- 支持统一的后处理压缩入口,可消费训练得到的 splats(来源为 PLY 或 checkpoint)。 |
| 5 | +- 提供 encode / decode 两个核心函数: |
| 6 | + - `encode()`:从输入载入 splats,执行 pruning、quantization、3D→2D 映射、视频/图像编码,输出码流或压缩文件集合。 |
| 7 | + - `decode()`:从码流恢复 splats,执行视频/图像解码、反量化等操作,可输出内存结构或保存为 PLY。 |
| 8 | +- 允许通过配置选择 codec 以及预处理策略,便于扩展新变种。 |
| 9 | + |
| 10 | +## 拟议架构 |
| 11 | +### 1. 顶层 Orchestrator(暂名 `PostTrainingCompressor`) |
| 12 | +- 初始化参数: |
| 13 | + - `input_spec`: 指定输入类型(`ply` 或 `ckpt`)及路径。 |
| 14 | + - `codec_config`: 指定编码方式(PNG / Entropy / HEVC / SeqYUV 等)及对应参数。 |
| 15 | + - `preprocess_config`: 可选,控制 pruning、排序、属性变换等。 |
| 16 | + - `quant_config`: 可选,控制按属性的量化位宽、截断范围等。 |
| 17 | +- 方法: |
| 18 | + - `encode()`:驱动全流程;返回压缩输出路径、元信息。 |
| 19 | + - `decode(compressed_dir)`:读取码流、还原 splats;可返回字典或写入 PLY。 |
| 20 | + |
| 21 | +### 2. 数据加载层 |
| 22 | +- `load_ckpt(path)`:从 checkpoint 中提取 `splats`(means/scales/quats/...)。 |
| 23 | +- `load_ply_sequence(path | list)`:加载单帧或序列 PLY,返回统一的张量字典。 |
| 24 | +- 两者都转换成标准的 `Dict[str, Tensor]` 供后续模块使用。 |
| 25 | + |
| 26 | +### 3. 编码流水线模块 |
| 27 | +- `PruningStage`:离群点过滤、mask 过滤等,可配置开关。 |
| 28 | +- `QuantizationStage`:基于属性设置 bitwidth、clamp range,复用现有 `_compress_*` 逻辑。 |
| 29 | +- `MappingStage`:排序/映射策略(PLAS、morton、无排序);负责 3D→2D 重排。 |
| 30 | +- `CodecStage`:调度 `gsplat/compression` 中的具体 codec 类。 |
| 31 | +- 元信息统一写入 `meta.json`(沿用现状)。 |
| 32 | + |
| 33 | +### 4. 解码流水线模块 |
| 34 | +- `CodecStage.decode`:调用 codec 的 `decompress`,获得属性张量。 |
| 35 | +- `DequantizationStage`:执行逆变换、逆量化;可复用 `inverse_log_transform` 等函数。 |
| 36 | +- 最终输出: |
| 37 | + - 内存中的 `Dict[str, Tensor]`,或 |
| 38 | + - 调用 `save_ply` 写入磁盘。 |
| 39 | + |
| 40 | +### 5. 配置结构 |
| 41 | +使用 dataclass 组织配置,便于 YAML/CLI: |
| 42 | +```python |
| 43 | +@dataclass |
| 44 | +class PTCompressionConfig: |
| 45 | + input_type: Literal["ply", "ckpt"] |
| 46 | + codec: Literal["png", "entropy", "hevc", "seq_hevc", "seq_yuv"] |
| 47 | + codec_params: Dict[str, Any] = field(default_factory=dict) |
| 48 | + preprocess: PreprocessConfig = field(default_factory=PreprocessConfig) |
| 49 | + quant: QuantConfig = field(default_factory=QuantConfig) |
| 50 | +``` |
| 51 | +其中 `PreprocessConfig`、`QuantConfig` 再细分字段(是否过滤、排序策略、bitwidth 等)。 |
| 52 | + |
| 53 | +### 6. CLI/脚本整合 |
| 54 | +- 重写/新增脚本(如 `examples/benchmarks/post_train_compress.sh`)调用 `PostTrainingCompressor`。 |
| 55 | +- YAML/CLI 用同一套配置,避免脚本内硬编码。 |
| 56 | +- 脚本流程:解析配置 → `encode()` → `decode()` → 统计评估 → 输出路径。 |
| 57 | + |
| 58 | +## 顾及到的补充需求 |
| 59 | +- 模块内部独立提供 PLY / CKPT 的读写接口(与 encode/decode 解耦)。 |
| 60 | +- 支持读取单帧或序列 splats,保留 `is_sequence` 或 `num_frames` 标识。 |
| 61 | +- `encode()` 前和 `decode()` 后输出参数分布直方图(matplotlib),但不写入 meta 数据。 |
| 62 | +- 实现时参考现有 `gsplat/compression` 的操作顺序,避免与现有流程背离。 |
| 63 | + |
| 64 | +## `PostTrainingComp` 草图 |
| 65 | +```python |
| 66 | +@dataclass |
| 67 | +class PostTrainingComp: |
| 68 | + input_spec: InputSpec |
| 69 | + codec_config: CodecConfig |
| 70 | + preprocess_cfg: PreprocessConfig |
| 71 | + quant_cfg: QuantConfig |
| 72 | + output_dir: Path |
| 73 | + |
| 74 | + splats: Union[Dict[str, Tensor], List[Dict[str, Tensor]]] = field(init=False) |
| 75 | + is_sequence: bool = field(init=False) |
| 76 | + codec: BaseCodec = field(init=False) |
| 77 | + metadata: Dict[str, Any] = field(default_factory=dict) |
| 78 | + |
| 79 | + def load_inputs(self) -> None: |
| 80 | + if self.input_spec.type == "ply": |
| 81 | + self.splats = load_ply(self.input_spec.path, as_sequence=self.input_spec.as_sequence) |
| 82 | + elif self.input_spec.type == "ckpt": |
| 83 | + self.splats = load_ckpt(self.input_spec.path) |
| 84 | + self.is_sequence = isinstance(self.splats, list) |
| 85 | + self.metadata["num_frames"] = len(self.splats) if self.is_sequence else 1 |
| 86 | + |
| 87 | + def encode(self) -> CompressionResult: |
| 88 | + self.load_inputs() |
| 89 | + self._plot_stats(self.splats, stage="before_encode") |
| 90 | + payload = self._run_encode_pipeline(self.splats) |
| 91 | + self._save_payload(payload) |
| 92 | + return CompressionResult(payload_path=..., metadata=self.metadata) |
| 93 | + |
| 94 | + def decode(self, payload_path: Path) -> DecodeResult: |
| 95 | + payload = self._load_payload(payload_path) |
| 96 | + decoded = self._run_decode_pipeline(payload) |
| 97 | + self._plot_stats(decoded, stage="after_decode") |
| 98 | + self._write_outputs(decoded) |
| 99 | + return DecodeResult(splats=decoded, saved_paths=...) |
| 100 | + |
| 101 | + def _run_encode_pipeline(self, splats): |
| 102 | + pruned = run_pruning(splats, self.preprocess_cfg) |
| 103 | + quantized = run_quantization(pruned, self.quant_cfg) |
| 104 | + mapped, mapping_ctx = run_mapping(quantized, self.preprocess_cfg.mapping) |
| 105 | + encoded = self.codec.encode(mapped, context=mapping_ctx) |
| 106 | + return {"encoded": encoded, "mapping_ctx": mapping_ctx} |
| 107 | + |
| 108 | + def _run_decode_pipeline(self, payload): |
| 109 | + decoded = self.codec.decode(payload["encoded"], context=payload.get("mapping_ctx")) |
| 110 | + unmapped = run_inverse_mapping(decoded, payload.get("mapping_ctx")) |
| 111 | + dequant = run_dequantization(unmapped, self.quant_cfg) |
| 112 | + restored = run_postprocess(dequant, self.preprocess_cfg) |
| 113 | + return restored |
| 114 | + |
| 115 | + def _plot_stats(self, splats, stage: str) -> None: |
| 116 | + # 遍历属性画直方图,保存到 output_dir/stage_* 下 |
| 117 | +``` |
| 118 | + |
| 119 | +- `InputSpec`、`CodecConfig`、`PreprocessConfig`、`QuantConfig` 等 dataclass 可进一步细化。 |
| 120 | +- `CompressionResult` / `DecodeResult` 用于统一返回路径、元数据。 |
| 121 | +- PLY/CKPT 的读写实现为独立的 util 函数。 |
| 122 | +- `BaseCodec` 为各 codec 的抽象基类,具体实现参考现有 PNG/Entropy/HEVC 等类。 |
| 123 | + |
| 124 | +## 未决问题 |
| 125 | +1. 是否需要一次性处理多个场景/帧并输出统一码流? |
| 126 | +2. 统计文件命名、保存位置是否需要可配置? |
| 127 | +3. 与现有脚本结合时,是否要提供默认 YAML 模板。 |
| 128 | + |
| 129 | +## 设计补充(2024-xx-xx) |
| 130 | +- `Quantization` / `Dequantization` 阶段为必选流程,需始终执行。 |
| 131 | +- `Pruning`、`Mapping`、`Codec` 阶段可以按需求启用或跳过,可通过配置显式控制。 |
| 132 | +- `PostTrainingComp` 需提供 `save_decoded_splats(decoded, destination)` 助手,用于在 `decode()` 完成后将还原的 splats 持久化到 PLY / CKPT 等目标格式,避免重复实现写盘逻辑。 |
0 commit comments