66from contextlib import nullcontext
77from dataclasses import dataclass , field
88from collections import defaultdict
9- from typing import Dict , List , Optional , Tuple , Union , ContextManager
9+ from typing import Dict , List , Optional , Tuple , Union , ContextManager , TypedDict , Any
1010
1111import imageio
1212import nerfview
@@ -77,12 +77,65 @@ def _create_schedule(self):
7777 )
7878
7979 def update_schedule (self , ** kwargs ):
80- """动态更新schedule参数"""
8180 for key , value in kwargs .items ():
8281 if hasattr (self , key ):
8382 setattr (self , key , value )
8483 self .schedule = self ._create_schedule ()
8584
85+ @dataclass
86+ class CodecConfig :
87+ encode : str
88+ decode : str
89+
90+ @dataclass
91+ class AttributeCodecs :
92+ means : CodecConfig = field (default_factory = lambda : CodecConfig ("_compress_png_16bit" , "_decompress_png_16bit" ))
93+ scales : CodecConfig = field (default_factory = lambda : CodecConfig ("_compress_factorized_ans" , "_decompress_factorized_ans" ))
94+ quats : CodecConfig = field (default_factory = lambda : CodecConfig ("_compress_factorized_ans" , "_decompress_factorized_ans" ))
95+ opacities : CodecConfig = field (default_factory = lambda : CodecConfig ("_compress_png" , "_decompress_png" ))
96+ sh0 : CodecConfig = field (default_factory = lambda : CodecConfig ("_compress_png" , "_decompress_png" ))
97+ shN : CodecConfig = field (default_factory = lambda : CodecConfig ("_compress_masked_kmeans" , "_decompress_masked_kmeans" ))
98+
99+ def to_dict (self ) -> Dict [str , Dict [str , str ]]:
100+ return {
101+ attr : {"encode" : getattr (self , attr ).encode , "decode" : getattr (self , attr ).decode }
102+ for attr in ["means" , "scales" , "quats" , "opacities" , "sh0" , "shN" ]
103+ }
104+
105+ @dataclass
106+ class CompressionConfig :
107+ # Use PLAS sort in compression or not
108+ use_sort : bool = True
109+ # Verbose or not
110+ verbose : bool = True
111+ # QP value for video coding
112+ qp : Optional [int ] = field (default = None )
113+ # Number of cluster of VQ for shN compression
114+ n_clusters : int = 32768
115+ # Maps attribute names to their codec functions
116+ attribute_codec_registry : Optional [AttributeCodecs ] = field (default_factory = lambda : AttributeCodecs ())
117+
118+ def to_dict (self ) -> Dict [str , Any ]:
119+ """
120+ Convert the CompressionConfig instance to a dictionary.
121+ If attribute_codec_registry is not None, it will be converted to a dictionary using its to_dict method.
122+ Fields with None values (use_sort, verbose) will be excluded from the resulting dictionary.
123+ """
124+ result = {
125+ "use_sort" : self .use_sort ,
126+ "verbose" : self .verbose ,
127+ "n_clusters" : self .n_clusters ,
128+ }
129+
130+ if self .qp is not None :
131+ result ["qp" ] = self .qp
132+
133+ # handle attribute_codec_registry
134+ if self .attribute_codec_registry is not None :
135+ result ["attribute_codec_registry" ] = self .attribute_codec_registry .to_dict ()
136+
137+ return result
138+
86139@dataclass
87140class Config :
88141 # Disable viewer
@@ -91,8 +144,12 @@ class Config:
91144 ckpt : Optional [List [str ]] = None
92145 # Name of compression strategy to use
93146 compression : Optional [Literal ["png" , "entropy_coding" , "hevc" ]] = None
94- # Quantization parameters when set to hevc
95- qp : Optional [int ] = None
147+ # # Quantization parameters when set to hevc
148+ # qp: Optional[int] = None
149+ # Configuration for compression methods
150+ compression_cfg : CompressionConfig = field (
151+ default_factory = CompressionConfig
152+ )
96153
97154 # Enable profiler
98155 profiler_enabled : bool = False
@@ -549,9 +606,11 @@ def __init__(
549606 if cfg .compression == "png" :
550607 self .compression_method = PngCompression ()
551608 elif cfg .compression == "entropy_coding" :
552- self .compression_method = EntropyCodingCompression ()
609+ compression_cfg = cfg .compression_cfg .to_dict ()
610+ self .compression_method = EntropyCodingCompression (** compression_cfg )
553611 elif cfg .compression == "hevc" :
554- self .compression_method = HevcCompression (qp = cfg .qp )
612+ compression_cfg = cfg .compression_cfg .to_dict ()
613+ self .compression_method = HevcCompression (** compression_cfg )
555614 else :
556615 raise ValueError (f"Unknown compression strategy: { cfg .compression } " )
557616
0 commit comments