Skip to content

Commit 58b4650

Browse files
committed
Register locally defined CallableModels and Contexts in a module to work with PyObjectPath
Signed-off-by: Nijat Khanbabayev <nijat.khanbabayev@cubistsystematic.com>
1 parent 29534a3 commit 58b4650

File tree

9 files changed

+524
-5
lines changed

9 files changed

+524
-5
lines changed

ccflow/base.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from typing_extensions import Self
3131

3232
from .exttypes.pyobjectpath import PyObjectPath
33+
from .local_persistence import register_local_subclass
3334

3435
log = logging.getLogger(__name__)
3536

@@ -156,6 +157,15 @@ class BaseModel(PydanticBaseModel, _RegistryMixin, metaclass=_SerializeAsAnyMeta
156157
- Registration by name, and coercion from string name to allow for object re-use from the configs
157158
"""
158159

160+
__ccflow_local_registration_kind__: ClassVar[str] = "model"
161+
162+
@classmethod
163+
def __pydantic_init_subclass__(cls, **kwargs):
164+
# __pydantic_init_subclass__ is the supported hook point once Pydantic finishes wiring the subclass.
165+
super().__pydantic_init_subclass__(**kwargs)
166+
kind = getattr(cls, "__ccflow_local_registration_kind__", "model")
167+
register_local_subclass(cls, kind=kind)
168+
159169
@computed_field(
160170
alias="_target_",
161171
repr=False,
@@ -820,6 +830,8 @@ class ContextBase(ResultBase):
820830
that is an input into another CallableModel.
821831
"""
822832

833+
__ccflow_local_registration_kind__: ClassVar[str] = "context"
834+
823835
model_config = ConfigDict(
824836
frozen=True,
825837
arbitrary_types_allowed=False,

ccflow/callable.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,8 @@ class _CallableModel(BaseModel, abc.ABC):
7474
The purpose of this class is to provide type definitions of context_type and return_type.
7575
"""
7676

77+
__ccflow_local_registration_kind__: ClassVar[str] = "callable_model"
78+
7779
model_config = ConfigDict(
7880
ignored_types=(property,),
7981
)

ccflow/local_persistence.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
"""Helpers for persisting BaseModel-derived classes defined inside local scopes."""
2+
3+
from __future__ import annotations
4+
5+
import re
6+
import sys
7+
from collections import defaultdict
8+
from itertools import count
9+
from types import ModuleType
10+
from typing import Any, DefaultDict, Type
11+
12+
__all__ = ("LOCAL_ARTIFACTS_MODULE_NAME", "register_local_subclass")
13+
14+
15+
LOCAL_ARTIFACTS_MODULE_NAME = "ccflow._local_artifacts"
16+
_LOCAL_MODULE_DOC = "Auto-generated BaseModel subclasses created outside importable modules."
17+
18+
_SANITIZE_PATTERN = re.compile(r"[^0-9A-Za-z_]")
19+
_LOCAL_KIND_COUNTERS: DefaultDict[str, count] = defaultdict(lambda: count())
20+
21+
22+
def _ensure_module(name: str, doc: str) -> ModuleType:
23+
"""Ensure the dynamic module exists so import paths remain stable."""
24+
module = sys.modules.get(name)
25+
if module is None:
26+
module = ModuleType(name, doc)
27+
sys.modules[name] = module
28+
parent_name, _, attr = name.rpartition(".")
29+
if parent_name:
30+
parent_module = sys.modules.get(parent_name)
31+
if parent_module and not hasattr(parent_module, attr):
32+
setattr(parent_module, attr, module)
33+
return module
34+
35+
36+
_LOCAL_ARTIFACTS_MODULE = _ensure_module(LOCAL_ARTIFACTS_MODULE_NAME, _LOCAL_MODULE_DOC)
37+
38+
39+
def _needs_registration(cls: Type[Any]) -> bool:
40+
module = getattr(cls, "__module__", "")
41+
qualname = getattr(cls, "__qualname__", "")
42+
return "<locals>" in qualname or module == "__main__"
43+
44+
45+
def _sanitize_identifier(value: str, fallback: str) -> str:
46+
sanitized = _SANITIZE_PATTERN.sub("_", value or "")
47+
sanitized = sanitized.strip("_") or fallback
48+
if sanitized[0].isdigit():
49+
sanitized = f"_{sanitized}"
50+
return sanitized
51+
52+
53+
def _build_unique_name(*, kind_slug: str, name_hint: str) -> str:
54+
sanitized_hint = _sanitize_identifier(name_hint, "BaseModel")
55+
counter = _LOCAL_KIND_COUNTERS[kind_slug]
56+
return f"{kind_slug}__{sanitized_hint}__{next(counter)}"
57+
58+
59+
def register_local_subclass(cls: Type[Any], *, kind: str = "model") -> None:
60+
"""Register BaseModel subclasses created in local scopes."""
61+
if getattr(cls, "__module__", "").startswith(LOCAL_ARTIFACTS_MODULE_NAME):
62+
return
63+
if not _needs_registration(cls):
64+
return
65+
66+
name_hint = f"{getattr(cls, '__module__', '')}.{getattr(cls, '__qualname__', '')}"
67+
kind_slug = _sanitize_identifier(kind, "model").lower()
68+
unique_name = _build_unique_name(kind_slug=kind_slug, name_hint=name_hint)
69+
setattr(_LOCAL_ARTIFACTS_MODULE, unique_name, cls)
70+
cls.__module__ = _LOCAL_ARTIFACTS_MODULE.__name__
71+
cls.__qualname__ = unique_name
72+
setattr(cls, "__ccflow_dynamic_origin__", name_hint)

ccflow/tests/evaluators/test_common.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,24 @@
11
import logging
22
from datetime import date
3+
from typing import ClassVar
34
from unittest import TestCase
45

56
import pandas as pd
67
import pyarrow as pa
78

8-
from ccflow import DateContext, DateRangeContext, Evaluator, FlowOptionsOverride, ModelEvaluationContext, NullContext
9+
from ccflow import (
10+
CallableModel,
11+
ContextBase,
12+
DateContext,
13+
DateRangeContext,
14+
Evaluator,
15+
Flow,
16+
FlowOptions,
17+
FlowOptionsOverride,
18+
GenericResult,
19+
ModelEvaluationContext,
20+
NullContext,
21+
)
922
from ccflow.evaluators import (
1023
FallbackEvaluator,
1124
GraphEvaluator,
@@ -16,6 +29,7 @@
1629
combine_evaluators,
1730
get_dependency_graph,
1831
)
32+
from ccflow.tests.local_helpers import build_nested_graph_chain
1933

2034
from .util import CircularModel, MyDateCallable, MyDateRangeCallable, MyRaisingCallable, NodeModel, ResultModel
2135

@@ -473,3 +487,37 @@ def test_graph_evaluator_circular(self):
473487
evaluator = GraphEvaluator()
474488
with FlowOptionsOverride(options={"evaluator": evaluator}):
475489
self.assertRaises(Exception, root, context)
490+
491+
def test_graph_evaluator_with_local_models_and_cache(self):
492+
ParentCls, ChildCls = build_nested_graph_chain()
493+
ChildCls.call_count = 0
494+
model = ParentCls(child=ChildCls())
495+
evaluator = MultiEvaluator(evaluators=[GraphEvaluator(), MemoryCacheEvaluator()])
496+
with FlowOptionsOverride(options=FlowOptions(evaluator=evaluator, cacheable=True)):
497+
first = model(NullContext())
498+
second = model(NullContext())
499+
self.assertEqual(first.value, second.value)
500+
self.assertEqual(ChildCls.call_count, 1)
501+
502+
503+
class TestMemoryCacheEvaluatorLocal(TestCase):
504+
def test_memory_cache_handles_local_context_and_callable(self):
505+
class LocalContext(ContextBase):
506+
value: int
507+
508+
class LocalModel(CallableModel):
509+
call_count: ClassVar[int] = 0
510+
511+
@Flow.call
512+
def __call__(self, context: LocalContext) -> GenericResult:
513+
type(self).call_count += 1
514+
return GenericResult(value=context.value * 2)
515+
516+
evaluator = MemoryCacheEvaluator()
517+
LocalModel.call_count = 0
518+
model = LocalModel()
519+
with FlowOptionsOverride(options=FlowOptions(evaluator=evaluator, cacheable=True)):
520+
result1 = model(LocalContext(value=5))
521+
result2 = model(LocalContext(value=5))
522+
self.assertEqual(result1.value, result2.value)
523+
self.assertEqual(LocalModel.call_count, 1)

ccflow/tests/local_helpers.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
"""Shared helpers for constructing local-scope contexts/models in tests."""
2+
3+
from typing import ClassVar, Tuple, Type
4+
5+
from ccflow import CallableModel, ContextBase, Flow, GenericResult, GraphDepList, NullContext
6+
7+
8+
def build_local_callable(name: str = "LocalCallable") -> Type[CallableModel]:
9+
class _LocalCallable(CallableModel):
10+
@Flow.call
11+
def __call__(self, context: NullContext) -> GenericResult:
12+
return GenericResult(value="local")
13+
14+
_LocalCallable.__name__ = name
15+
return _LocalCallable
16+
17+
18+
def build_local_context(name: str = "LocalContext") -> Type[ContextBase]:
19+
class _LocalContext(ContextBase):
20+
value: int
21+
22+
_LocalContext.__name__ = name
23+
return _LocalContext
24+
25+
26+
def build_nested_graph_chain() -> Tuple[Type[CallableModel], Type[CallableModel]]:
27+
class LocalLeaf(CallableModel):
28+
call_count: ClassVar[int] = 0
29+
30+
@Flow.call
31+
def __call__(self, context: NullContext) -> GenericResult:
32+
type(self).call_count += 1
33+
return GenericResult(value="leaf")
34+
35+
class LocalParent(CallableModel):
36+
child: LocalLeaf
37+
38+
@Flow.call
39+
def __call__(self, context: NullContext) -> GenericResult:
40+
return self.child(context)
41+
42+
@Flow.deps
43+
def __deps__(self, context: NullContext) -> GraphDepList:
44+
return [(self.child, [context])]
45+
46+
return LocalParent, LocalLeaf

ccflow/tests/test_base.py

Lines changed: 67 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
from typing import Any, Dict, List
2-
from unittest import TestCase
2+
from unittest import TestCase, mock
33

4-
from pydantic import ConfigDict, ValidationError
4+
from pydantic import BaseModel as PydanticBaseModel, ConfigDict, ValidationError
55

6-
from ccflow import BaseModel, PyObjectPath
6+
from ccflow import BaseModel, CallableModel, ContextBase, Flow, GenericResult, NullContext, PyObjectPath
7+
from ccflow.local_persistence import LOCAL_ARTIFACTS_MODULE_NAME
78

89

910
class ModelA(BaseModel):
@@ -106,6 +107,20 @@ def test_type_after_assignment(self):
106107
self.assertIsInstance(m.type_, PyObjectPath)
107108
self.assertEqual(m.type_, path)
108109

110+
def test_pyobjectpath_requires_ccflow_local_registration(self):
111+
class PlainLocalModel(PydanticBaseModel):
112+
value: int
113+
114+
with self.assertRaises(ValueError):
115+
PyObjectPath.validate(PlainLocalModel)
116+
117+
class LocalCcflowModel(BaseModel):
118+
value: int
119+
120+
path = PyObjectPath.validate(LocalCcflowModel)
121+
self.assertEqual(path.object, LocalCcflowModel)
122+
self.assertTrue(str(path).startswith(f"{LOCAL_ARTIFACTS_MODULE_NAME}."))
123+
109124
def test_validate(self):
110125
self.assertEqual(ModelA.model_validate({"x": "foo"}), ModelA(x="foo"))
111126
type_ = "ccflow.tests.test_base.ModelA"
@@ -157,3 +172,52 @@ def test_widget(self):
157172
{"expanded": True, "root": "bar"},
158173
),
159174
)
175+
176+
177+
class TestLocalRegistrationKind(TestCase):
178+
def test_base_model_defaults_to_model_kind(self):
179+
with mock.patch("ccflow.base.register_local_subclass") as register:
180+
181+
class LocalModel(BaseModel):
182+
value: int
183+
184+
register.assert_called_once()
185+
args, kwargs = register.call_args
186+
self.assertIs(args[0], LocalModel)
187+
self.assertEqual(kwargs["kind"], "model")
188+
189+
def test_context_defaults_to_context_kind(self):
190+
with mock.patch("ccflow.base.register_local_subclass") as register:
191+
192+
class LocalContext(ContextBase):
193+
value: int
194+
195+
register.assert_called_once()
196+
args, kwargs = register.call_args
197+
self.assertIs(args[0], LocalContext)
198+
self.assertEqual(kwargs["kind"], "context")
199+
200+
def test_callable_defaults_to_callable_kind(self):
201+
with mock.patch("ccflow.base.register_local_subclass") as register:
202+
203+
class LocalCallable(CallableModel):
204+
@Flow.call
205+
def __call__(self, context: NullContext) -> GenericResult:
206+
return GenericResult(value="ok")
207+
208+
register.assert_called_once()
209+
args, kwargs = register.call_args
210+
self.assertIs(args[0], LocalCallable)
211+
self.assertEqual(kwargs["kind"], "callable_model")
212+
213+
def test_explicit_override_respected(self):
214+
with mock.patch("ccflow.base.register_local_subclass") as register:
215+
216+
class CustomKind(BaseModel):
217+
__ccflow_local_registration_kind__ = "custom"
218+
value: int
219+
220+
register.assert_called_once()
221+
args, kwargs = register.call_args
222+
self.assertIs(args[0], CustomKind)
223+
self.assertEqual(kwargs["kind"], "custom")

0 commit comments

Comments
 (0)