|
1 | 1 | from pickle import dumps as pdumps, loads as ploads |
2 | | -from typing import Generic, List, Optional, TypeVar |
| 2 | +from typing import Generic, List, Optional, Tuple, Type, TypeVar |
3 | 3 | from unittest import TestCase |
4 | 4 |
|
5 | 5 | import ray |
@@ -268,6 +268,30 @@ def __call__(self, context: NullContext) -> GenericResult[int]: |
268 | 268 | return GenericResult[int](value=42) |
269 | 269 |
|
270 | 270 |
|
| 271 | +class MyGenericContext(ContextBase, Generic[TContext]): |
| 272 | + value: TContext |
| 273 | + |
| 274 | + |
| 275 | +class ModelMixedGenericsEnforceContextMatch(CallableModel, Generic[TContext, TResult]): |
| 276 | + model: CallableModelGenericType[TContext, TResult] |
| 277 | + |
| 278 | + @property |
| 279 | + def context_type(self) -> Type[ContextType]: |
| 280 | + return MyGenericContext[self.model.context_type] |
| 281 | + |
| 282 | + @property |
| 283 | + def result_type(self) -> Type[ResultType]: |
| 284 | + return GenericResult[self.model.result_type] |
| 285 | + |
| 286 | + @Flow.deps |
| 287 | + def __deps__(self, context: MyGenericContext[TContext]) -> List[Tuple[CallableModelGenericType[TContext, TResult], List[ContextType]]]: |
| 288 | + return [] |
| 289 | + |
| 290 | + @Flow.call |
| 291 | + def __call__(self, context: MyGenericContext[TContext]) -> TResult: |
| 292 | + return GenericResult(value=None) |
| 293 | + |
| 294 | + |
271 | 295 | class TestContext(TestCase): |
272 | 296 | def test_immutable(self): |
273 | 297 | x = MyContext(a="foo") |
@@ -413,6 +437,10 @@ def test_identity(self): |
413 | 437 | self.assertEqual(ident(context), context) |
414 | 438 | self.assertIsNot(ident(context), context) |
415 | 439 |
|
| 440 | + def test_context_call_match_enforcement_generic_base(self): |
| 441 | + # This should not raise |
| 442 | + _ = ModelMixedGenericsEnforceContextMatch(model=IdentityCallable()) |
| 443 | + |
416 | 444 |
|
417 | 445 | class TestWrapperModel(TestCase): |
418 | 446 | def test_wrapper(self): |
|
0 commit comments