Skip to content

Commit 5871ae7

Browse files
Merge branch 'main' into cupyactx
2 parents d61f0cf + dee0ca4 commit 5871ae7

24 files changed

+382
-275
lines changed

.github/workflows/ci.yml

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -52,12 +52,6 @@ jobs:
5252
curl -L -O https://tiker.net/ci-support-v0
5353
. ./ci-support-v0
5454
55-
# NOTE: jax>=0.4.31 requires python 3.10 and uses pattern matching
56-
# which conflicts with our mypy.python_version = '3.8' setting
57-
CONDA_ENVIRONMENT=.test-conda-env-py3.yml
58-
sed -i "s/jax/jax<0.4.31/" "$CONDA_ENVIRONMENT"
59-
echo "- cupy" >> "$CONDA_ENVIRONMENT"
60-
6155
build_py_project_in_conda_env
6256
python -m pip install mypy pytest
6357
./run-mypy.sh

.gitlab-ci.yml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -107,9 +107,7 @@ Pylint:
107107

108108
Mypy:
109109
script: |
110-
# NOTE: jax>=0.4.31 requires python 3.10 and uses pattern matching
111-
# which conflicts with our mypy.python_version = '3.8' setting
112-
EXTRA_INSTALL="mypy pytest jax[cpu]<0.4.31"
110+
EXTRA_INSTALL="mypy pytest"
113111
114112
curl -L -O https://tiker.net/ci-support-v0
115113
. ./ci-support-v0

arraycontext/container/__init__.py

Lines changed: 12 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -79,24 +79,14 @@
7979
THE SOFTWARE.
8080
"""
8181

82+
from collections.abc import Hashable, Sequence
8283
from functools import singledispatch
83-
from typing import (
84-
TYPE_CHECKING,
85-
Any,
86-
Hashable,
87-
Iterable,
88-
Optional,
89-
Protocol,
90-
Sequence,
91-
Tuple,
92-
TypeVar,
93-
)
84+
from typing import TYPE_CHECKING, Any, Protocol, TypeAlias, TypeVar
9485

9586
# For use in singledispatch type annotations, because sphinx can't figure out
9687
# what 'np' is.
9788
import numpy
9889
import numpy as np
99-
from typing_extensions import TypeAlias
10090

10191
from arraycontext.context import ArrayContext
10292

@@ -154,8 +144,6 @@ class ArrayContainer(Protocol):
154144
# by dataclass_array_container, where it's used to recognize attributes
155145
# that are container-typed.
156146

157-
pass
158-
159147

160148
ArrayContainerT = TypeVar("ArrayContainerT", bound=ArrayContainer)
161149

@@ -165,7 +153,7 @@ class NotAnArrayContainerError(TypeError):
165153

166154

167155
SerializationKey: TypeAlias = Hashable
168-
SerializedContainer: TypeAlias = Sequence[Tuple[SerializationKey, "ArrayOrContainer"]]
156+
SerializedContainer: TypeAlias = Sequence[tuple[SerializationKey, "ArrayOrContainer"]]
169157

170158

171159
@singledispatch
@@ -252,7 +240,7 @@ def is_array_container(ary: Any) -> bool:
252240

253241

254242
@singledispatch
255-
def get_container_context_opt(ary: ArrayContainer) -> Optional[ArrayContext]:
243+
def get_container_context_opt(ary: ArrayContainer) -> ArrayContext | None:
256244
"""Retrieves the :class:`ArrayContext` from the container, if any.
257245
258246
This function is not recursive, so it will only search at the root level
@@ -306,7 +294,7 @@ def _deserialize_ndarray_container( # type: ignore[misc]
306294
# {{{ get_container_context_recursively
307295

308296
def get_container_context_recursively_opt(
309-
ary: ArrayContainer) -> Optional[ArrayContext]:
297+
ary: ArrayContainer) -> ArrayContext | None:
310298
"""Walks the :class:`ArrayContainer` hierarchy to find an
311299
:class:`ArrayContext` associated with it.
312300
@@ -340,7 +328,7 @@ def get_container_context_recursively_opt(
340328
return actx
341329

342330

343-
def get_container_context_recursively(ary: ArrayContainer) -> Optional[ArrayContext]:
331+
def get_container_context_recursively(ary: ArrayContainer) -> ArrayContext | None:
344332
"""Walks the :class:`ArrayContainer` hierarchy to find an
345333
:class:`ArrayContext` associated with it.
346334
@@ -369,14 +357,16 @@ def get_container_context_recursively(ary: ArrayContainer) -> Optional[ArrayCont
369357
# FYI: This doesn't, and never should, make arraycontext directly depend on pymbolic.
370358
# (Though clearly there exists a dependency via loopy.)
371359

372-
def _serialize_multivec_as_container(mv: MultiVector) -> Iterable[Tuple[Any, Any]]:
360+
def _serialize_multivec_as_container(mv: MultiVector) -> SerializedContainer:
373361
return list(mv.data.items())
374362

375363

376-
def _deserialize_multivec_as_container(template: MultiVector,
377-
iterable: Iterable[Tuple[Any, Any]]) -> MultiVector:
364+
# FIXME: Ignored due to https://github.com/python/mypy/issues/13040
365+
def _deserialize_multivec_as_container( # type: ignore[misc]
366+
template: MultiVector,
367+
serialized: SerializedContainer) -> MultiVector:
378368
from pymbolic.geometric_algebra import MultiVector
379-
return MultiVector(dict(iterable), space=template.space)
369+
return MultiVector(dict(serialized), space=template.space)
380370

381371

382372
def _get_container_context_opt_from_multivec(mv: MultiVector) -> None:

arraycontext/container/arithmetic.py

Lines changed: 19 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@
3434
"""
3535

3636
import enum
37-
from typing import Any, Callable, Optional, Tuple, TypeVar, Union
37+
from collections.abc import Callable
38+
from typing import Any, TypeVar
3839
from warnings import warn
3940

4041
import numpy as np
@@ -90,7 +91,7 @@ class _OpClass(enum.Enum):
9091
]
9192

9293

93-
def _format_unary_op_str(op_str: str, arg1: Union[Tuple[str, ...], str]) -> str:
94+
def _format_unary_op_str(op_str: str, arg1: tuple[str, ...] | str) -> str:
9495
if isinstance(arg1, tuple):
9596
arg1_entry, arg1_container = arg1
9697
return (f"{op_str.format(arg1_entry)} "
@@ -100,20 +101,14 @@ def _format_unary_op_str(op_str: str, arg1: Union[Tuple[str, ...], str]) -> str:
100101

101102

102103
def _format_binary_op_str(op_str: str,
103-
arg1: Union[Tuple[str, str], str],
104-
arg2: Union[Tuple[str, str], str]) -> str:
104+
arg1: tuple[str, str] | str,
105+
arg2: tuple[str, str] | str) -> str:
105106
if isinstance(arg1, tuple) and isinstance(arg2, tuple):
106-
import sys
107-
if sys.version_info >= (3, 10):
108-
strict_arg = ", strict=__debug__"
109-
else:
110-
strict_arg = ""
111-
112107
arg1_entry, arg1_container = arg1
113108
arg2_entry, arg2_container = arg2
114109
return (f"{op_str.format(arg1_entry, arg2_entry)} "
115110
f"for {arg1_entry}, {arg2_entry} "
116-
f"in zip({arg1_container}, {arg2_container}{strict_arg})")
111+
f"in zip({arg1_container}, {arg2_container}, strict=__debug__)")
117112

118113
elif isinstance(arg1, tuple):
119114
arg1_entry, arg1_container = arg1
@@ -160,23 +155,23 @@ class ComplainingNumpyNonObjectArray(metaclass=ComplainingNumpyNonObjectArrayMet
160155

161156
def with_container_arithmetic(
162157
*,
163-
number_bcasts_across: Optional[bool] = None,
164-
bcasts_across_obj_array: Optional[bool] = None,
165-
container_types_bcast_across: Optional[Tuple[type, ...]] = None,
158+
number_bcasts_across: bool | None = None,
159+
bcasts_across_obj_array: bool | None = None,
160+
container_types_bcast_across: tuple[type, ...] | None = None,
166161
arithmetic: bool = True,
167162
matmul: bool = False,
168163
bitwise: bool = False,
169164
shift: bool = False,
170-
_cls_has_array_context_attr: Optional[bool] = None,
171-
eq_comparison: Optional[bool] = None,
172-
rel_comparison: Optional[bool] = None,
165+
_cls_has_array_context_attr: bool | None = None,
166+
eq_comparison: bool | None = None,
167+
rel_comparison: bool | None = None,
173168

174169
# deprecated:
175-
bcast_number: Optional[bool] = None,
176-
bcast_obj_array: Optional[bool] = None,
170+
bcast_number: bool | None = None,
171+
bcast_obj_array: bool | None = None,
177172
bcast_numpy_array: bool = False,
178-
_bcast_actx_array_type: Optional[bool] = None,
179-
bcast_container_types: Optional[Tuple[type, ...]] = None,
173+
_bcast_actx_array_type: bool | None = None,
174+
bcast_container_types: tuple[type, ...] | None = None,
180175
) -> Callable[[type], type]:
181176
"""A class decorator that implements built-in operators for array containers
182177
by propagating the operations to the elements of the container.
@@ -482,7 +477,7 @@ def same_key(k1: T, k2: T) -> T:
482477
assert k1 == k2
483478
return k1
484479

485-
def tup_str(t: Tuple[str, ...]) -> str:
480+
def tup_str(t: tuple[str, ...]) -> str:
486481
if not t:
487482
return "()"
488483
else:
@@ -544,7 +539,8 @@ def {fname}(arg1):
544539
_format_binary_op_str(op_str, expr_arg1, expr_arg2)
545540
for (key_arg1, expr_arg1), (key_arg2, expr_arg2) in zip(
546541
cls._serialize_init_arrays_code("arg1").items(),
547-
cls._serialize_init_arrays_code("arg2").items())
542+
cls._serialize_init_arrays_code("arg2").items(),
543+
strict=True)
548544
})
549545
bcast_init_args_arg1_is_outer = cls._deserialize_init_arrays_code("arg1", {
550546
key_arg1: _format_binary_op_str(op_str, expr_arg1, "arg2")

arraycontext/container/dataclass.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
"""
3232

3333
from dataclasses import Field, fields, is_dataclass
34-
from typing import Tuple, Union, get_args, get_origin
34+
from typing import Union, get_args, get_origin
3535

3636
from arraycontext.container import is_array_container_type
3737

@@ -83,14 +83,15 @@ def is_array_field(f: Field) -> bool:
8383
f"Field '{f.name}' union contains non-array container "
8484
"arguments. All arguments must be array containers.")
8585

86+
if isinstance(f.type, str):
87+
raise TypeError(
88+
f"String annotation on field '{f.name}' not supported. "
89+
"(this may be due to 'from __future__ import annotations')")
90+
8691
if __debug__:
8792
if not f.init:
8893
raise ValueError(
89-
f"'init=False' field not allowed: '{f.name}'")
90-
91-
if isinstance(f.type, str):
92-
raise TypeError(
93-
f"string annotation on field '{f.name}' not supported")
94+
f"Field with 'init=False' not allowed: '{f.name}'")
9495

9596
# NOTE:
9697
# * `_BaseGenericAlias` catches `List`, `Tuple`, etc.
@@ -99,15 +100,15 @@ def is_array_field(f: Field) -> bool:
99100
_BaseGenericAlias,
100101
_SpecialForm,
101102
)
102-
if isinstance(f.type, (_BaseGenericAlias, _SpecialForm)):
103+
if isinstance(f.type, _BaseGenericAlias | _SpecialForm):
103104
# NOTE: anything except a Union is not allowed
104105
raise TypeError(
105-
f"typing annotation not supported on field '{f.name}': "
106+
f"Typing annotation not supported on field '{f.name}': "
106107
f"'{f.type!r}'")
107108

108109
if not isinstance(f.type, type):
109110
raise TypeError(
110-
f"field '{f.name}' not an instance of 'type': "
111+
f"Field '{f.name}' not an instance of 'type': "
111112
f"'{f.type!r}'")
112113

113114
return is_array_type(f.type)
@@ -124,8 +125,8 @@ def is_array_field(f: Field) -> bool:
124125

125126
def inject_dataclass_serialization(
126127
cls: type,
127-
array_fields: Tuple[Field, ...],
128-
non_array_fields: Tuple[Field, ...]) -> type:
128+
array_fields: tuple[Field, ...],
129+
non_array_fields: tuple[Field, ...]) -> type:
129130
"""Implements :func:`~arraycontext.serialize_container` and
130131
:func:`~arraycontext.deserialize_container` for the given dataclass *cls*.
131132

0 commit comments

Comments
 (0)