Skip to content

Commit 5d4df98

Browse files
authored
Merge pull request #166 from Point72/vs/null_context
Make NullContext an alias of ContextBase
2 parents 250e95d + 50179c9 commit 5d4df98

File tree

4 files changed

+32
-15
lines changed

4 files changed

+32
-15
lines changed

ccflow/base.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -829,6 +829,9 @@ class ContextBase(ResultBase):
829829

830830
@model_validator(mode="wrap")
831831
def _context_validator(cls, v, handler, info):
832+
if v is None:
833+
return handler({})
834+
832835
# Add deepcopy for v2 because it doesn't support copy_on_model_validation
833836
v = copy.deepcopy(v)
834837

ccflow/context.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -90,15 +90,8 @@
9090

9191
_SEPARATOR = ","
9292

93-
94-
class NullContext(ContextBase):
95-
"""A Null Context that is used when no context is provided."""
96-
97-
@model_validator(mode="wrap")
98-
def _validate_none(cls, v, handler, info):
99-
v = v or {}
100-
return handler(v)
101-
93+
# Starting 0.8.0 Nullcontext is an alias to ContextBase
94+
NullContext = ContextBase
10295

10396
C = TypeVar("C", bound=Hashable)
10497

ccflow/tests/test_callable.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,10 @@ class MyExtendedContext(MyContext):
3232
c: bool
3333

3434

35+
class MyOtherContext(ContextBase):
36+
a: int
37+
38+
3539
class ListContext(ContextBase):
3640
ll: List[str] = []
3741

@@ -152,7 +156,7 @@ class BadModelMismatchedContextAndCall(CallableModel):
152156

153157
@property
154158
def context_type(self):
155-
return NullContext
159+
return MyOtherContext
156160

157161
@property
158162
def result_type(self):
@@ -163,7 +167,7 @@ def __call__(self, context: MyContext) -> MyResult:
163167
return context
164168

165169

166-
class BadModelGenericMismatchedContextAndCall(CallableModelGenericType[NullContext, MyResult]):
170+
class BadModelGenericMismatchedContextAndCall(CallableModelGenericType[MyOtherContext, MyResult]):
167171
"""Model with mismatched context_type and __call__ annotation"""
168172

169173
@Flow.call
@@ -460,7 +464,7 @@ def test_types(self):
460464
error = "__call__ method must take a single argument, named 'context'"
461465
self.assertRaisesRegex(ValueError, error, BadModelDoubleContextArg)
462466

463-
error = "The context_type <class 'ccflow.context.NullContext'> must match the type of the context accepted by __call__ <class 'ccflow.tests.test_callable.MyContext'>"
467+
error = "The context_type <class 'ccflow.tests.test_callable.MyOtherContext'> must match the type of the context accepted by __call__ <class 'ccflow.tests.test_callable.MyContext'>"
464468
self.assertRaisesRegex(ValueError, error, BadModelMismatchedContextAndCall)
465469

466470
error = "The result_type <class 'ccflow.result.generic.GenericResult'> must match the return type of __call__ <class 'ccflow.tests.test_callable.MyResult'>"
@@ -642,7 +646,7 @@ def __call__(self, context: NullContext) -> GenericResult[float]:
642646
MyCallable()
643647

644648
def test_types_generic(self):
645-
error = "Context type annotation <class 'ccflow.tests.test_callable.MyContext'> on __call__ does not match context_type <class 'ccflow.context.NullContext'> defined by CallableModelGenericType"
649+
error = "Context type annotation <class 'ccflow.tests.test_callable.MyContext'> on __call__ does not match context_type <class 'ccflow.tests.test_callable.MyOtherContext'> defined by CallableModelGenericType"
646650
self.assertRaisesRegex(TypeError, error, BadModelGenericMismatchedContextAndCall)
647651

648652
error = "Return type annotation <class 'ccflow.tests.test_callable.MyResult'> on __call__ does not match result_type <class 'ccflow.result.generic.GenericResult'> defined by CallableModelGenericType"

ccflow/tests/test_context.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,11 @@
3030
from ccflow.result import GenericResult
3131

3232

33+
class MyDefaultContext(ContextBase):
34+
b: float = 3.14
35+
c: bool = False
36+
37+
3338
class TestContexts(TestCase):
3439
def test_null_context(self):
3540
n1 = NullContext()
@@ -38,11 +43,25 @@ def test_null_context(self):
3843
self.assertEqual(hash(n1), hash(n2))
3944

4045
def test_null_context_validation(self):
46+
# Context creation is based on two main assumptions:
47+
# 1. If there is enough information to create a context, it should be created.
48+
# 2. Since NullContext has no required fields, it can be created from None,
49+
# empty containers ({} or []), or any other context.
4150
self.assertEqual(NullContext.model_validate([]), NullContext())
4251
self.assertEqual(NullContext.model_validate({}), NullContext())
4352
self.assertEqual(NullContext.model_validate(None), NullContext())
53+
self.assertIsInstance(NullContext.model_validate(DateContext(date="0d")), NullContext)
4454
self.assertRaises(ValueError, NullContext.model_validate, [True])
4555

56+
def test_context_with_defaults(self):
57+
# Contexts may define default values. Extending the assumptions above:
58+
# Any context inherits the behavior from NullContext, and can be
59+
# created as long as all required fields (if any) are satisfied.
60+
self.assertEqual(TypeAdapter(MyDefaultContext).validate_python(None), MyDefaultContext(b=3.14, c=False))
61+
self.assertEqual(TypeAdapter(MyDefaultContext).validate_python({}), MyDefaultContext(b=3.14, c=False))
62+
self.assertEqual(TypeAdapter(MyDefaultContext).validate_python([]), MyDefaultContext(b=3.14, c=False))
63+
self.assertEqual(TypeAdapter(MyDefaultContext).validate_python({"b": 10.0}), MyDefaultContext(b=10.0, c=False))
64+
4665
def test_date_validation(self):
4766
c = DateContext(date=date.today())
4867
self.assertEqual(DateContext(date=str(date.today())), c)
@@ -228,8 +247,6 @@ def setUp(self):
228247
for name, obj in inspect.getmembers(ctx, inspect.isclass)
229248
if obj.__module__ == ctx.__name__ and issubclass(obj, ContextBase) and not getattr(obj, "__deprecated__", False)
230249
}
231-
# TODO - remove NullContext until we fix the inheritance
232-
self.classes.pop("NullContext")
233250

234251
def test_field_ordering(self):
235252
"""Test that complex contexts have fields in the same order as the basic contexts they are composed of."""

0 commit comments

Comments
 (0)