|
1 | 1 | from typing import Any, Dict, List |
2 | | -from unittest import TestCase |
| 2 | +from unittest import TestCase, mock |
3 | 3 |
|
4 | | -from pydantic import ConfigDict, ValidationError |
| 4 | +from pydantic import BaseModel as PydanticBaseModel, ConfigDict, ValidationError |
5 | 5 |
|
6 | | -from ccflow import BaseModel, PyObjectPath |
| 6 | +from ccflow import BaseModel, CallableModel, ContextBase, Flow, GenericResult, NullContext, PyObjectPath |
| 7 | +from ccflow.local_persistence import LOCAL_ARTIFACTS_MODULE_NAME |
7 | 8 |
|
8 | 9 |
|
9 | 10 | class ModelA(BaseModel): |
@@ -106,6 +107,20 @@ def test_type_after_assignment(self): |
106 | 107 | self.assertIsInstance(m.type_, PyObjectPath) |
107 | 108 | self.assertEqual(m.type_, path) |
108 | 109 |
|
| 110 | + def test_pyobjectpath_requires_ccflow_local_registration(self): |
| 111 | + class PlainLocalModel(PydanticBaseModel): |
| 112 | + value: int |
| 113 | + |
| 114 | + with self.assertRaises(ValueError): |
| 115 | + PyObjectPath.validate(PlainLocalModel) |
| 116 | + |
| 117 | + class LocalCcflowModel(BaseModel): |
| 118 | + value: int |
| 119 | + |
| 120 | + path = PyObjectPath.validate(LocalCcflowModel) |
| 121 | + self.assertEqual(path.object, LocalCcflowModel) |
| 122 | + self.assertTrue(str(path).startswith(f"{LOCAL_ARTIFACTS_MODULE_NAME}.")) |
| 123 | + |
109 | 124 | def test_validate(self): |
110 | 125 | self.assertEqual(ModelA.model_validate({"x": "foo"}), ModelA(x="foo")) |
111 | 126 | type_ = "ccflow.tests.test_base.ModelA" |
@@ -157,3 +172,52 @@ def test_widget(self): |
157 | 172 | {"expanded": True, "root": "bar"}, |
158 | 173 | ), |
159 | 174 | ) |
| 175 | + |
| 176 | + |
| 177 | +class TestLocalRegistrationKind(TestCase): |
| 178 | + def test_base_model_defaults_to_model_kind(self): |
| 179 | + with mock.patch("ccflow.base.register_local_subclass") as register: |
| 180 | + |
| 181 | + class LocalModel(BaseModel): |
| 182 | + value: int |
| 183 | + |
| 184 | + register.assert_called_once() |
| 185 | + args, kwargs = register.call_args |
| 186 | + self.assertIs(args[0], LocalModel) |
| 187 | + self.assertEqual(kwargs["kind"], "model") |
| 188 | + |
| 189 | + def test_context_defaults_to_context_kind(self): |
| 190 | + with mock.patch("ccflow.base.register_local_subclass") as register: |
| 191 | + |
| 192 | + class LocalContext(ContextBase): |
| 193 | + value: int |
| 194 | + |
| 195 | + register.assert_called_once() |
| 196 | + args, kwargs = register.call_args |
| 197 | + self.assertIs(args[0], LocalContext) |
| 198 | + self.assertEqual(kwargs["kind"], "context") |
| 199 | + |
| 200 | + def test_callable_defaults_to_callable_kind(self): |
| 201 | + with mock.patch("ccflow.base.register_local_subclass") as register: |
| 202 | + |
| 203 | + class LocalCallable(CallableModel): |
| 204 | + @Flow.call |
| 205 | + def __call__(self, context: NullContext) -> GenericResult: |
| 206 | + return GenericResult(value="ok") |
| 207 | + |
| 208 | + register.assert_called_once() |
| 209 | + args, kwargs = register.call_args |
| 210 | + self.assertIs(args[0], LocalCallable) |
| 211 | + self.assertEqual(kwargs["kind"], "callable_model") |
| 212 | + |
| 213 | + def test_explicit_override_respected(self): |
| 214 | + with mock.patch("ccflow.base.register_local_subclass") as register: |
| 215 | + |
| 216 | + class CustomKind(BaseModel): |
| 217 | + __ccflow_local_registration_kind__ = "custom" |
| 218 | + value: int |
| 219 | + |
| 220 | + register.assert_called_once() |
| 221 | + args, kwargs = register.call_args |
| 222 | + self.assertIs(args[0], CustomKind) |
| 223 | + self.assertEqual(kwargs["kind"], "custom") |
0 commit comments