44.. currentmodule:: arraycontext
55.. autofunction:: dataclass_array_container
66"""
7+ from __future__ import annotations
78
89
910__copyright__ = """
3031THE SOFTWARE.
3132"""
3233
34+ from collections .abc import Mapping , Sequence
3335from dataclasses import Field , fields , is_dataclass
3436from typing import Union , get_args , get_origin
3537
@@ -57,11 +59,21 @@ def dataclass_array_container(cls: type) -> type:
5759 * a :class:`typing.Union` of array containers is considered an array container.
5860 * other type annotations, e.g. :class:`typing.Optional`, are not considered
5961 array containers, even if they wrap one.
62+
63+ .. note::
64+
65+ When type annotations are strings (e.g. because of
66+ ``from __future__ import annotations``),
67+ this function relies on :func:`inspect.get_annotations`
68+ (with ``eval_str=True``) to obtain type annotations. This
69+ means that *cls* must live in a module that is importable.
6070 """
6171
72+ from types import GenericAlias , UnionType
73+
6274 assert is_dataclass (cls )
6375
64- def is_array_field (f : Field ) -> bool :
76+ def is_array_field (f : Field , field_type : type ) -> bool :
6577 # NOTE: unions of array containers are treated separately to handle
6678 # unions of only array containers, e.g. `Union[np.ndarray, Array]`, as
6779 # they can work seamlessly with arithmetic and traversal.
@@ -74,16 +86,17 @@ def is_array_field(f: Field) -> bool:
7486 #
7587 # This is not set in stone, but mostly driven by current usage!
7688
77- origin = get_origin (f .type )
78- if origin is Union :
79- if all (is_array_type (arg ) for arg in get_args (f .type )):
89+ origin = get_origin (field_type )
90+ # NOTE: `UnionType` is returned when using `Type1 | Type2`
91+ if origin in (Union , UnionType ):
92+ if all (is_array_type (arg ) for arg in get_args (field_type )):
8093 return True
8194 else :
8295 raise TypeError (
8396 f"Field '{ f .name } ' union contains non-array container "
8497 "arguments. All arguments must be array containers." )
8598
86- if isinstance (f . type , str ):
99+ if isinstance (field_type , str ):
87100 raise TypeError (
88101 f"String annotation on field '{ f .name } ' not supported. "
89102 "(this may be due to 'from __future__ import annotations')" )
@@ -94,39 +107,56 @@ def is_array_field(f: Field) -> bool:
94107 f"Field with 'init=False' not allowed: '{ f .name } '" )
95108
96109 # NOTE:
110+ # * `GenericAlias` catches typed `list`, `tuple`, etc.
97111 # * `_BaseGenericAlias` catches `List`, `Tuple`, etc.
98112 # * `_SpecialForm` catches `Any`, `Literal`, etc.
99113 from typing import ( # type: ignore[attr-defined]
100114 _BaseGenericAlias ,
101115 _SpecialForm ,
102116 )
103- if isinstance (f . type , _BaseGenericAlias | _SpecialForm ):
117+ if isinstance (field_type , GenericAlias | _BaseGenericAlias | _SpecialForm ):
104118 # NOTE: anything except a Union is not allowed
105119 raise TypeError (
106120 f"Typing annotation not supported on field '{ f .name } ': "
107- f"'{ f . type !r} '" )
121+ f"'{ field_type !r} '" )
108122
109- if not isinstance (f . type , type ):
123+ if not isinstance (field_type , type ):
110124 raise TypeError (
111125 f"Field '{ f .name } ' not an instance of 'type': "
112- f"'{ f .type !r} '" )
126+ f"'{ field_type !r} '" )
127+
128+ return is_array_type (field_type )
129+
130+ from inspect import get_annotations
113131
114- return is_array_type (f .type )
132+ array_fields : list [Field ] = []
133+ non_array_fields : list [Field ] = []
134+ cls_ann : Mapping [str , type ] | None = None
135+ for field in fields (cls ):
136+ field_type_or_str = field .type
137+ if isinstance (field_type_or_str , str ):
138+ if cls_ann is None :
139+ cls_ann = get_annotations (cls , eval_str = True )
140+ field_type = cls_ann [field .name ]
141+ else :
142+ field_type = field_type_or_str
115143
116- from pytools import partition
117- array_fields , non_array_fields = partition (is_array_field , fields (cls ))
144+ if is_array_field (field , field_type ):
145+ array_fields .append (field )
146+ else :
147+ non_array_fields .append (field )
118148
119149 if not array_fields :
120150 raise ValueError (f"'{ cls } ' must have fields with array container type "
121151 "in order to use the 'dataclass_array_container' decorator" )
122152
123- return inject_dataclass_serialization (cls , array_fields , non_array_fields )
153+ return _inject_dataclass_serialization (cls , array_fields , non_array_fields )
124154
125155
126- def inject_dataclass_serialization (
156+ def _inject_dataclass_serialization (
127157 cls : type ,
128- array_fields : tuple [Field , ... ],
129- non_array_fields : tuple [Field , ... ]) -> type :
158+ array_fields : Sequence [Field ],
159+ non_array_fields : Sequence [Field ]) -> type :
130160 """Implements :func:`~arraycontext.serialize_container` and
131161 :func:`~arraycontext.deserialize_container` for the given dataclass *cls*.
132162
0 commit comments