Skip to content

Commit 677dfe9

Browse files
GeorgOstrovskiThe ml_collections Authors
authored andcommitted
Internal change only.
PiperOrigin-RevId: 797749449
1 parent 95c20f4 commit 677dfe9

File tree

2 files changed

+86
-0
lines changed

2 files changed

+86
-0
lines changed

ml_collections/config_dict/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,11 @@
2727
from .config_dict import required_placeholder
2828
from .config_dict import RequiredValueError
2929

30+
from .internal_utils import BestEffortCustomJSONEncoder
31+
from .internal_utils import Jsonable
32+
from .internal_utils import update_in_place
33+
from .internal_utils import copy_and_update
34+
3035
__all__ = ("_Op", "ConfigDict", "create", "CustomJSONEncoder", "FieldReference",
3136
"FrozenConfigDict", "JSONDecodeError", "MutabilityError",
3237
"placeholder", "recursive_rename", "required_placeholder",

ml_collections/config_dict/config_dict.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2108,3 +2108,84 @@ def recursive_rename(conf, old_name, new_name):
21082108
else:
21092109
setattr(new_conf, name, new_c)
21102110
return new_conf
2111+
2112+
2113+
# BEGIN GOOGLE-INTERNAL
2114+
# pylint: disable=g-import-not-at-top,g-bad-import-order
2115+
import copy
2116+
# pylint: enable=g-import-not-at-top,g-bad-import-order
2117+
2118+
2119+
def _setattr(
2120+
obj: Any,
2121+
key: str,
2122+
value: Any,
2123+
ensure_keys_exist: bool,
2124+
prefix: str,
2125+
) -> None:
2126+
"""Recursively setattr() for an object with a dot-separated key."""
2127+
if '.' in key:
2128+
key_first, key_rest = key.split('.', 1)
2129+
if not hasattr(obj, key_first):
2130+
if ensure_keys_exist:
2131+
full_key = prefix + key
2132+
child_key = prefix + key_first
2133+
raise AttributeError('Key "{}" cannot be set as "{}" was not found.'
2134+
.format(full_key, child_key))
2135+
else:
2136+
setattr(obj, key_first, {})
2137+
_setattr(getattr(obj, key_first), key_rest, value,
2138+
ensure_keys_exist=ensure_keys_exist,
2139+
prefix=prefix + key_first + '.')
2140+
else:
2141+
if ensure_keys_exist and not hasattr(obj, key):
2142+
full_key = prefix + key
2143+
raise AttributeError(
2144+
f'Key "{full_key}" cannot be set as "{full_key}" was not found.')
2145+
setattr(obj, key, value)
2146+
2147+
2148+
def update_in_place(
2149+
config: ConfigDict,
2150+
params: Mapping[str, Any],
2151+
ensure_keys_exist: bool = False,
2152+
) -> ConfigDict:
2153+
"""Returns the config dict updated in-place using dot-separated paths.
2154+
2155+
Args:
2156+
config: The config dict to update.
2157+
params: A dictionary containing the updated items. Keys are
2158+
dot-separated paths, and values are the updated values.
2159+
ensure_keys_exist: If True, raise an AttributeError, if a key does not
2160+
exist yet.
2161+
2162+
Returns:
2163+
The in-place updated config.
2164+
"""
2165+
for key, value in params.items():
2166+
_setattr(config, key, value,
2167+
ensure_keys_exist=ensure_keys_exist,
2168+
prefix='')
2169+
return config
2170+
2171+
2172+
def copy_and_update(
2173+
config: ConfigDict,
2174+
params: Mapping[str, Any],
2175+
ensure_keys_exist: bool = False,
2176+
) -> ConfigDict:
2177+
"""Copies a config dict and update using dot-separated paths.
2178+
2179+
Args:
2180+
config: The config dict to copy and update.
2181+
params: A dictionary containing the updated items. Keys are
2182+
dot-separated paths, and values are the updated values.
2183+
ensure_keys_exist: If True, raise an AttributeError, if a key does not exist
2184+
yet.
2185+
2186+
Returns:
2187+
A copy of the updated config dict.
2188+
"""
2189+
new_config = copy.deepcopy(config)
2190+
return update_in_place(new_config, params, ensure_keys_exist)
2191+
# END GOOGLE-INTERNAL

0 commit comments

Comments
 (0)