Skip to content

Commit b496339

Browse files
authored
Provide utility functions for calling from hydra (#128)
* Add mixin to allow pulling key-value pairs from python objects Signed-off-by: Nijat Khanbabayev <nijat.khanbabayev@cubistsystematic.com> * Intermediate adding compose file Signed-off-by: Nijat Khanbabayev <nijat.khanbabayev@cubistsystematic.com> * Remove old code, dont use mixin, and utility functions Signed-off-by: Nijat Khanbabayev <nijat.khanbabayev@cubistsystematic.com> * Add update_from_base Signed-off-by: Nijat Khanbabayev <nijat.khanbabayev@cubistsystematic.com> * Add type hints, change names, move import compose Signed-off-by: Nijat Khanbabayev <nijat.khanbabayev@cubistsystematic.com> * Delayed import of model_alias in base.py to avoid circular import Signed-off-by: Nijat Khanbabayev <nijat.khanbabayev@cubistsystematic.com> * Add deprecation warning for ccflow.base.model_alias Signed-off-by: Nijat Khanbabayev <nijat.khanbabayev@cubistsystematic.com> * Remove model_alias from base.py, import from compose instead Signed-off-by: Nijat Khanbabayev <nijat.khanbabayev@cubistsystematic.com> --------- Signed-off-by: Nijat Khanbabayev <nijat.khanbabayev@cubistsystematic.com>
1 parent 907c06a commit b496339

File tree

9 files changed

+508
-16
lines changed

9 files changed

+508
-16
lines changed

ccflow/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,16 @@
11
__version__ = "0.6.10"
22

3+
# Import exttypes early so modules that import `from ccflow import PyObjectPath` during
4+
# initialization find it (avoids circular import issues with functions that import utilities
5+
# which, in turn, import `ccflow`).
6+
from .exttypes import * # noqa: I001
7+
38
from .arrow import *
49
from .base import *
10+
from .compose import *
511
from .callable import *
612
from .context import *
713
from .enums import Enum
8-
from .exttypes import *
914
from .global_state import *
1015
from .models import *
1116
from .object_config import *

ccflow/base.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
log = logging.getLogger(__name__)
3434

3535
__all__ = (
36-
"model_alias",
3736
"BaseModel",
3837
"ModelRegistry",
3938
"ModelType",
@@ -323,20 +322,6 @@ def _is_config_subregistry(value):
323322
return False
324323

325324

326-
def model_alias(model_name: str) -> BaseModel:
327-
"""Function to alias a BaseModel by name in the root registry.
328-
329-
Useful for configs in hydra where we want a config object to point directly to another config object.
330-
331-
Args:
332-
model_name: The name of the underlying model to point to in the registry
333-
Example:
334-
_target_: ccflow.model_alias
335-
model_name: foo
336-
"""
337-
return BaseModel.model_validate(model_name)
338-
339-
340325
ModelType = TypeVar("ModelType", bound=BaseModel)
341326

342327

ccflow/compose.py

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
from typing import Any, Dict, Optional, Type, Union
2+
3+
from .base import BaseModel
4+
from .exttypes.pyobjectpath import _TYPE_ADAPTER as PyObjectPathTA
5+
6+
__all__ = (
7+
"model_alias",
8+
"from_python",
9+
"update_from_template",
10+
)
11+
12+
13+
def model_alias(model_name: str) -> BaseModel:
14+
"""Return a model by alias from the registry.
15+
16+
Hydra-friendly: `_target_: ccflow.compose.model_alias` with `model_name`.
17+
18+
Args:
19+
model_name: Alias string registered in the model registry. Typically a
20+
short name that maps to a configured BaseModel.
21+
22+
Returns:
23+
A ``BaseModel`` instance resolved from the registry by ``model_name``.
24+
"""
25+
return BaseModel.model_validate(model_name)
26+
27+
28+
def from_python(py_object_path: str, indexer: Optional[list] = None) -> Any:
29+
"""Hydra-friendly: resolve and return any Python object by import path.
30+
31+
Optionally accepts ``indexer``, a list of keys that will be applied in
32+
order to index into the resolved object. No safety checks are performed;
33+
indexing errors will propagate.
34+
35+
Args:
36+
py_object_path: Dotted import path to a Python object, e.g.
37+
``mypkg.module.OBJECT`` or ``mypkg.module.ClassName``.
38+
indexer: Optional list of keys to apply in order to index into the
39+
resolved object (e.g., strings for dict keys or integers for list
40+
indexes).
41+
42+
Returns:
43+
The resolved Python object, or the value obtained after applying all
44+
``indexer`` keys to the resolved object.
45+
46+
Example YAML usage:
47+
some_value:
48+
_target_: ccflow.compose.from_python
49+
py_object_path: mypkg.module.OBJECT
50+
51+
nested_value:
52+
_target_: ccflow.compose.from_python
53+
py_object_path: mypkg.module.NESTED
54+
indexer: ["a", "b"]
55+
"""
56+
obj = PyObjectPathTA.validate_python(py_object_path).object
57+
if indexer:
58+
for key in indexer:
59+
obj = obj[key]
60+
return obj
61+
62+
63+
def update_from_template(
64+
base: Optional[Union[str, Dict[str, Any], BaseModel]] = None,
65+
*,
66+
target_class: Optional[Union[str, Type]] = None,
67+
update: Optional[Dict[str, Any]] = None,
68+
) -> Any:
69+
"""Generic update helper that constructs an instance from a base and updates.
70+
71+
Args:
72+
base: Either a registry alias string, a dict, or a Pydantic BaseModel. If BaseModel, it is converted
73+
to a shallow dict via ``dict(base)`` to preserve nested object identity.
74+
target_class: Optional path to the target class to construct. May be a
75+
string import path or the type itself. If None and ``base`` is a
76+
BaseModel, returns an instance of ``base.__class__``. If None and
77+
``base`` is a dict, returns the updated dict.
78+
update: Optional dict of updates to apply.
79+
80+
Returns:
81+
Instance of ``target_class`` if provided; otherwise an instance of the same
82+
class as ``base`` when base is a BaseModel; or the updated dict when base
83+
is a dict.
84+
"""
85+
# Determine base dict and default target
86+
default_target = None
87+
if isinstance(base, str):
88+
# Allow passing alias name directly; resolve from registry
89+
base = model_alias(base)
90+
if isinstance(base, BaseModel):
91+
base_dict = dict(base)
92+
default_target = base.__class__
93+
elif isinstance(base, dict):
94+
base_dict = dict(base)
95+
elif base is None:
96+
base_dict = {}
97+
else:
98+
raise TypeError("base must be a dict, BaseModel, or None")
99+
100+
# Merge updates: explicit dict first, then kwargs
101+
if update:
102+
base_dict.update(update)
103+
104+
# Resolve target class if provided as string path
105+
target = None
106+
if target_class is not None:
107+
if isinstance(target_class, str):
108+
target = PyObjectPathTA.validate_python(target_class).object
109+
else:
110+
target = target_class
111+
else:
112+
target = default_target
113+
114+
if target is None:
115+
# No target: return dict update for dict base
116+
return base_dict
117+
118+
# Construct instance of target with updated fields
119+
return target(**base_dict)
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
shared_model:
2+
_target_: ccflow.compose.from_python
3+
py_object_path: ccflow.tests.data.python_object_samples.SHARED_MODEL
4+
5+
consumer:
6+
_target_: ccflow.tests.data.python_object_samples.Consumer
7+
shared: shared_model
8+
tag: consumer1
9+
10+
# Demonstrate from_python returning a dict (non-BaseModel)
11+
holder:
12+
_target_: ccflow.tests.data.python_object_samples.SharedHolder
13+
name: holder1
14+
cfg:
15+
_target_: ccflow.compose.from_python
16+
py_object_path: ccflow.tests.data.python_object_samples.SHARED_CFG
17+
18+
# Use update_from_template to update a field while preserving shared identity
19+
consumer_updated:
20+
_target_: ccflow.compose.update_from_template
21+
base: consumer
22+
update:
23+
tag: consumer2

ccflow/tests/data/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
"""Shared test data modules.
2+
3+
Import sample configs using module-level objects in `python_object_samples`.
4+
"""
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
"""Sample python objects for testing from_python and identity preservation."""
2+
3+
from typing import Dict
4+
5+
from ccflow import BaseModel
6+
7+
# Module-level objects
8+
SHARED_CFG: Dict[str, int] = {"x": 1, "y": 2}
9+
OTHER_CFG: Dict[str, int] = {"x": 10, "y": 20}
10+
"""Dict samples; identity for dicts is not guaranteed by Pydantic."""
11+
12+
NESTED_CFG = {
13+
"db": {"host": "seed.local", "port": 7000, "name": "seed"},
14+
"meta": {"env": "dev"},
15+
}
16+
17+
18+
class SharedHolder(BaseModel):
19+
name: str
20+
cfg: Dict[str, int]
21+
22+
23+
class SharedModel(BaseModel):
24+
val: int = 0
25+
26+
27+
# Module-level instance to be resolved via from_python
28+
SHARED_MODEL = SharedModel(val=42)
29+
30+
31+
class Consumer(BaseModel):
32+
shared: SharedModel
33+
tag: str = ""

0 commit comments

Comments
 (0)