Skip to content

Commit 07046f0

Browse files
committed
Add all_type_leaves_satisfy_predicate, refactor type tests to use it
1 parent c51497d commit 07046f0

File tree

4 files changed

+123
-103
lines changed

4 files changed

+123
-103
lines changed

arraycontext/container/__init__.py

Lines changed: 11 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,6 @@
135135
from typing import (
136136
TYPE_CHECKING,
137137
TypeAlias,
138-
get_origin,
139138
)
140139

141140
# For use in singledispatch type annotations, because sphinx can't figure out
@@ -147,14 +146,15 @@
147146
from pytools.obj_array import ObjectArray, ObjectArrayND as ObjectArrayND
148147

149148
from arraycontext.typing import (
150-
ArithArrayContainer,
149+
ArithArrayContainer as ArithArrayContainer,
151150
ArrayContainer,
152151
ArrayContainerT,
153152
ArrayOrArithContainer,
154153
ArrayOrArithContainerOrScalar as ArrayOrArithContainerOrScalar,
155154
ArrayOrContainerOrScalar,
156155
_UserDefinedArithArrayContainer,
157156
_UserDefinedArrayContainer,
157+
all_type_leaves_satisfy_predicate,
158158
)
159159

160160

@@ -233,23 +233,15 @@ def is_array_container_type(cls: type | GenericAlias | UnionType) -> bool:
233233
function will say that :class:`numpy.ndarray` is an array container
234234
type, only object arrays *actually are* array containers.
235235
"""
236-
if cls is ArrayContainer or cls is ArithArrayContainer:
237-
return True
236+
def pred(tp: type) -> bool:
237+
return (
238+
tp is ObjectArray
239+
or tp is _UserDefinedArrayContainer
240+
or tp is _UserDefinedArithArrayContainer
241+
or (serialize_container.dispatch(tp)
242+
is not serialize_container.__wrapped__)) # pyright: ignore[reportUnknownMemberType, reportAttributeAccessIssue]
238243

239-
origin = get_origin(cls)
240-
if origin is not None:
241-
cls = origin # pyright: ignore[reportAny]
242-
243-
assert isinstance(cls, type), (
244-
f"must pass a {type!r}, not a '{cls!r}'")
245-
246-
return (
247-
cls is ObjectArray
248-
or cls is ArrayContainer # pyright: ignore[reportUnnecessaryComparison]
249-
or cls is _UserDefinedArrayContainer
250-
or cls is _UserDefinedArithArrayContainer
251-
or (serialize_container.dispatch(cls)
252-
is not serialize_container.__wrapped__)) # type:ignore[attr-defined]
244+
return all_type_leaves_satisfy_predicate(pred, cls)
253245

254246

255247
def is_array_container(ary: object) -> TypeIs[ArrayContainer]:
@@ -265,7 +257,7 @@ def is_array_container(ary: object) -> TypeIs[ArrayContainer]:
265257
"cheaper option, see is_array_container_type.",
266258
DeprecationWarning, stacklevel=2)
267259
return (serialize_container.dispatch(ary.__class__)
268-
is not serialize_container.__wrapped__ # type:ignore[attr-defined]
260+
is not serialize_container.__wrapped__ # pyright: ignore[reportUnknownMemberType, reportAttributeAccessIssue]
269261
# numpy values with scalar elements aren't array containers
270262
and not (isinstance(ary, np.ndarray)
271263
and ary.dtype.kind != "O")

arraycontext/container/dataclass.py

Lines changed: 32 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -37,30 +37,19 @@
3737
THE SOFTWARE.
3838
"""
3939

40-
# The import of 'Union' is type-ignored below because we're specifically importing
41-
# Union to pick apart type annotations.
42-
4340
from dataclasses import fields, is_dataclass
4441
from typing import (
4542
TYPE_CHECKING,
4643
NamedTuple,
4744
TypeVar,
48-
Union, # pyright: ignore[reportDeprecated]
49-
cast,
50-
get_args,
51-
get_origin,
5245
)
5346
from warnings import warn
5447

5548
import numpy as np
5649

57-
from pytools.obj_array import ObjectArray
58-
5950
from arraycontext.container import is_array_container_type
6051
from arraycontext.typing import (
61-
ArrayContainer,
62-
ArrayOrContainer,
63-
ArrayOrContainerOrScalar,
52+
all_type_leaves_satisfy_predicate,
6453
is_scalar_type,
6554
)
6655

@@ -83,17 +72,30 @@ class _Field(NamedTuple):
8372
type: type
8473

8574

86-
def _is_array_or_container_type(tp: type | GenericAlias | UnionType, /) -> bool:
87-
if tp is np.ndarray:
88-
warn("Encountered 'numpy.ndarray' in a dataclass_array_container. "
89-
"This is deprecated and will stop working in 2026. "
90-
"If you meant an object array, use pytools.obj_array.ObjectArray. "
91-
"For other uses, file an issue to discuss.",
92-
DeprecationWarning, stacklevel=3)
93-
return True
75+
def _is_array_or_container_type(
76+
tp: type | GenericAlias | UnionType | TypeVar, /, *,
77+
allow_scalar: bool = True,
78+
require_homogeneity: bool = True,
79+
) -> bool:
80+
def _is_array_or_container_or_scalar(tp: type) -> bool:
81+
if tp is np.ndarray:
82+
warn("Encountered 'numpy.ndarray' in a dataclass_array_container. "
83+
"This is deprecated and will stop working in 2026. "
84+
"If you meant an object array, use pytools.obj_array.ObjectArray. "
85+
"For other uses, file an issue to discuss.",
86+
DeprecationWarning, stacklevel=1)
87+
return True
88+
89+
from arraycontext import Array
90+
91+
return (
92+
is_array_container_type(tp)
93+
or tp is Array
94+
or (allow_scalar and is_scalar_type(tp)))
9495

95-
from arraycontext import Array
96-
return tp is Array or is_array_container_type(tp)
96+
return all_type_leaves_satisfy_predicate(
97+
_is_array_or_container_or_scalar, tp,
98+
require_homogeneity=require_homogeneity)
9799

98100

99101
def dataclass_array_container(cls: type[T]) -> type[T]:
@@ -120,8 +122,6 @@ def dataclass_array_container(cls: type[T]) -> type[T]:
120122
means that *cls* must live in a module that is importable.
121123
"""
122124

123-
from types import GenericAlias, UnionType
124-
125125
assert is_dataclass(cls)
126126

127127
def is_array_field(f: _Field) -> bool:
@@ -139,61 +139,17 @@ def is_array_field(f: _Field) -> bool:
139139
#
140140
# This is not set in stone, but mostly driven by current usage!
141141

142-
# pyright has no idea what we're up to. :)
143-
if field_type is ArrayContainer: # pyright: ignore[reportUnnecessaryComparison]
144-
return True
145-
if field_type is ArrayOrContainer: # pyright: ignore[reportUnnecessaryComparison]
146-
return True
147-
if field_type is ArrayOrContainerOrScalar: # pyright: ignore[reportUnnecessaryComparison]
148-
return True
149-
150-
origin = get_origin(field_type)
151-
152-
if origin is ObjectArray:
153-
return True
154-
155-
# NOTE: `UnionType` is returned when using `Type1 | Type2`
156-
if origin in (Union, UnionType): # pyright: ignore[reportDeprecated]
157-
for arg in get_args(field_type): # pyright: ignore[reportAny]
158-
if not (
159-
_is_array_or_container_type(cast("type", arg))
160-
or is_scalar_type(cast("type", arg))):
161-
raise TypeError(
162-
f"Field '{f.name}' union contains non-array container "
163-
f"type '{arg}'. All types must be array containers "
164-
"or arrays or scalars."
165-
)
166-
167-
return True
168-
169142
# NOTE: this should never happen due to using `inspect.get_annotations`
170143
assert not isinstance(field_type, str)
171144

172-
if __debug__:
173-
if not f.init:
174-
raise ValueError(
175-
f"Field with 'init=False' not allowed: '{f.name}'")
176-
177-
# NOTE:
178-
# * `GenericAlias` catches typed `list`, `tuple`, etc.
179-
# * `_BaseGenericAlias` catches `List`, `Tuple`, etc.
180-
# * `_SpecialForm` catches `Any`, `Literal`, etc.
181-
from typing import ( # type: ignore[attr-defined]
182-
_BaseGenericAlias,
183-
_SpecialForm,
184-
)
185-
if isinstance(field_type, GenericAlias | _BaseGenericAlias | _SpecialForm):
186-
# NOTE: anything except a Union is not allowed
187-
raise TypeError(
188-
f"Type annotation not supported on field '{f.name}': "
189-
f"'{field_type!r}'")
190-
191-
if not isinstance(field_type, type):
192-
raise TypeError(
193-
f"Field '{f.name}' not an instance of 'type': "
194-
f"'{field_type!r}'")
195-
196-
return _is_array_or_container_type(field_type)
145+
if not f.init:
146+
raise ValueError(
147+
f"Field with 'init=False' not allowed: '{f.name}'")
148+
149+
try:
150+
return _is_array_or_container_type(field_type)
151+
except TypeError as e:
152+
raise TypeError(f"Field '{f.name}': {e}") from None
197153

198154
from pytools import partition
199155

arraycontext/typing.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,11 @@
7171
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
7272
THE SOFTWARE.
7373
"""
74+
# The import of 'Union' is type-ignored below because we're specifically importing
75+
# Union to pick apart old/deprecated type annotations.
76+
77+
from functools import partial
78+
from types import GenericAlias, UnionType
7479
from typing import (
7580
TYPE_CHECKING,
7681
Any,
@@ -80,7 +85,9 @@
8085
SupportsInt,
8186
TypeAlias,
8287
TypeVar,
88+
Union, # pyright: ignore[reportDeprecated]
8389
cast,
90+
get_args,
8491
get_origin,
8592
overload,
8693
)
@@ -89,10 +96,13 @@
8996
from typing_extensions import Self, TypeIs
9097

9198
from pymbolic.typing import Integer, Scalar as _Scalar
99+
from pytools import partition2
92100
from pytools.obj_array import ObjectArrayND
93101

94102

95103
if TYPE_CHECKING:
104+
from collections.abc import Callable
105+
96106
from numpy.typing import DTypeLike
97107

98108
from pymbolic.typing import Integer
@@ -296,3 +306,56 @@ def shape_is_int_only(shape: tuple[Array | Integer, ...], /) -> tuple[int, ...]:
296306
) from None
297307

298308
return tuple(res)
309+
310+
311+
def all_type_leaves_satisfy_predicate(
312+
predicate: Callable[[type], bool],
313+
tp: type | GenericAlias | UnionType | TypeVar,
314+
/, *,
315+
require_homogeneity: bool = False,
316+
allow_containers_with_satisfying_types: bool = False,
317+
) -> bool:
318+
# This is horrible and brittle. I'm sorry.
319+
320+
rec = partial(
321+
all_type_leaves_satisfy_predicate,
322+
predicate,
323+
require_homogeneity=require_homogeneity,
324+
allow_containers_with_satisfying_types=allow_containers_with_satisfying_types
325+
)
326+
origin = get_origin(tp)
327+
args = get_args(tp)
328+
tp_or_origin = tp if origin is None else origin
329+
330+
if isinstance(tp_or_origin, TypeVar):
331+
bound = cast("type | None", tp_or_origin.__bound__)
332+
if bound is None:
333+
return False
334+
else:
335+
return rec(bound)
336+
337+
# NOTE: `UnionType` is returned when using `Type1 | Type2`
338+
if origin in (Union, UnionType): # pyright: ignore[reportDeprecated]
339+
yes_types, no_types = partition2(
340+
(rec(arg), arg) for arg in args) # pyright: ignore[reportAny]
341+
if require_homogeneity and yes_types and no_types:
342+
raise TypeError(f"union '{tp}' is non-homogeneous "
343+
f"in whether it satisfies '{predicate}'")
344+
345+
return not no_types
346+
347+
if not isinstance(tp_or_origin, type):
348+
raise TypeError(f"encountered non-type '{type(tp_or_origin)!r}'")
349+
350+
if predicate(tp_or_origin):
351+
return True
352+
353+
if args and not allow_containers_with_satisfying_types:
354+
# assume these are containers
355+
has_sat_types = any(rec(arg) for arg in args) # pyright: ignore[reportAny]
356+
357+
if has_sat_types:
358+
raise TypeError(f"container '{tp}' has an element type "
359+
f"satisfying '{predicate}'")
360+
361+
return False

test/test_utils.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -72,9 +72,9 @@ def test_dataclass_array_container() -> None:
7272
class ArrayContainerWithOptional:
7373
x: np.ndarray
7474
# Deliberately left as Optional to test compatibility.
75-
y: Optional[np.ndarray] # noqa: UP045
75+
y: Optional[np.ndarray] # noqa: UP045 # pyright: ignore[reportDeprecated]
7676

77-
with pytest.raises(TypeError, match="Field 'y' union contains non-array"):
77+
with pytest.raises(TypeError, match=r"Field 'y':.*non-homogeneous.*"):
7878
# NOTE: cannot have wrapped annotations (here by `Optional`)
7979
dataclass_array_container(ArrayContainerWithOptional)
8080

@@ -88,15 +88,15 @@ class ArrayContainerWithTuple:
8888
# Deliberately left as Tuple to test compatibility.
8989
y: Tuple[Array, Array] # noqa: UP006
9090

91-
with pytest.raises(TypeError, match="Type annotation not supported on field 'y'"):
91+
with pytest.raises(TypeError, match=r"Field 'y':.*has an element type.*"):
9292
dataclass_array_container(ArrayContainerWithTuple)
9393

9494
@dataclass
9595
class ArrayContainerWithTupleAlt:
9696
x: Array
9797
y: tuple[Array, Array]
9898

99-
with pytest.raises(TypeError, match="Type annotation not supported on field 'y'"):
99+
with pytest.raises(TypeError, match=r"Field 'y':.*has an element type.*"):
100100
dataclass_array_container(ArrayContainerWithTupleAlt)
101101

102102
# }}}
@@ -159,12 +159,21 @@ class ArrayContainerWithUnionAlt:
159159
@dataclass
160160
class ArrayContainerWithWrongUnion:
161161
x: np.ndarray
162-
y: np.ndarray | list[bool]
162+
y: np.ndarray | list[str]
163163

164-
with pytest.raises(TypeError, match="Field 'y' union contains non-array container"):
165-
# NOTE: bool is not an ArrayContainer, so y should fail
164+
with pytest.raises(TypeError, match=r"Field 'y':.*non-homogeneous.*"):
165+
# NOTE: str is not an ArrayContainer, so y should fail
166166
dataclass_array_container(ArrayContainerWithWrongUnion)
167167

168+
@dataclass
169+
class ArrayContainerWithWrongUnion2:
170+
x: np.ndarray
171+
y: np.ndarray | str
172+
173+
with pytest.raises(TypeError, match=r"Field 'y':.*non-homogeneous.*"):
174+
# NOTE: str is not an ArrayContainer, so y should fail
175+
dataclass_array_container(ArrayContainerWithWrongUnion2)
176+
168177
# }}}
169178

170179
# {{{ optional union
@@ -174,7 +183,7 @@ class ArrayContainerWithOptionalUnion:
174183
x: np.ndarray
175184
y: np.ndarray | None
176185

177-
with pytest.raises(TypeError, match="Field 'y' union contains non-array container"):
186+
with pytest.raises(TypeError, match=r"Field 'y':.*non-homogeneous.*"):
178187
# NOTE: None is not an ArrayContainer, so y should fail
179188
dataclass_array_container(ArrayContainerWithWrongUnion)
180189

0 commit comments

Comments
 (0)