Skip to content

Commit 93702be

Browse files
committed
refactor & doc
1 parent 91a3300 commit 93702be

File tree

71 files changed

+9279
-1206
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

71 files changed

+9279
-1206
lines changed

.gitignore

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,4 +156,7 @@ examples/helper/src/
156156

157157
logs/
158158

159-
AGENTS.md
159+
AGENTS.md
160+
*.pdf
161+
*.tex
162+
.cursor

cfg_sys/__init__.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
"""Configuration system utilities for GSCodec Studio.
2+
3+
This package centralizes helpers that manage experiment configurations,
4+
allowing multiple entry points to share a consistent set of behaviors for
5+
parsing overrides, instantiating modules, and recording snapshots.
6+
"""
7+
8+
from .config_manager import (
9+
CONFIG_FILE_FLAGS,
10+
CONFIG_SAVE_FLAGS,
11+
DEFAULT_CONFIG_SNAPSHOT,
12+
apply_updates,
13+
build_from_registry,
14+
coerce_value,
15+
load_config_updates,
16+
pop_flag_value,
17+
prepare_presets,
18+
save_config_snapshot,
19+
serialize_config_value,
20+
synchronize_compression_config,
21+
)
22+
23+
__all__ = [
24+
"CONFIG_FILE_FLAGS",
25+
"CONFIG_SAVE_FLAGS",
26+
"DEFAULT_CONFIG_SNAPSHOT",
27+
"apply_updates",
28+
"build_from_registry",
29+
"coerce_value",
30+
"load_config_updates",
31+
"pop_flag_value",
32+
"prepare_presets",
33+
"save_config_snapshot",
34+
"serialize_config_value",
35+
"synchronize_compression_config",
36+
]

cfg_sys/config_manager.py

Lines changed: 318 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,318 @@
1+
"""Reusable configuration utilities for experiment entry points."""
2+
3+
from __future__ import annotations
4+
5+
import copy
6+
import sys
7+
from collections.abc import Mapping, MutableMapping, Sequence
8+
from dataclasses import fields, is_dataclass
9+
from inspect import isclass
10+
from pathlib import Path
11+
from typing import Any, Callable, Dict, Iterable, Optional, Tuple
12+
13+
import yaml
14+
15+
# Flag naming conventions shared across entry points.
16+
CONFIG_FILE_FLAGS: Tuple[str, ...] = ("--config", "--config-path", "-c")
17+
CONFIG_SAVE_FLAGS: Tuple[str, ...] = ("--save-config", "--config-save")
18+
DEFAULT_CONFIG_SNAPSHOT = "config_snapshot.yaml"
19+
20+
# Type aliases for clarity.
21+
Registry = Mapping[str, Callable[..., Any]]
22+
ApplyFn = Callable[[Any, Dict[str, Any], str], None]
23+
RegistryHandler = Callable[[Any, Optional[Any], str, ApplyFn], Any]
24+
Serializer = Callable[[Any, Callable[[Any], Any]], Optional[Any]]
25+
26+
27+
def pop_flag_value(flag_names: Sequence[str], argv: Optional[Sequence[str]] = None) -> Optional[str]:
28+
"""Remove and return the value for the first matching flag in ``argv``."""
29+
30+
argv_list = sys.argv if argv is None else list(argv)
31+
i = 1
32+
while i < len(argv_list):
33+
arg = argv_list[i]
34+
for flag in flag_names:
35+
if arg == flag:
36+
if i + 1 >= len(argv_list):
37+
raise ValueError(f"Flag {flag} requires a value.")
38+
value = argv_list[i + 1]
39+
del argv_list[i : i + 2]
40+
if argv is None:
41+
sys.argv[:] = argv_list
42+
return value
43+
if arg.startswith(f"{flag}="):
44+
value = arg.split("=", 1)[1]
45+
del argv_list[i]
46+
if argv is None:
47+
sys.argv[:] = argv_list
48+
return value
49+
i += 1
50+
if argv is None:
51+
sys.argv[:] = argv_list
52+
return None
53+
54+
55+
def load_config_updates(config_path: Path) -> Dict[str, Any]:
56+
"""Load overrides from a YAML file."""
57+
58+
with config_path.open("r") as f:
59+
data = yaml.safe_load(f) or {}
60+
if not isinstance(data, MutableMapping):
61+
raise ValueError(f"Config file {config_path} must contain a mapping at the top level.")
62+
return dict(data)
63+
64+
65+
def coerce_value(reference: Any, value: Any) -> Any:
66+
"""Best-effort coercion to keep tuple/list shapes consistent."""
67+
68+
if isinstance(reference, tuple) and isinstance(value, list):
69+
return type(reference)(value)
70+
if isinstance(reference, tuple) and isinstance(value, tuple):
71+
return type(reference)(value)
72+
if isinstance(reference, list) and isinstance(value, tuple):
73+
return list(value)
74+
if isinstance(reference, list) and not isinstance(value, (list, tuple)):
75+
return [value]
76+
return value
77+
78+
79+
def apply_updates(
80+
target: Any,
81+
updates: Dict[str, Any],
82+
path: str = "cfg",
83+
registry_handlers: Optional[Dict[str, RegistryHandler]] = None,
84+
) -> None:
85+
"""Recursively apply ``updates`` onto ``target``."""
86+
87+
if not isinstance(updates, MutableMapping):
88+
raise ValueError(f"Expected mapping for updates at {path}, got {type(updates).__name__}")
89+
90+
registry_handlers = registry_handlers or {}
91+
92+
def _apply_nested(nested_target: Any, nested_updates: Dict[str, Any], nested_path: str) -> None:
93+
apply_updates(nested_target, nested_updates, nested_path, registry_handlers)
94+
95+
if is_dataclass(target):
96+
field_lookup = {f.name: f for f in fields(target)}
97+
for key, value in updates.items():
98+
if key == "type":
99+
continue
100+
if key not in field_lookup:
101+
raise KeyError(f"Unknown configuration key '{path}.{key}'")
102+
103+
current_value = getattr(target, key)
104+
next_path = f"{path}.{key}"
105+
106+
if key in registry_handlers:
107+
handler = registry_handlers[key]
108+
new_value = handler(value, current_value, next_path, _apply_nested)
109+
setattr(target, key, new_value)
110+
continue
111+
112+
if is_dataclass(current_value) and isinstance(value, MutableMapping):
113+
_apply_nested(current_value, dict(value), next_path)
114+
continue
115+
116+
if isinstance(current_value, MutableMapping) and isinstance(value, MutableMapping):
117+
merged = copy.deepcopy(current_value)
118+
for sub_key, sub_value in value.items():
119+
if (
120+
sub_key in merged
121+
and is_dataclass(merged[sub_key])
122+
and isinstance(sub_value, MutableMapping)
123+
):
124+
apply_updates(
125+
merged[sub_key],
126+
dict(sub_value),
127+
f"{next_path}.{sub_key}",
128+
registry_handlers,
129+
)
130+
else:
131+
merged[sub_key] = sub_value
132+
setattr(target, key, merged)
133+
continue
134+
135+
setattr(target, key, coerce_value(current_value, value))
136+
return
137+
138+
if isinstance(target, MutableMapping):
139+
for key, value in updates.items():
140+
target[key] = value
141+
return
142+
143+
raise ValueError(f"Unsupported target type '{type(target).__name__}' at {path}")
144+
145+
146+
def build_from_registry(
147+
value: Any,
148+
registry: Registry,
149+
current: Optional[Any] = None,
150+
*,
151+
apply_fn: Optional[ApplyFn] = None,
152+
path: str = "cfg",
153+
) -> Any:
154+
"""Instantiate or update an object using a registry mapping."""
155+
156+
if any(isinstance(value, cls) for cls in registry.values() if isclass(cls)):
157+
return value
158+
159+
if isinstance(value, str):
160+
if value not in registry:
161+
raise ValueError(f"Unsupported registry type '{value}' at {path}")
162+
factory = registry[value]
163+
return factory()
164+
165+
if isinstance(value, MutableMapping):
166+
params = dict(value)
167+
type_name = params.pop("type", None)
168+
params_dict = params.pop("params", None)
169+
if params and params_dict is not None:
170+
params_dict.update(params)
171+
args = params_dict if params_dict is not None else params
172+
173+
if type_name is None:
174+
if current is None:
175+
raise ValueError(f"Registry entry type must be specified at {path}")
176+
if args and apply_fn is not None:
177+
apply_fn(current, args, path)
178+
return current
179+
180+
if type_name not in registry:
181+
raise ValueError(f"Unsupported registry type '{type_name}' at {path}")
182+
factory = registry[type_name]
183+
184+
expected_type = factory if isclass(factory) else None
185+
if current is not None and expected_type is not None and isinstance(current, expected_type):
186+
if args and apply_fn is not None:
187+
apply_fn(current, args, f"{path}.{type_name}")
188+
return current
189+
190+
if args:
191+
return factory(**args)
192+
return factory()
193+
194+
raise ValueError(f"Unsupported registry specification at {path}: {value!r}")
195+
196+
197+
def prepare_presets(
198+
presets: Mapping[str, Tuple[str, Any]],
199+
updates: Optional[Dict[str, Any]],
200+
*,
201+
registry_handlers: Optional[Dict[str, RegistryHandler]] = None,
202+
) -> Dict[str, Tuple[str, Any]]:
203+
"""Return deep-copied presets with overrides applied."""
204+
205+
prepared: Dict[str, Tuple[str, Any]] = {}
206+
for name, (description, cfg) in presets.items():
207+
cfg_copy = copy.deepcopy(cfg)
208+
if updates:
209+
apply_updates(cfg_copy, dict(updates), registry_handlers=registry_handlers)
210+
prepared[name] = (description, cfg_copy)
211+
return prepared
212+
213+
214+
def serialize_config_value(
215+
value: Any,
216+
*,
217+
custom_serializers: Optional[Iterable[Serializer]] = None,
218+
) -> Any:
219+
"""Recursively convert configuration values into YAML-friendly objects."""
220+
221+
# 添加内置序列化器
222+
def path_serializer(obj: Any, _serialize: Callable) -> Any:
223+
from pathlib import Path
224+
if isinstance(obj, Path):
225+
return str(obj)
226+
return None
227+
228+
def torch_dtype_serializer(obj: Any, _serialize: Callable) -> Any:
229+
try:
230+
import torch
231+
if isinstance(obj, torch.dtype):
232+
return str(obj)
233+
except ImportError:
234+
pass
235+
return None
236+
237+
# 将内置序列化器添加到自定义序列化器列表前面
238+
serializers = [
239+
path_serializer,
240+
torch_dtype_serializer,
241+
*(custom_serializers or [])
242+
]
243+
244+
def _serialize(obj: Any) -> Any:
245+
for serializer in serializers:
246+
result = serializer(obj, _serialize)
247+
if result is not None:
248+
return result
249+
250+
if is_dataclass(obj):
251+
return {f.name: _serialize(getattr(obj, f.name)) for f in fields(obj)}
252+
if isinstance(obj, Mapping):
253+
return {k: _serialize(v) for k, v in obj.items()}
254+
if isinstance(obj, (list, tuple)):
255+
return [_serialize(v) for v in obj]
256+
return obj
257+
258+
return _serialize(value)
259+
260+
261+
def save_config_snapshot(
262+
cfg: Any,
263+
destination: Path,
264+
*,
265+
custom_serializers: Optional[Iterable[Serializer]] = None,
266+
) -> None:
267+
"""Persist a configuration dataclass to ``destination`` as YAML."""
268+
269+
destination.parent.mkdir(parents=True, exist_ok=True)
270+
serializable = {
271+
f.name: serialize_config_value(getattr(cfg, f.name), custom_serializers=custom_serializers)
272+
for f in fields(cfg)
273+
}
274+
with destination.open("w") as f:
275+
yaml.safe_dump(serializable, f, sort_keys=False)
276+
277+
278+
def synchronize_compression_config(cfg: Any) -> None:
279+
"""Keep compression-related configuration flags in sync."""
280+
281+
comp_cfg = cfg.compression_sim_cfg
282+
283+
if getattr(cfg, "compression_sim", False):
284+
comp_cfg.enabled = True
285+
else:
286+
cfg.compression_sim = comp_cfg.enabled
287+
288+
if getattr(cfg, "entropy_model_opt", False):
289+
comp_cfg.entropy.enabled = True
290+
else:
291+
cfg.entropy_model_opt = comp_cfg.entropy.enabled
292+
293+
comp_cfg.entropy.model_type = getattr(cfg, "entropy_model_type", None) or comp_cfg.entropy.model_type
294+
cfg.entropy_model_type = comp_cfg.entropy.model_type
295+
296+
entropy_steps = getattr(cfg, "entropy_steps", None)
297+
if entropy_steps:
298+
comp_cfg.entropy.steps.update(entropy_steps)
299+
cfg.entropy_steps = comp_cfg.entropy.steps
300+
301+
if getattr(cfg, "shN_ada_mask_opt", False):
302+
comp_cfg.mask.enabled = True
303+
else:
304+
cfg.shN_ada_mask_opt = comp_cfg.mask.enabled
305+
306+
strategy = getattr(cfg, "shN_ada_mask_strategy", None)
307+
if strategy is not None:
308+
comp_cfg.mask.strategy = strategy
309+
elif comp_cfg.mask.strategy is not None:
310+
cfg.shN_ada_mask_strategy = comp_cfg.mask.strategy
311+
312+
mask_steps = getattr(cfg, "ada_mask_steps", None)
313+
if mask_steps is not None:
314+
comp_cfg.mask.start_step = mask_steps
315+
elif comp_cfg.mask.start_step is not None:
316+
cfg.ada_mask_steps = comp_cfg.mask.start_step
317+
318+
cfg.compression_sim_cfg = comp_cfg

0 commit comments

Comments
 (0)