Skip to content

Commit 84fe35f

Browse files
committed
Towards grudge typing
This improves many aspects of typing in arraycontext: - It improves type checking (and consistency) in the traversals. - It allows scalars consistently. - It adds some overloads for traversal functions. - It adds `rec_map_container`, which is simpler (and easier to type). - It makes array containers recognizable to the type checker. This works via a heuristic, by having `__array_ufunc__ == None`. - It adds more types in the base fake numpy and shifts some implementation aspects there.
1 parent 9f0934b commit 84fe35f

22 files changed

+1398
-615
lines changed

arraycontext/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
multimapped_over_array_containers,
6464
outer,
6565
rec_map_array_container,
66+
rec_map_container,
6667
rec_map_reduce_array_container,
6768
rec_multimap_array_container,
6869
rec_multimap_reduce_array_container,
@@ -84,6 +85,8 @@
8485
ArrayOrContainerOrScalar,
8586
ArrayOrContainerOrScalarT,
8687
ArrayOrContainerT,
88+
ArrayOrScalar,
89+
ArrayOrScalarT,
8790
ArrayT,
8891
Scalar,
8992
ScalarLike,
@@ -117,6 +120,8 @@
117120
"ArrayOrContainerOrScalar",
118121
"ArrayOrContainerOrScalarT",
119122
"ArrayOrContainerT",
123+
"ArrayOrScalar",
124+
"ArrayOrScalarT",
120125
"ArrayT",
121126
"BcastUntilActxArray",
122127
"CommonSubexpressionTag",
@@ -154,6 +159,7 @@
154159
"outer",
155160
"pytest_generate_tests_for_array_contexts",
156161
"rec_map_array_container",
162+
"rec_map_container",
157163
"rec_map_reduce_array_container",
158164
"rec_multimap_array_container",
159165
"rec_multimap_reduce_array_container",

arraycontext/container/__init__.py

Lines changed: 99 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,53 @@
1-
# mypy: disallow-untyped-defs
2-
31
"""
42
.. currentmodule:: arraycontext
53
6-
.. autoclass:: ArrayContainer
4+
.. class:: ArrayContainer
5+
A protocol for generic containers of the array type supported by the
6+
:class:`ArrayContext`.
7+
8+
The functionality required for the container to operated is supplied via
9+
:func:`functools.singledispatch`. Implementations of the following functions need
10+
to be registered for a type serving as an :class:`ArrayContainer`:
11+
12+
* :func:`serialize_container` for serialization, which gives the components
13+
of the array.
14+
* :func:`deserialize_container` for deserialization, which constructs a
15+
container from a set of components.
16+
* :func:`get_container_context_opt` retrieves the :class:`ArrayContext` from
17+
a container, if it has one.
18+
19+
This allows enumeration of the component arrays in a container and the
20+
construction of modified containers from an iterable of those component arrays.
21+
22+
Packages may register their own types as array containers. They must not
23+
register other types (e.g. :class:`list`) as array containers.
24+
The type :class:`numpy.ndarray` is considered an array container, but
25+
only arrays with dtype *object* may be used as such. (This is so
26+
because object arrays cannot be distinguished from non-object arrays
27+
via their type.)
28+
29+
The container and its serialization interface has goals and uses
30+
approaches similar to JAX's
31+
`PyTrees <https://jax.readthedocs.io/en/latest/pytrees.html>`__,
32+
however its implementation differs a bit.
33+
34+
.. note::
35+
36+
This class is used in type annotation and as a marker of array container
37+
attributes for :func:`~arraycontext.dataclass_array_container`.
38+
As a protocol, it is not intended as a superclass.
39+
40+
.. note::
41+
42+
For the benefit of type checkers, array containers are recognized by
43+
having the declaration::
44+
45+
__array_ufunc__: ClassVar[None] = None
46+
47+
in their body. In addition to its use as a recognition feature, this also
48+
prevents unintended arithmetic in conjunction with :mod:`numpy` arrays.
49+
This should be considered experimental for now, and it may well change.
50+
751
.. autoclass:: ArithArrayContainer
852
.. class:: ArrayContainerT
953
@@ -51,6 +95,12 @@
5195

5296
from __future__ import annotations
5397

98+
from types import GenericAlias, UnionType
99+
100+
from numpy.typing import NDArray
101+
102+
from arraycontext.context import ArrayOrArithContainer, ArrayOrContainerOrScalar
103+
54104

55105
__copyright__ = """
56106
Copyright (C) 2020-1 University of Illinois Board of Trustees
@@ -78,75 +128,45 @@
78128

79129
from collections.abc import Hashable, Sequence
80130
from functools import singledispatch
81-
from typing import TYPE_CHECKING, Protocol, TypeAlias, TypeVar
131+
from typing import (
132+
TYPE_CHECKING,
133+
Any,
134+
ClassVar,
135+
Protocol,
136+
TypeAlias,
137+
TypeVar,
138+
get_origin,
139+
)
82140

83141
# For use in singledispatch type annotations, because sphinx can't figure out
84142
# what 'np' is.
85143
import numpy
86144
import numpy as np
87-
from typing_extensions import Self
145+
from typing_extensions import Self, TypeIs
88146

89147

90148
if TYPE_CHECKING:
91-
from pymbolic.geometric_algebra import MultiVector
149+
from pymbolic.geometric_algebra import CoeffT, MultiVector
92150

93-
from arraycontext import ArrayOrContainer
94151
from arraycontext.context import ArrayContext, ArrayOrScalar
95152

96153

97154
# {{{ ArrayContainer
98155

99-
class ArrayContainer(Protocol):
100-
"""
101-
A protocol for generic containers of the array type supported by the
102-
:class:`ArrayContext`.
103-
104-
The functionality required for the container to operated is supplied via
105-
:func:`functools.singledispatch`. Implementations of the following functions need
106-
to be registered for a type serving as an :class:`ArrayContainer`:
107-
108-
* :func:`serialize_container` for serialization, which gives the components
109-
of the array.
110-
* :func:`deserialize_container` for deserialization, which constructs a
111-
container from a set of components.
112-
* :func:`get_container_context_opt` retrieves the :class:`ArrayContext` from
113-
a container, if it has one.
114-
115-
This allows enumeration of the component arrays in a container and the
116-
construction of modified containers from an iterable of those component arrays.
117-
118-
Packages may register their own types as array containers. They must not
119-
register other types (e.g. :class:`list`) as array containers.
120-
The type :class:`numpy.ndarray` is considered an array container, but
121-
only arrays with dtype *object* may be used as such. (This is so
122-
because object arrays cannot be distinguished from non-object arrays
123-
via their type.)
124-
125-
The container and its serialization interface has goals and uses
126-
approaches similar to JAX's
127-
`PyTrees <https://jax.readthedocs.io/en/latest/pytrees.html>`__,
128-
however its implementation differs a bit.
129-
130-
.. note::
131-
132-
This class is used in type annotation and as a marker of array container
133-
attributes for :func:`~arraycontext.dataclass_array_container`.
134-
As a protocol, it is not intended as a superclass.
135-
"""
136-
137-
# Array containers do not need to have any particular features, so this
138-
# protocol is deliberately empty.
139-
140-
# This *is* used as a type annotation in dataclasses that are processed
156+
class _UserDefinedArrayContainer(Protocol):
157+
# This is used as a type annotation in dataclasses that are processed
141158
# by dataclass_array_container, where it's used to recognize attributes
142159
# that are container-typed.
143160

161+
# This method prevents ArrayContainer from matching any object, while
162+
# matching numpy object arrays and many array containers.
163+
__array_ufunc__: ClassVar[None]
144164

145-
class ArithArrayContainer(ArrayContainer, Protocol):
146-
"""
147-
A sub-protocol of :class:`ArrayContainer` that supports basic arithmetic.
148-
"""
149165

166+
ArrayContainer: TypeAlias = NDArray[Any] | _UserDefinedArrayContainer
167+
168+
169+
class _UserDefinedArithArrayContainer(_UserDefinedArrayContainer, Protocol):
150170
# This is loose and permissive, assuming that any array can be added
151171
# to any container. The alternative would be to plaster type-ignores
152172
# on all those uses. Achieving typing precision on what broadcasting is
@@ -167,6 +187,9 @@ def __pow__(self, other: ArrayOrScalar | Self) -> Self: ...
167187
def __rpow__(self, other: ArrayOrScalar | Self) -> Self: ...
168188

169189

190+
ArithArrayContainer: TypeAlias = NDArray[Any] | _UserDefinedArithArrayContainer
191+
192+
170193
ArrayContainerT = TypeVar("ArrayContainerT", bound=ArrayContainer)
171194

172195

@@ -175,7 +198,8 @@ class NotAnArrayContainerError(TypeError):
175198

176199

177200
SerializationKey: TypeAlias = Hashable
178-
SerializedContainer: TypeAlias = Sequence[tuple[SerializationKey, "ArrayOrContainer"]]
201+
SerializedContainer: TypeAlias = Sequence[
202+
tuple[SerializationKey, ArrayOrContainerOrScalar]]
179203

180204

181205
@singledispatch
@@ -221,7 +245,7 @@ def deserialize_container(
221245
f"'{type(template).__name__}' cannot be deserialized as a container")
222246

223247

224-
def is_array_container_type(cls: type) -> bool:
248+
def is_array_container_type(cls: type | GenericAlias | UnionType) -> bool:
225249
"""
226250
:returns: *True* if the type *cls* has a registered implementation of
227251
:func:`serialize_container`, or if it is an :class:`ArrayContainer`.
@@ -233,15 +257,22 @@ def is_array_container_type(cls: type) -> bool:
233257
function will say that :class:`numpy.ndarray` is an array container
234258
type, only object arrays *actually are* array containers.
235259
"""
236-
assert isinstance(cls, type), f"must pass a {type!r}, not a '{cls!r}'"
260+
if cls is ArrayContainer:
261+
return True
262+
263+
while isinstance(cls, GenericAlias):
264+
cls = get_origin(cls)
265+
266+
assert isinstance(cls, type), (
267+
f"must pass a {type!r}, not a '{cls!r}'")
237268

238269
return (
239-
cls is ArrayContainer
270+
cls is ArrayContainer # pyright: ignore[reportUnnecessaryComparison]
240271
or (serialize_container.dispatch(cls)
241272
is not serialize_container.__wrapped__)) # type:ignore[attr-defined]
242273

243274

244-
def is_array_container(ary: object) -> bool:
275+
def is_array_container(ary: object) -> TypeIs[ArrayContainer]:
245276
"""
246277
:returns: *True* if the instance *ary* has a registered implementation of
247278
:func:`serialize_container`.
@@ -317,7 +348,7 @@ def _deserialize_ndarray_container( # type: ignore[misc]
317348
# {{{ get_container_context_recursively
318349

319350
def get_container_context_recursively_opt(
320-
ary: ArrayContainer) -> ArrayContext | None:
351+
ary: ArrayOrContainerOrScalar) -> ArrayContext | None:
321352
"""Walks the :class:`ArrayContainer` hierarchy to find an
322353
:class:`ArrayContext` associated with it.
323354
@@ -351,7 +382,7 @@ def get_container_context_recursively_opt(
351382
return actx
352383

353384

354-
def get_container_context_recursively(ary: ArrayContainer) -> ArrayContext | None:
385+
def get_container_context_recursively(ary: ArrayContainer) -> ArrayContext:
355386
"""Walks the :class:`ArrayContainer` hierarchy to find an
356387
:class:`ArrayContext` associated with it.
357388
@@ -362,13 +393,7 @@ def get_container_context_recursively(ary: ArrayContainer) -> ArrayContext | Non
362393
"""
363394
actx = get_container_context_recursively_opt(ary)
364395
if actx is None:
365-
# raise ValueError("no array context was found")
366-
from warnings import warn
367-
warn("No array context was found. This will be an error starting in "
368-
"July of 2022. If you would like the function to return "
369-
"None if no array context was found, use "
370-
"get_container_context_recursively_opt.",
371-
DeprecationWarning, stacklevel=2)
396+
raise ValueError("no array context was found")
372397

373398
return actx
374399

@@ -380,19 +405,20 @@ def get_container_context_recursively(ary: ArrayContainer) -> ArrayContext | Non
380405
# FYI: This doesn't, and never should, make arraycontext directly depend on pymbolic.
381406
# (Though clearly there exists a dependency via loopy.)
382407

383-
def _serialize_multivec_as_container(mv: MultiVector) -> SerializedContainer:
408+
def _serialize_multivec_as_container(
409+
mv: MultiVector[ArrayOrArithContainer]
410+
) -> SerializedContainer:
384411
return list(mv.data.items())
385412

386413

387-
# FIXME: Ignored due to https://github.com/python/mypy/issues/13040
388-
def _deserialize_multivec_as_container( # type: ignore[misc]
389-
template: MultiVector,
390-
serialized: SerializedContainer) -> MultiVector:
414+
def _deserialize_multivec_as_container(
415+
template: MultiVector[CoeffT],
416+
serialized: SerializedContainer) -> MultiVector[CoeffT]:
391417
from pymbolic.geometric_algebra import MultiVector
392418
return MultiVector(dict(serialized), space=template.space)
393419

394420

395-
def _get_container_context_opt_from_multivec(mv: MultiVector) -> None:
421+
def _get_container_context_opt_from_multivec(mv: MultiVector[CoeffT]) -> None:
396422
return None
397423

398424

arraycontext/container/arithmetic.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# mypy: disallow-untyped-defs
21
from __future__ import annotations
32

43

@@ -62,7 +61,11 @@
6261
if TYPE_CHECKING:
6362
from collections.abc import Callable
6463

65-
from arraycontext.context import ArrayContext, ArrayOrContainer
64+
from arraycontext.context import (
65+
ArrayContext,
66+
ArrayOrContainer,
67+
ArrayOrContainerOrScalar,
68+
)
6669

6770

6871
# {{{ with_container_arithmetic
@@ -772,11 +775,11 @@ def __post_init__(self) -> None:
772775

773776
def _binary_op(self,
774777
op: Callable[
775-
[ArrayOrContainer, ArrayOrContainer],
776-
ArrayOrContainer
778+
[ArrayOrContainerOrScalar, ArrayOrContainerOrScalar],
779+
ArrayOrContainerOrScalar
777780
],
778-
right: ArrayOrContainer
779-
) -> ArrayOrContainer:
781+
right: ArrayOrContainerOrScalar
782+
) -> ArrayOrContainerOrScalar:
780783
try:
781784
serialized = serialize_container(right)
782785
except NotAnArrayContainerError:
@@ -791,11 +794,11 @@ def _binary_op(self,
791794

792795
def _rev_binary_op(self,
793796
op: Callable[
794-
[ArrayOrContainer, ArrayOrContainer],
795-
ArrayOrContainer
797+
[ArrayOrContainerOrScalar, ArrayOrContainerOrScalar],
798+
ArrayOrContainerOrScalar
796799
],
797-
left: ArrayOrContainer
798-
) -> ArrayOrContainer:
800+
left: ArrayOrContainerOrScalar
801+
) -> ArrayOrContainerOrScalar:
799802
try:
800803
serialized = serialize_container(left)
801804
except NotAnArrayContainerError:

arraycontext/container/dataclass.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
# mypy: disallow-untyped-defs
2-
31
"""
42
.. currentmodule:: arraycontext
53
.. autofunction:: dataclass_array_container
@@ -34,7 +32,7 @@
3432
from dataclasses import fields, is_dataclass
3533
from typing import TYPE_CHECKING, NamedTuple, Union, get_args, get_origin
3634

37-
from arraycontext.container import is_array_container_type
35+
from arraycontext.container import ArrayContainer, is_array_container_type
3836

3937

4038
if TYPE_CHECKING:
@@ -99,7 +97,12 @@ def is_array_field(f: _Field) -> bool:
9997
#
10098
# This is not set in stone, but mostly driven by current usage!
10199

100+
# pyright has no idea what we're up to. :)
101+
if field_type is ArrayContainer: # pyright: ignore[reportUnnecessaryComparison]
102+
return True
103+
102104
origin = get_origin(field_type)
105+
103106
# NOTE: `UnionType` is returned when using `Type1 | Type2`
104107
if origin in (Union, UnionType):
105108
if all(is_array_type(arg) for arg in get_args(field_type)):

0 commit comments

Comments
 (0)