Skip to content

Commit c215858

Browse files
committed
bugfix: apply new config to ply_loader_renderer.py
1 parent 2362099 commit c215858

File tree

2 files changed

+97
-32
lines changed

2 files changed

+97
-32
lines changed

examples/ply_loader_renderer.py

Lines changed: 53 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from contextlib import nullcontext
77
from dataclasses import dataclass, field
88
from collections import defaultdict
9-
from typing import Dict, List, Optional, Tuple, Union, ContextManager, TypedDict
9+
from typing import Dict, List, Optional, Tuple, Union, ContextManager, TypedDict, Any
1010

1111
import imageio
1212
import nerfview
@@ -83,15 +83,59 @@ def update_schedule(self, **kwargs):
8383
setattr(self, key, value)
8484
self.schedule = self._create_schedule()
8585

86-
class CompressionConfig(TypedDict, total=False):
86+
@dataclass
87+
class CodecConfig:
88+
encode: str
89+
decode: str
90+
91+
@dataclass
92+
class AttributeCodecs:
93+
means: CodecConfig = field(default_factory=lambda: CodecConfig("_compress_png_16bit", "_decompress_png_16bit"))
94+
scales: CodecConfig = field(default_factory=lambda: CodecConfig("_compress_factorized_ans", "_decompress_factorized_ans"))
95+
quats: CodecConfig = field(default_factory=lambda: CodecConfig("_compress_factorized_ans", "_decompress_factorized_ans"))
96+
opacities: CodecConfig = field(default_factory=lambda: CodecConfig("_compress_png", "_decompress_png"))
97+
sh0: CodecConfig = field(default_factory=lambda: CodecConfig("_compress_png", "_decompress_png"))
98+
shN: CodecConfig = field(default_factory=lambda: CodecConfig("_compress_masked_kmeans", "_decompress_masked_kmeans"))
99+
100+
def to_dict(self) -> Dict[str, Dict[str, str]]:
101+
return {
102+
attr: {"encode": getattr(self, attr).encode, "decode": getattr(self, attr).decode}
103+
for attr in ["means", "scales", "quats", "opacities", "sh0", "shN"]
104+
}
105+
106+
@dataclass
107+
class CompressionConfig:
87108
# Use PLAS sort in compression or not
88-
use_sort: Optional[bool] = field(default=None)
109+
use_sort: bool = True
89110
# Verbose or not
90-
verbose: Optional[bool] = field(default=None)
111+
verbose: bool = True
91112
# QP value for video coding
92-
qp: int = 4
113+
qp: Optional[int] = field(default=None)
93114
# Number of cluster of VQ for shN compression
94115
n_clusters: int = 32768
116+
# Maps attribute names to their codec functions
117+
attribute_codec_registry: Optional[AttributeCodecs] = field(default_factory=lambda: AttributeCodecs())
118+
119+
def to_dict(self) -> Dict[str, Any]:
120+
"""
121+
Convert the CompressionConfig instance to a dictionary.
122+
If attribute_codec_registry is not None, it will be converted to a dictionary using its to_dict method.
123+
Fields with None values (use_sort, verbose) will be excluded from the resulting dictionary.
124+
"""
125+
result = {
126+
"use_sort": self.use_sort,
127+
"verbose": self.verbose,
128+
"n_clusters": self.n_clusters,
129+
}
130+
131+
if self.qp is not None:
132+
result["qp"] = self.qp
133+
134+
# handle attribute_codec_registry
135+
if self.attribute_codec_registry is not None:
136+
result["attribute_codec_registry"] = self.attribute_codec_registry.to_dict()
137+
138+
return result
95139

96140
@dataclass
97141
class Config:
@@ -565,9 +609,11 @@ def __init__(
565609
if cfg.compression == "png":
566610
self.compression_method = PngCompression()
567611
elif cfg.compression == "entropy_coding":
568-
self.compression_method = EntropyCodingCompression()
612+
compression_cfg = cfg.compression_cfg.to_dict()
613+
self.compression_method = EntropyCodingCompression(**compression_cfg)
569614
elif cfg.compression == "hevc":
570-
self.compression_method = HevcCompression(**cfg.compression_cfg) # compression_cfg=cfg.compression_cfg
615+
compression_cfg = cfg.compression_cfg.to_dict()
616+
self.compression_method = HevcCompression(**compression_cfg)
571617
else:
572618
raise ValueError(f"Unknown compression strategy: {cfg.compression}")
573619

gsplat/compression/hevc_compression.py

Lines changed: 44 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import json
22
import os
3-
from dataclasses import dataclass
3+
from dataclasses import dataclass, field, InitVar
44
import shutil
5-
from typing import Any, Callable, Dict
5+
from typing import Any, Callable, Dict, Optional
66

77
import numpy as np
88
import torch
@@ -46,36 +46,55 @@ class HevcCompression:
4646

4747
use_sort: bool = True
4848
verbose: bool = True
49-
qp: int = 4
5049
n_clusters: int = 32768
50+
qp: int = 4
51+
52+
attribute_codec_registry: InitVar[Optional[Dict[str, str]]] = None
53+
54+
compress_fn_map: Dict[str, Callable] = field(default_factory=lambda: {
55+
"means": _compress_png_16bit,
56+
"scales": _compress_hevc_kbit,
57+
"quats": _compress_quats_hevc_kbit,
58+
"opacities": _compress_hevc_kbit,
59+
"sh0": _compress_hevc_kbit,
60+
"shN": _compress_masked_kmeans,
61+
})
62+
decompress_fn_map: Dict[str, Callable] = field(default_factory=lambda: {
63+
"means": _decompress_png_16bit,
64+
"scales": _decompress_hevc_kbit,
65+
"quats": _decompress_quats_hevc_kbit,
66+
"opacities": _decompress_hevc_kbit,
67+
"sh0": _decompress_hevc_kbit,
68+
"shN": _decompress_masked_kmeans,
69+
})
5170

5271
def _get_compress_fn(self, param_name: str) -> Callable:
53-
compress_fn_map = {
54-
"means": _compress_png_16bit,
55-
"scales": _compress_hevc_kbit,
56-
"quats": _compress_quats_hevc_kbit,
57-
"opacities": _compress_hevc_kbit,
58-
"sh0": _compress_hevc_kbit,
59-
# "shN": _compress_kmeans,
60-
"shN": _compress_masked_kmeans,
61-
}
62-
if param_name in compress_fn_map:
63-
return compress_fn_map[param_name]
72+
# compress_fn_map = {
73+
# "means": _compress_png_16bit,
74+
# "scales": _compress_hevc_kbit,
75+
# "quats": _compress_quats_hevc_kbit,
76+
# "opacities": _compress_hevc_kbit,
77+
# "sh0": _compress_hevc_kbit,
78+
# # "shN": _compress_kmeans,
79+
# "shN": _compress_masked_kmeans,
80+
# }
81+
if param_name in self.compress_fn_map:
82+
return self.compress_fn_map[param_name]
6483
else:
6584
return _compress_npz
6685

6786
def _get_decompress_fn(self, param_name: str) -> Callable:
68-
decompress_fn_map = {
69-
"means": _decompress_png_16bit,
70-
"scales": _decompress_hevc_kbit,
71-
"quats": _decompress_quats_hevc_kbit,
72-
"opacities": _decompress_hevc_kbit,
73-
"sh0": _decompress_hevc_kbit,
74-
# "shN": _decompress_kmeans,
75-
"shN": _decompress_masked_kmeans,
76-
}
77-
if param_name in decompress_fn_map:
78-
return decompress_fn_map[param_name]
87+
# decompress_fn_map = {
88+
# "means": _decompress_png_16bit,
89+
# "scales": _decompress_hevc_kbit,
90+
# "quats": _decompress_quats_hevc_kbit,
91+
# "opacities": _decompress_hevc_kbit,
92+
# "sh0": _decompress_hevc_kbit,
93+
# # "shN": _decompress_kmeans,
94+
# "shN": _decompress_masked_kmeans,
95+
# }
96+
if param_name in self.decompress_fn_map:
97+
return self.decompress_fn_map[param_name]
7998
else:
8099
return _decompress_npz
81100

0 commit comments

Comments
 (0)