Skip to content

Commit 9c56443

Browse files
Merge branch 'main' into cupyactx
2 parents 5871ae7 + d8e8683 commit 9c56443

29 files changed

+565
-304
lines changed

.github/workflows/ci.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@ on:
77
schedule:
88
- cron: '17 3 * * 0'
99

10+
concurrency:
11+
group: ${{ github.head_ref || github.ref_name }}
12+
cancel-in-progress: true
13+
1014
jobs:
1115
typos:
1216
name: Typos

MANIFEST.in

Lines changed: 0 additions & 10 deletions
This file was deleted.

arraycontext/__init__.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
An array context is an abstraction that helps you dispatch between multiple
33
implementations of :mod:`numpy`-like :math:`n`-dimensional arrays.
44
"""
5+
from __future__ import annotations
56

67

78
__copyright__ = """
@@ -29,6 +30,7 @@
2930
"""
3031

3132
from .container import (
33+
ArithArrayContainer,
3234
ArrayContainer,
3335
ArrayContainerT,
3436
NotAnArrayContainerError,
@@ -72,6 +74,10 @@
7274
from .context import (
7375
Array,
7476
ArrayContext,
77+
ArrayOrArithContainer,
78+
ArrayOrArithContainerOrScalar,
79+
ArrayOrArithContainerOrScalarT,
80+
ArrayOrArithContainerT,
7581
ArrayOrContainer,
7682
ArrayOrContainerOrScalar,
7783
ArrayOrContainerOrScalarT,
@@ -96,10 +102,15 @@
96102

97103

98104
__all__ = (
105+
"ArithArrayContainer",
99106
"Array",
100107
"ArrayContainer",
101108
"ArrayContainerT",
102109
"ArrayContext",
110+
"ArrayOrArithContainer",
111+
"ArrayOrArithContainerOrScalar",
112+
"ArrayOrArithContainerOrScalarT",
113+
"ArrayOrArithContainerT",
103114
"ArrayOrContainer",
104115
"ArrayOrContainerOrScalar",
105116
"ArrayOrContainerOrScalarT",

arraycontext/container/__init__.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
.. currentmodule:: arraycontext
55
66
.. autoclass:: ArrayContainer
7+
.. autoclass:: ArithArrayContainer
78
.. class:: ArrayContainerT
89
910
A type variable with a lower bound of :class:`ArrayContainer`.
@@ -81,14 +82,15 @@
8182

8283
from collections.abc import Hashable, Sequence
8384
from functools import singledispatch
84-
from typing import TYPE_CHECKING, Any, Protocol, TypeAlias, TypeVar
85+
from typing import TYPE_CHECKING, Protocol, TypeAlias, TypeVar
8586

8687
# For use in singledispatch type annotations, because sphinx can't figure out
8788
# what 'np' is.
8889
import numpy
8990
import numpy as np
91+
from typing_extensions import Self
9092

91-
from arraycontext.context import ArrayContext
93+
from arraycontext.context import ArrayContext, ArrayOrScalar
9294

9395

9496
if TYPE_CHECKING:
@@ -145,6 +147,29 @@ class ArrayContainer(Protocol):
145147
# that are container-typed.
146148

147149

150+
class ArithArrayContainer(ArrayContainer, Protocol):
151+
"""
152+
A sub-protocol of :class:`ArrayContainer` that supports basic arithmetic.
153+
"""
154+
155+
# This is loose and permissive, assuming that any array can be added
156+
# to any container. The alternative would be to plaster type-ignores
157+
# on all those uses. Achieving typing precision on what broadcasting is
158+
# allowable seems like a huge endeavor and is likely not feasible without
159+
# a mypy plugin. Maybe some day? -AK, November 2024
160+
161+
def __neg__(self) -> Self: ...
162+
def __abs__(self) -> Self: ...
163+
def __add__(self, other: ArrayOrScalar | Self) -> Self: ...
164+
def __radd__(self, other: ArrayOrScalar | Self) -> Self: ...
165+
def __sub__(self, other: ArrayOrScalar | Self) -> Self: ...
166+
def __rsub__(self, other: ArrayOrScalar | Self) -> Self: ...
167+
def __mul__(self, other: ArrayOrScalar | Self) -> Self: ...
168+
def __rmul__(self, other: ArrayOrScalar | Self) -> Self: ...
169+
def __truediv__(self, other: ArrayOrScalar | Self) -> Self: ...
170+
def __rtruediv__(self, other: ArrayOrScalar | Self) -> Self: ...
171+
172+
148173
ArrayContainerT = TypeVar("ArrayContainerT", bound=ArrayContainer)
149174

150175

@@ -219,7 +244,7 @@ def is_array_container_type(cls: type) -> bool:
219244
is not serialize_container.__wrapped__)) # type:ignore[attr-defined]
220245

221246

222-
def is_array_container(ary: Any) -> bool:
247+
def is_array_container(ary: object) -> bool:
223248
"""
224249
:returns: *True* if the instance *ary* has a registered implementation of
225250
:func:`serialize_container`.

arraycontext/container/dataclass.py

Lines changed: 46 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
.. currentmodule:: arraycontext
55
.. autofunction:: dataclass_array_container
66
"""
7+
from __future__ import annotations
78

89

910
__copyright__ = """
@@ -30,6 +31,7 @@
3031
THE SOFTWARE.
3132
"""
3233

34+
from collections.abc import Mapping, Sequence
3335
from dataclasses import Field, fields, is_dataclass
3436
from typing import Union, get_args, get_origin
3537

@@ -57,11 +59,21 @@ def dataclass_array_container(cls: type) -> type:
5759
* a :class:`typing.Union` of array containers is considered an array container.
5860
* other type annotations, e.g. :class:`typing.Optional`, are not considered
5961
array containers, even if they wrap one.
62+
63+
.. note::
64+
65+
When type annotations are strings (e.g. because of
66+
``from __future__ import annotations``),
67+
this function relies on :func:`inspect.get_annotations`
68+
(with ``eval_str=True``) to obtain type annotations. This
69+
means that *cls* must live in a module that is importable.
6070
"""
6171

72+
from types import GenericAlias, UnionType
73+
6274
assert is_dataclass(cls)
6375

64-
def is_array_field(f: Field) -> bool:
76+
def is_array_field(f: Field, field_type: type) -> bool:
6577
# NOTE: unions of array containers are treated separately to handle
6678
# unions of only array containers, e.g. `Union[np.ndarray, Array]`, as
6779
# they can work seamlessly with arithmetic and traversal.
@@ -74,16 +86,17 @@ def is_array_field(f: Field) -> bool:
7486
#
7587
# This is not set in stone, but mostly driven by current usage!
7688

77-
origin = get_origin(f.type)
78-
if origin is Union:
79-
if all(is_array_type(arg) for arg in get_args(f.type)):
89+
origin = get_origin(field_type)
90+
# NOTE: `UnionType` is returned when using `Type1 | Type2`
91+
if origin in (Union, UnionType):
92+
if all(is_array_type(arg) for arg in get_args(field_type)):
8093
return True
8194
else:
8295
raise TypeError(
8396
f"Field '{f.name}' union contains non-array container "
8497
"arguments. All arguments must be array containers.")
8598

86-
if isinstance(f.type, str):
99+
if isinstance(field_type, str):
87100
raise TypeError(
88101
f"String annotation on field '{f.name}' not supported. "
89102
"(this may be due to 'from __future__ import annotations')")
@@ -94,39 +107,56 @@ def is_array_field(f: Field) -> bool:
94107
f"Field with 'init=False' not allowed: '{f.name}'")
95108

96109
# NOTE:
110+
# * `GenericAlias` catches typed `list`, `tuple`, etc.
97111
# * `_BaseGenericAlias` catches `List`, `Tuple`, etc.
98112
# * `_SpecialForm` catches `Any`, `Literal`, etc.
99113
from typing import ( # type: ignore[attr-defined]
100114
_BaseGenericAlias,
101115
_SpecialForm,
102116
)
103-
if isinstance(f.type, _BaseGenericAlias | _SpecialForm):
117+
if isinstance(field_type, GenericAlias | _BaseGenericAlias | _SpecialForm):
104118
# NOTE: anything except a Union is not allowed
105119
raise TypeError(
106120
f"Typing annotation not supported on field '{f.name}': "
107-
f"'{f.type!r}'")
121+
f"'{field_type!r}'")
108122

109-
if not isinstance(f.type, type):
123+
if not isinstance(field_type, type):
110124
raise TypeError(
111125
f"Field '{f.name}' not an instance of 'type': "
112-
f"'{f.type!r}'")
126+
f"'{field_type!r}'")
127+
128+
return is_array_type(field_type)
129+
130+
from inspect import get_annotations
113131

114-
return is_array_type(f.type)
132+
array_fields: list[Field] = []
133+
non_array_fields: list[Field] = []
134+
cls_ann: Mapping[str, type] | None = None
135+
for field in fields(cls):
136+
field_type_or_str = field.type
137+
if isinstance(field_type_or_str, str):
138+
if cls_ann is None:
139+
cls_ann = get_annotations(cls, eval_str=True)
140+
field_type = cls_ann[field.name]
141+
else:
142+
field_type = field_type_or_str
115143

116-
from pytools import partition
117-
array_fields, non_array_fields = partition(is_array_field, fields(cls))
144+
if is_array_field(field, field_type):
145+
array_fields.append(field)
146+
else:
147+
non_array_fields.append(field)
118148

119149
if not array_fields:
120150
raise ValueError(f"'{cls}' must have fields with array container type "
121151
"in order to use the 'dataclass_array_container' decorator")
122152

123-
return inject_dataclass_serialization(cls, array_fields, non_array_fields)
153+
return _inject_dataclass_serialization(cls, array_fields, non_array_fields)
124154

125155

126-
def inject_dataclass_serialization(
156+
def _inject_dataclass_serialization(
127157
cls: type,
128-
array_fields: tuple[Field, ...],
129-
non_array_fields: tuple[Field, ...]) -> type:
158+
array_fields: Sequence[Field],
159+
non_array_fields: Sequence[Field]) -> type:
130160
"""Implements :func:`~arraycontext.serialize_container` and
131161
:func:`~arraycontext.deserialize_container` for the given dataclass *cls*.
132162

0 commit comments

Comments
 (0)