Skip to content

Commit ca4b202

Browse files
authored
Allow signature and property to match in either direction (#151)
Signed-off-by: Tim Paine <3105306+timkpaine@users.noreply.github.com>
1 parent b496339 commit ca4b202

File tree

2 files changed

+31
-1
lines changed

2 files changed

+31
-1
lines changed

ccflow/callable.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ def _check_signature(self):
110110
not isinstance(type_call_arg, TypeVar)
111111
and type_call_arg is not Signature.empty
112112
and (not isclass(type_call_arg) or not issubclass(type_call_arg, self.context_type))
113+
and (not isclass(self.context_type) or not issubclass(self.context_type, type_call_arg))
113114
):
114115
err_msg_type_mismatch = (
115116
f"The context_type {self.context_type} must match the type of the context accepted by __call__ {type_call_arg}"
@@ -122,6 +123,7 @@ def _check_signature(self):
122123
not isinstance(type_call_return, TypeVar)
123124
and type_call_return is not Signature.empty
124125
and (not isclass(type_call_return) or not issubclass(type_call_return, self.result_type))
126+
and (not isclass(self.result_type) or not issubclass(self.result_type, type_call_return))
125127
):
126128
err_msg_type_mismatch = f"The result_type {self.result_type} must match the return type of __call__ {type_call_return}"
127129
raise ValueError(err_msg_type_mismatch)

ccflow/tests/test_callable.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from pickle import dumps as pdumps, loads as ploads
2-
from typing import Generic, List, Optional, TypeVar
2+
from typing import Generic, List, Optional, Tuple, Type, TypeVar
33
from unittest import TestCase
44

55
import ray
@@ -268,6 +268,30 @@ def __call__(self, context: NullContext) -> GenericResult[int]:
268268
return GenericResult[int](value=42)
269269

270270

271+
class MyGenericContext(ContextBase, Generic[TContext]):
272+
value: TContext
273+
274+
275+
class ModelMixedGenericsEnforceContextMatch(CallableModel, Generic[TContext, TResult]):
276+
model: CallableModelGenericType[TContext, TResult]
277+
278+
@property
279+
def context_type(self) -> Type[ContextType]:
280+
return MyGenericContext[self.model.context_type]
281+
282+
@property
283+
def result_type(self) -> Type[ResultType]:
284+
return GenericResult[self.model.result_type]
285+
286+
@Flow.deps
287+
def __deps__(self, context: MyGenericContext[TContext]) -> List[Tuple[CallableModelGenericType[TContext, TResult], List[ContextType]]]:
288+
return []
289+
290+
@Flow.call
291+
def __call__(self, context: MyGenericContext[TContext]) -> TResult:
292+
return GenericResult(value=None)
293+
294+
271295
class TestContext(TestCase):
272296
def test_immutable(self):
273297
x = MyContext(a="foo")
@@ -413,6 +437,10 @@ def test_identity(self):
413437
self.assertEqual(ident(context), context)
414438
self.assertIsNot(ident(context), context)
415439

440+
def test_context_call_match_enforcement_generic_base(self):
441+
# This should not raise
442+
_ = ModelMixedGenericsEnforceContextMatch(model=IdentityCallable())
443+
416444

417445
class TestWrapperModel(TestCase):
418446
def test_wrapper(self):

0 commit comments

Comments
 (0)