1- # mypy: disallow-untyped-defs
2-
31"""
42.. currentmodule:: arraycontext
53
6- .. autoclass:: ArrayContainer
4+ .. class:: ArrayContainer
5+ A protocol for generic containers of the array type supported by the
6+ :class:`ArrayContext`.
7+
8+ The functionality required for the container to operated is supplied via
9+ :func:`functools.singledispatch`. Implementations of the following functions need
10+ to be registered for a type serving as an :class:`ArrayContainer`:
11+
12+ * :func:`serialize_container` for serialization, which gives the components
13+ of the array.
14+ * :func:`deserialize_container` for deserialization, which constructs a
15+ container from a set of components.
16+ * :func:`get_container_context_opt` retrieves the :class:`ArrayContext` from
17+ a container, if it has one.
18+
19+ This allows enumeration of the component arrays in a container and the
20+ construction of modified containers from an iterable of those component arrays.
21+
22+ Packages may register their own types as array containers. They must not
23+ register other types (e.g. :class:`list`) as array containers.
24+ The type :class:`numpy.ndarray` is considered an array container, but
25+ only arrays with dtype *object* may be used as such. (This is so
26+ because object arrays cannot be distinguished from non-object arrays
27+ via their type.)
28+
29+ The container and its serialization interface has goals and uses
30+ approaches similar to JAX's
31+ `PyTrees <https://jax.readthedocs.io/en/latest/pytrees.html>`__,
32+ however its implementation differs a bit.
33+
34+ .. note::
35+
36+ This class is used in type annotation and as a marker of array container
37+ attributes for :func:`~arraycontext.dataclass_array_container`.
38+ As a protocol, it is not intended as a superclass.
39+
40+ .. note::
41+
42+ For the benefit of type checkers, array containers are recognized by
43+ having the declaration::
44+
45+ __array_ufunc__: ClassVar[None] = None
46+
47+ in their body. In addition to its use as a recognition feature, this also
48+ prevents unintended arithmetic in conjunction with :mod:`numpy` arrays.
49+ This should be considered experimental for now, and it may well change.
50+
751.. autoclass:: ArithArrayContainer
852.. class:: ArrayContainerT
953
5195
5296from __future__ import annotations
5397
98+ from types import GenericAlias , UnionType
99+
100+ from numpy .typing import NDArray
101+
102+ from arraycontext .context import ArrayOrArithContainer , ArrayOrContainerOrScalar
103+
54104
55105__copyright__ = """
56106Copyright (C) 2020-1 University of Illinois Board of Trustees
78128
79129from collections .abc import Hashable , Sequence
80130from functools import singledispatch
81- from typing import TYPE_CHECKING , Protocol , TypeAlias , TypeVar
131+ from typing import (
132+ TYPE_CHECKING ,
133+ Any ,
134+ ClassVar ,
135+ Protocol ,
136+ TypeAlias ,
137+ TypeVar ,
138+ get_origin ,
139+ )
82140
83141# For use in singledispatch type annotations, because sphinx can't figure out
84142# what 'np' is.
85143import numpy
86144import numpy as np
87- from typing_extensions import Self
145+ from typing_extensions import Self , TypeIs
88146
89147
90148if TYPE_CHECKING :
91- from pymbolic .geometric_algebra import MultiVector
149+ from pymbolic .geometric_algebra import CoeffT , MultiVector
92150
93- from arraycontext import ArrayOrContainer
94151 from arraycontext .context import ArrayContext , ArrayOrScalar
95152
96153
97154# {{{ ArrayContainer
98155
99- class ArrayContainer (Protocol ):
100- """
101- A protocol for generic containers of the array type supported by the
102- :class:`ArrayContext`.
103-
104- The functionality required for the container to operated is supplied via
105- :func:`functools.singledispatch`. Implementations of the following functions need
106- to be registered for a type serving as an :class:`ArrayContainer`:
107-
108- * :func:`serialize_container` for serialization, which gives the components
109- of the array.
110- * :func:`deserialize_container` for deserialization, which constructs a
111- container from a set of components.
112- * :func:`get_container_context_opt` retrieves the :class:`ArrayContext` from
113- a container, if it has one.
114-
115- This allows enumeration of the component arrays in a container and the
116- construction of modified containers from an iterable of those component arrays.
117-
118- Packages may register their own types as array containers. They must not
119- register other types (e.g. :class:`list`) as array containers.
120- The type :class:`numpy.ndarray` is considered an array container, but
121- only arrays with dtype *object* may be used as such. (This is so
122- because object arrays cannot be distinguished from non-object arrays
123- via their type.)
124-
125- The container and its serialization interface has goals and uses
126- approaches similar to JAX's
127- `PyTrees <https://jax.readthedocs.io/en/latest/pytrees.html>`__,
128- however its implementation differs a bit.
129-
130- .. note::
131-
132- This class is used in type annotation and as a marker of array container
133- attributes for :func:`~arraycontext.dataclass_array_container`.
134- As a protocol, it is not intended as a superclass.
135- """
136-
137- # Array containers do not need to have any particular features, so this
138- # protocol is deliberately empty.
139-
140- # This *is* used as a type annotation in dataclasses that are processed
156+ class _UserDefinedArrayContainer (Protocol ):
157+ # This is used as a type annotation in dataclasses that are processed
141158 # by dataclass_array_container, where it's used to recognize attributes
142159 # that are container-typed.
143160
161+ # This method prevents ArrayContainer from matching any object, while
162+ # matching numpy object arrays and many array containers.
163+ __array_ufunc__ : ClassVar [None ]
144164
145- class ArithArrayContainer (ArrayContainer , Protocol ):
146- """
147- A sub-protocol of :class:`ArrayContainer` that supports basic arithmetic.
148- """
149165
166+ ArrayContainer : TypeAlias = NDArray [Any ] | _UserDefinedArrayContainer
167+
168+
169+ class _UserDefinedArithArrayContainer (_UserDefinedArrayContainer , Protocol ):
150170 # This is loose and permissive, assuming that any array can be added
151171 # to any container. The alternative would be to plaster type-ignores
152172 # on all those uses. Achieving typing precision on what broadcasting is
@@ -167,6 +187,9 @@ def __pow__(self, other: ArrayOrScalar | Self) -> Self: ...
167187 def __rpow__ (self , other : ArrayOrScalar | Self ) -> Self : ...
168188
169189
190+ ArithArrayContainer : TypeAlias = NDArray [Any ] | _UserDefinedArithArrayContainer
191+
192+
170193ArrayContainerT = TypeVar ("ArrayContainerT" , bound = ArrayContainer )
171194
172195
@@ -175,7 +198,8 @@ class NotAnArrayContainerError(TypeError):
175198
176199
177200SerializationKey : TypeAlias = Hashable
178- SerializedContainer : TypeAlias = Sequence [tuple [SerializationKey , "ArrayOrContainer" ]]
201+ SerializedContainer : TypeAlias = Sequence [
202+ tuple [SerializationKey , ArrayOrContainerOrScalar ]]
179203
180204
181205@singledispatch
@@ -221,7 +245,7 @@ def deserialize_container(
221245 f"'{ type (template ).__name__ } ' cannot be deserialized as a container" )
222246
223247
224- def is_array_container_type (cls : type ) -> bool :
248+ def is_array_container_type (cls : type | GenericAlias | UnionType ) -> bool :
225249 """
226250 :returns: *True* if the type *cls* has a registered implementation of
227251 :func:`serialize_container`, or if it is an :class:`ArrayContainer`.
@@ -233,15 +257,22 @@ def is_array_container_type(cls: type) -> bool:
233257 function will say that :class:`numpy.ndarray` is an array container
234258 type, only object arrays *actually are* array containers.
235259 """
236- assert isinstance (cls , type ), f"must pass a { type !r} , not a '{ cls !r} '"
260+ if cls is ArrayContainer :
261+ return True
262+
263+ while isinstance (cls , GenericAlias ):
264+ cls = get_origin (cls )
265+
266+ assert isinstance (cls , type ), (
267+ f"must pass a { type !r} , not a '{ cls !r} '" )
237268
238269 return (
239- cls is ArrayContainer
270+ cls is ArrayContainer # pyright: ignore[reportUnnecessaryComparison]
240271 or (serialize_container .dispatch (cls )
241272 is not serialize_container .__wrapped__ )) # type:ignore[attr-defined]
242273
243274
244- def is_array_container (ary : object ) -> bool :
275+ def is_array_container (ary : object ) -> TypeIs [ ArrayContainer ] :
245276 """
246277 :returns: *True* if the instance *ary* has a registered implementation of
247278 :func:`serialize_container`.
@@ -317,7 +348,7 @@ def _deserialize_ndarray_container( # type: ignore[misc]
317348# {{{ get_container_context_recursively
318349
319350def get_container_context_recursively_opt (
320- ary : ArrayContainer ) -> ArrayContext | None :
351+ ary : ArrayOrContainerOrScalar ) -> ArrayContext | None :
321352 """Walks the :class:`ArrayContainer` hierarchy to find an
322353 :class:`ArrayContext` associated with it.
323354
@@ -351,7 +382,7 @@ def get_container_context_recursively_opt(
351382 return actx
352383
353384
354- def get_container_context_recursively (ary : ArrayContainer ) -> ArrayContext | None :
385+ def get_container_context_recursively (ary : ArrayContainer ) -> ArrayContext :
355386 """Walks the :class:`ArrayContainer` hierarchy to find an
356387 :class:`ArrayContext` associated with it.
357388
@@ -362,13 +393,7 @@ def get_container_context_recursively(ary: ArrayContainer) -> ArrayContext | Non
362393 """
363394 actx = get_container_context_recursively_opt (ary )
364395 if actx is None :
365- # raise ValueError("no array context was found")
366- from warnings import warn
367- warn ("No array context was found. This will be an error starting in "
368- "July of 2022. If you would like the function to return "
369- "None if no array context was found, use "
370- "get_container_context_recursively_opt." ,
371- DeprecationWarning , stacklevel = 2 )
396+ raise ValueError ("no array context was found" )
372397
373398 return actx
374399
@@ -380,19 +405,20 @@ def get_container_context_recursively(ary: ArrayContainer) -> ArrayContext | Non
380405# FYI: This doesn't, and never should, make arraycontext directly depend on pymbolic.
381406# (Though clearly there exists a dependency via loopy.)
382407
383- def _serialize_multivec_as_container (mv : MultiVector ) -> SerializedContainer :
408+ def _serialize_multivec_as_container (
409+ mv : MultiVector [ArrayOrArithContainer ]
410+ ) -> SerializedContainer :
384411 return list (mv .data .items ())
385412
386413
387- # FIXME: Ignored due to https://github.com/python/mypy/issues/13040
388- def _deserialize_multivec_as_container ( # type: ignore[misc]
389- template : MultiVector ,
390- serialized : SerializedContainer ) -> MultiVector :
414+ def _deserialize_multivec_as_container (
415+ template : MultiVector [CoeffT ],
416+ serialized : SerializedContainer ) -> MultiVector [CoeffT ]:
391417 from pymbolic .geometric_algebra import MultiVector
392418 return MultiVector (dict (serialized ), space = template .space )
393419
394420
395- def _get_container_context_opt_from_multivec (mv : MultiVector ) -> None :
421+ def _get_container_context_opt_from_multivec (mv : MultiVector [ CoeffT ] ) -> None :
396422 return None
397423
398424
0 commit comments