Skip to content

Commit 294f0ad

Browse files
committed
Fix annotation adjustment for pipe unions
Signed-off-by: Pascal Tomecek <pascal.tomecek@cubistsystematic.com>
1 parent 6e06771 commit 294f0ad

File tree

2 files changed

+19
-32
lines changed

2 files changed

+19
-32
lines changed

ccflow/base.py

Lines changed: 9 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,9 @@
66
import logging
77
import pathlib
88
import platform
9-
import typing
109
from types import GenericAlias, MappingProxyType
11-
from typing import Any, Callable, ClassVar, Dict, Generic, List, Optional, Tuple, Type, TypeVar, get_args, get_origin
10+
from typing import Any, Callable, ClassVar, Dict, Generic, List, Optional, Tuple, Type, TypeVar, Union, get_args, get_origin
1211

13-
import typing_extensions
1412
from omegaconf import DictConfig
1513
from packaging import version
1614
from pydantic import (
@@ -100,35 +98,18 @@ def get_registry_dependencies(self, types: Optional[Tuple["ModelType"]] = None)
10098
# NOTE: For this logic to be removed, require https://github.com/pydantic/pydantic-core/pull/1478
10199
from pydantic._internal._model_construction import ModelMetaclass # noqa: E402
102100

103-
# Required for py38 compatibility
104-
# In python 3.8, get_origin(List[float]) returns list, but you can't call list[float] to retrieve the annotation
105-
# Furthermore, Annotated is part of typing_Extensions and get_origin(Annotated[str, ...]) returns str rather than Annotated
106-
_IS_PY38 = version.parse(platform.python_version()) < version.parse("3.9")
107-
# For a more complete list, see https://github.com/alexmojaki/eval_type_backport/blob/main/eval_type_backport/eval_type_backport.py
108-
_PY38_ORIGIN_MAP = {
109-
tuple: typing.Tuple,
110-
list: typing.List,
111-
dict: typing.Dict,
112-
set: typing.Set,
113-
frozenset: typing.FrozenSet,
114-
collections.abc.Callable: typing.Callable,
115-
collections.abc.Iterable: typing.Iterable,
116-
collections.abc.Mapping: typing.Mapping,
117-
collections.abc.MutableMapping: typing.MutableMapping,
118-
collections.abc.Sequence: typing.Sequence,
119-
}
101+
_IS_PY39 = version.parse(platform.python_version()) < version.parse("3.10")
120102

121103

122104
def _adjust_annotations(annotation):
123105
origin = get_origin(annotation)
124-
if _IS_PY38:
125-
origin = _PY38_ORIGIN_MAP.get(origin, origin)
126-
if isinstance(annotation, typing_extensions._AnnotatedAlias):
127-
args = annotation.__metadata__
128-
else:
129-
args = get_args(annotation)
130-
else:
131-
args = get_args(annotation)
106+
args = get_args(annotation)
107+
if not _IS_PY39:
108+
from types import UnionType
109+
110+
if origin is UnionType:
111+
origin = Union
112+
132113
if isinstance(annotation, GenericAlias) or (inspect.isclass(annotation) and issubclass(annotation, PydanticBaseModel)):
133114
return SerializeAsAny[annotation]
134115
elif origin and args:
@@ -139,10 +120,6 @@ def _adjust_annotations(annotation):
139120
return ClassVar[_adjust_annotations(args[0])]
140121
else:
141122
try:
142-
if _IS_PY38 and isinstance(annotation, typing_extensions._AnnotatedAlias):
143-
if origin != annotation:
144-
origin = _adjust_annotations(origin)
145-
return typing_extensions.Annotated[(origin,) + tuple(_adjust_annotations(arg) for arg in args)]
146123
return origin[tuple(_adjust_annotations(arg) for arg in args)]
147124
except TypeError:
148125
raise TypeError(f"Could not adjust annotations for {origin}")

ccflow/tests/test_base_serialize.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1+
import platform
12
import unittest
23
from typing import ClassVar, Dict, List, Optional, Type, Union
34

45
import numpy as np
6+
from packaging import version
57
from pydantic import BaseModel as PydanticBaseModel, ConfigDict, ValidationError
68

79
from ccflow import BaseModel, NDArray
@@ -205,20 +207,27 @@ def test_serialize_as_any(self):
205207
from pydantic import SerializeAsAny
206208
from pydantic.types import constr
207209

210+
if version.parse(platform.python_version()) >= version.parse("3.10"):
211+
pipe_union = A | int
212+
else:
213+
pipe_union = Union[A, int]
214+
208215
class MyNestedModel(BaseModel):
209216
a1: A
210217
a2: Optional[Union[A, int]]
211218
a3: Dict[str, Optional[List[A]]]
212219
a4: ClassVar[A]
213220
a5: Type[A]
214221
a6: constr(min_length=1)
222+
a7: pipe_union
215223

216224
target = {
217225
"a1": SerializeAsAny[A],
218226
"a2": Optional[Union[SerializeAsAny[A], int]],
219227
"a4": ClassVar[SerializeAsAny[A]],
220228
"a5": Type[A],
221229
"a6": constr(min_length=1), # Uses Annotation
230+
"a7": Union[SerializeAsAny[A], int],
222231
}
223232
target["a3"] = dict[str, Optional[list[SerializeAsAny[A]]]]
224233
annotations = MyNestedModel.__annotations__
@@ -228,3 +237,4 @@ class MyNestedModel(BaseModel):
228237
self.assertEqual(str(annotations["a4"]), str(target["a4"]))
229238
self.assertEqual(str(annotations["a5"]), str(target["a5"]))
230239
self.assertEqual(str(annotations["a6"]), str(target["a6"]))
240+
self.assertEqual(str(annotations["a7"]), str(target["a7"]))

0 commit comments

Comments
 (0)