|
1 | 1 | import json |
2 | 2 | import os |
3 | | -from dataclasses import dataclass |
| 3 | +from dataclasses import dataclass, field, InitVar |
4 | 4 | import shutil |
5 | | -from typing import Any, Callable, Dict |
| 5 | +from typing import Any, Callable, Dict, Optional |
6 | 6 |
|
7 | 7 | import numpy as np |
8 | 8 | import torch |
@@ -46,36 +46,55 @@ class HevcCompression: |
46 | 46 |
|
47 | 47 | use_sort: bool = True |
48 | 48 | verbose: bool = True |
49 | | - qp: int = 4 |
50 | 49 | n_clusters: int = 32768 |
| 50 | + qp: int = 4 |
| 51 | + |
| 52 | + attribute_codec_registry: InitVar[Optional[Dict[str, str]]] = None |
| 53 | + |
| 54 | + compress_fn_map: Dict[str, Callable] = field(default_factory=lambda: { |
| 55 | + "means": _compress_png_16bit, |
| 56 | + "scales": _compress_hevc_kbit, |
| 57 | + "quats": _compress_quats_hevc_kbit, |
| 58 | + "opacities": _compress_hevc_kbit, |
| 59 | + "sh0": _compress_hevc_kbit, |
| 60 | + "shN": _compress_masked_kmeans, |
| 61 | + }) |
| 62 | + decompress_fn_map: Dict[str, Callable] = field(default_factory=lambda: { |
| 63 | + "means": _decompress_png_16bit, |
| 64 | + "scales": _decompress_hevc_kbit, |
| 65 | + "quats": _decompress_quats_hevc_kbit, |
| 66 | + "opacities": _decompress_hevc_kbit, |
| 67 | + "sh0": _decompress_hevc_kbit, |
| 68 | + "shN": _decompress_masked_kmeans, |
| 69 | + }) |
51 | 70 |
|
52 | 71 | def _get_compress_fn(self, param_name: str) -> Callable: |
53 | | - compress_fn_map = { |
54 | | - "means": _compress_png_16bit, |
55 | | - "scales": _compress_hevc_kbit, |
56 | | - "quats": _compress_quats_hevc_kbit, |
57 | | - "opacities": _compress_hevc_kbit, |
58 | | - "sh0": _compress_hevc_kbit, |
59 | | - # "shN": _compress_kmeans, |
60 | | - "shN": _compress_masked_kmeans, |
61 | | - } |
62 | | - if param_name in compress_fn_map: |
63 | | - return compress_fn_map[param_name] |
| 72 | + # compress_fn_map = { |
| 73 | + # "means": _compress_png_16bit, |
| 74 | + # "scales": _compress_hevc_kbit, |
| 75 | + # "quats": _compress_quats_hevc_kbit, |
| 76 | + # "opacities": _compress_hevc_kbit, |
| 77 | + # "sh0": _compress_hevc_kbit, |
| 78 | + # # "shN": _compress_kmeans, |
| 79 | + # "shN": _compress_masked_kmeans, |
| 80 | + # } |
| 81 | + if param_name in self.compress_fn_map: |
| 82 | + return self.compress_fn_map[param_name] |
64 | 83 | else: |
65 | 84 | return _compress_npz |
66 | 85 |
|
67 | 86 | def _get_decompress_fn(self, param_name: str) -> Callable: |
68 | | - decompress_fn_map = { |
69 | | - "means": _decompress_png_16bit, |
70 | | - "scales": _decompress_hevc_kbit, |
71 | | - "quats": _decompress_quats_hevc_kbit, |
72 | | - "opacities": _decompress_hevc_kbit, |
73 | | - "sh0": _decompress_hevc_kbit, |
74 | | - # "shN": _decompress_kmeans, |
75 | | - "shN": _decompress_masked_kmeans, |
76 | | - } |
77 | | - if param_name in decompress_fn_map: |
78 | | - return decompress_fn_map[param_name] |
| 87 | + # decompress_fn_map = { |
| 88 | + # "means": _decompress_png_16bit, |
| 89 | + # "scales": _decompress_hevc_kbit, |
| 90 | + # "quats": _decompress_quats_hevc_kbit, |
| 91 | + # "opacities": _decompress_hevc_kbit, |
| 92 | + # "sh0": _decompress_hevc_kbit, |
| 93 | + # # "shN": _decompress_kmeans, |
| 94 | + # "shN": _decompress_masked_kmeans, |
| 95 | + # } |
| 96 | + if param_name in self.decompress_fn_map: |
| 97 | + return self.decompress_fn_map[param_name] |
79 | 98 | else: |
80 | 99 | return _decompress_npz |
81 | 100 |
|
|
0 commit comments