Skip to content

Commit 0a31907

Browse files
committed
Enable TC ruff lint
1 parent fca3855 commit 0a31907

File tree

21 files changed

+151
-68
lines changed

21 files changed

+151
-68
lines changed

arraycontext/container/__init__.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,6 @@
4040
4141
:canonical: arraycontext.ArrayContainerT
4242
43-
.. class:: ArrayOrContainerT
44-
45-
:canonical: arraycontext.ArrayOrContainerT
46-
4743
.. class:: SerializationKey
4844
4945
:canonical: arraycontext.SerializationKey
@@ -90,13 +86,12 @@
9086
import numpy as np
9187
from typing_extensions import Self
9288

93-
from arraycontext.context import ArrayContext, ArrayOrScalar
94-
9589

9690
if TYPE_CHECKING:
9791
from pymbolic.geometric_algebra import MultiVector
9892

9993
from arraycontext import ArrayOrContainer
94+
from arraycontext.context import ArrayContext, ArrayOrScalar
10095

10196

10297
# {{{ ArrayContainer

arraycontext/container/arithmetic.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,10 @@
3737

3838
import enum
3939
import operator
40-
from collections.abc import Callable
4140
from dataclasses import dataclass, field
4241
from functools import partialmethod
4342
from numbers import Number
44-
from typing import Any, TypeVar
43+
from typing import TYPE_CHECKING, Any, TypeVar
4544
from warnings import warn
4645

4746
import numpy as np
@@ -51,7 +50,12 @@
5150
deserialize_container,
5251
serialize_container,
5352
)
54-
from arraycontext.context import ArrayContext, ArrayOrContainer
53+
54+
55+
if TYPE_CHECKING:
56+
from collections.abc import Callable
57+
58+
from arraycontext.context import ArrayContext, ArrayOrContainer
5559

5660

5761
# {{{ with_container_arithmetic

arraycontext/container/dataclass.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,16 @@
3131
THE SOFTWARE.
3232
"""
3333

34-
from collections.abc import Mapping, Sequence
3534
from dataclasses import fields, is_dataclass
36-
from typing import NamedTuple, Union, get_args, get_origin
35+
from typing import TYPE_CHECKING, NamedTuple, Union, get_args, get_origin
3736

3837
from arraycontext.container import is_array_container_type
3938

4039

40+
if TYPE_CHECKING:
41+
from collections.abc import Mapping, Sequence
42+
43+
4144
# {{{ dataclass containers
4245

4346
class _Field(NamedTuple):

arraycontext/container/traversal.py

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,8 @@
7070
THE SOFTWARE.
7171
"""
7272

73-
from collections.abc import Callable, Iterable
7473
from functools import partial, singledispatch, update_wrapper
75-
from typing import Any, cast
74+
from typing import TYPE_CHECKING, Any, cast
7675
from warnings import warn
7776

7877
import numpy as np
@@ -87,14 +86,19 @@
8786
get_container_context_recursively_opt,
8887
serialize_container,
8988
)
90-
from arraycontext.context import (
91-
Array,
92-
ArrayContext,
93-
ArrayOrContainer,
94-
ArrayOrContainerOrScalar,
95-
ArrayOrContainerT,
96-
ScalarLike,
97-
)
89+
90+
91+
if TYPE_CHECKING:
92+
from collections.abc import Callable, Iterable
93+
94+
from arraycontext.context import (
95+
Array,
96+
ArrayContext,
97+
ArrayOrContainer,
98+
ArrayOrContainerOrScalar,
99+
ArrayOrContainerT,
100+
ScalarLike,
101+
)
98102

99103

100104
# {{{ array container traversal helpers
@@ -414,7 +418,7 @@ def rec(keys: tuple[SerializationKey, ...],
414418
try:
415419
iterable = serialize_container(ary_)
416420
except NotAnArrayContainerError:
417-
return cast(ArrayOrContainer, f(keys, cast(Array, ary_)))
421+
return cast("ArrayOrContainer", f(keys, cast("Array", ary_)))
418422
else:
419423
return deserialize_container(ary_, [
420424
(key, rec((*keys, key), subary)) for key, subary in iterable
@@ -699,7 +703,7 @@ def _flatten(subary: ArrayOrContainer) -> list[Array]:
699703
try:
700704
iterable = serialize_container(subary)
701705
except NotAnArrayContainerError:
702-
subary_c = cast(Array, subary)
706+
subary_c = cast("Array", subary)
703707

704708
if common_dtype is None:
705709
common_dtype = subary_c.dtype
@@ -786,7 +790,7 @@ def _unflatten(template_subary: ArrayOrContainer) -> ArrayOrContainer:
786790
try:
787791
iterable = serialize_container(template_subary)
788792
except NotAnArrayContainerError:
789-
template_subary_c = cast(Array, template_subary)
793+
template_subary_c = cast("Array", template_subary)
790794

791795
# {{{ validate subary
792796

@@ -877,7 +881,7 @@ def _unflatten(template_subary: ArrayOrContainer) -> ArrayOrContainer:
877881
raise ValueError("'template' and 'ary' sizes do not match: "
878882
"'ary' is too large")
879883

880-
return cast(ArrayOrContainerT, result)
884+
return cast("ArrayOrContainerT", result)
881885

882886

883887
def flat_size_and_dtype(
@@ -895,7 +899,7 @@ def _flat_size(subary: ArrayOrContainer) -> Array | Integer:
895899
try:
896900
iterable = serialize_container(subary)
897901
except NotAnArrayContainerError:
898-
subary_c = cast(Array, subary)
902+
subary_c = cast("Array", subary)
899903

900904
if common_dtype is None:
901905
common_dtype = subary_c.dtype

arraycontext/context.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,11 @@
8585
Types and Type Variables for Arrays and Containers
8686
--------------------------------------------------
8787
88+
.. autodata:: ScalarLike
89+
:noindex:
90+
91+
A type alias of :data:`pymbolic.Scalar`.
92+
8893
.. autoclass:: Array
8994
9095
.. autodata:: ArrayT
@@ -176,11 +181,11 @@
176181

177182
from pymbolic.typing import Integer, Scalar as _Scalar
178183
from pytools import memoize_method
179-
from pytools.tag import ToTagSetConvertible
180184

181185

182186
if TYPE_CHECKING:
183187
import loopy
188+
from pytools.tag import ToTagSetConvertible
184189

185190
from arraycontext.container import ArithArrayContainer, ArrayContainer
186191

@@ -254,7 +259,7 @@ def __rtruediv__(self, other: Self | ScalarLike) -> Array: ...
254259
#
255260
# For now, they're purposefully not in the main arraycontext.* name space.
256261
ArrayT = TypeVar("ArrayT", bound=Array)
257-
ArrayOrScalar: TypeAlias = "Array | ScalarLike"
262+
ArrayOrScalar: TypeAlias = Array | ScalarLike
258263
ArrayOrContainer: TypeAlias = "Array | ArrayContainer"
259264
ArrayOrArithContainer: TypeAlias = "Array | ArithArrayContainer"
260265
ArrayOrContainerT = TypeVar("ArrayOrContainerT", bound=ArrayOrContainer)

arraycontext/impl/jax/__init__.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,16 +29,21 @@
2929
THE SOFTWARE.
3030
"""
3131

32-
from collections.abc import Callable
3332

34-
import numpy as np
33+
from typing import TYPE_CHECKING
3534

36-
from pytools.tag import ToTagSetConvertible
35+
import numpy as np
3736

3837
from arraycontext.container.traversal import rec_map_array_container, with_array_context
3938
from arraycontext.context import Array, ArrayContext, ArrayOrContainer, ScalarLike
4039

4140

41+
if TYPE_CHECKING:
42+
from collections.abc import Callable
43+
44+
from pytools.tag import ToTagSetConvertible
45+
46+
4247
class EagerJAXArrayContext(ArrayContext):
4348
"""
4449
A :class:`ArrayContext` that uses

arraycontext/impl/jax/fake_numpy.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
THE SOFTWARE.
2626
"""
2727
from functools import partial, reduce
28+
from typing import TYPE_CHECKING
2829

2930
import numpy as np
3031

@@ -39,10 +40,13 @@
3940
rec_map_reduce_array_container,
4041
rec_multimap_array_container,
4142
)
42-
from arraycontext.context import Array, ArrayOrContainer
4343
from arraycontext.fake_numpy import BaseFakeNumpyLinalgNamespace, BaseFakeNumpyNamespace
4444

4545

46+
if TYPE_CHECKING:
47+
from arraycontext.context import Array, ArrayOrContainer
48+
49+
4650
class EagerJAXFakeNumpyLinalgNamespace(BaseFakeNumpyLinalgNamespace):
4751
# Everything is implemented in the base class for now.
4852
pass

arraycontext/impl/numpy/__init__.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,11 @@
3333
THE SOFTWARE.
3434
"""
3535

36-
from typing import Any, overload
36+
from typing import TYPE_CHECKING, Any, overload
3737

3838
import numpy as np
3939

4040
import loopy as lp
41-
from pytools.tag import ToTagSetConvertible
4241

4342
from arraycontext.container.traversal import rec_map_array_container, with_array_context
4443
from arraycontext.context import (
@@ -52,6 +51,10 @@
5251
)
5352

5453

54+
if TYPE_CHECKING:
55+
from pytools.tag import ToTagSetConvertible
56+
57+
5558
class NumpyNonObjectArrayMetaclass(type):
5659
def __instancecheck__(cls, instance: Any) -> bool:
5760
return isinstance(instance, np.ndarray) and instance.dtype != object

arraycontext/impl/numpy/fake_numpy.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
"""
2727

2828
from functools import partial, reduce
29-
from typing import cast
29+
from typing import TYPE_CHECKING, cast
3030

3131
import numpy as np
3232

@@ -37,13 +37,16 @@
3737
rec_multimap_array_container,
3838
rec_multimap_reduce_array_container,
3939
)
40-
from arraycontext.context import Array, ArrayOrContainer
4140
from arraycontext.fake_numpy import (
4241
BaseFakeNumpyLinalgNamespace,
4342
BaseFakeNumpyNamespace,
4443
)
4544

4645

46+
if TYPE_CHECKING:
47+
from arraycontext.context import Array, ArrayOrContainer
48+
49+
4750
class NumpyFakeNumpyLinalgNamespace(BaseFakeNumpyLinalgNamespace):
4851
# Everything is implemented in the base class for now.
4952
pass
@@ -150,7 +153,7 @@ def array_equal(self, a: ArrayOrContainer, b: ArrayOrContainer) -> Array:
150153
return false_ary
151154
return np.logical_and.reduce(
152155
[(true_ary if kx_i == ky_i else false_ary)
153-
and cast(np.ndarray, self.array_equal(x_i, y_i))
156+
and cast("np.ndarray", self.array_equal(x_i, y_i))
154157
for (kx_i, x_i), (ky_i, y_i)
155158
in zip(serialized_x, serialized_y, strict=True)],
156159
initial=true_ary)

arraycontext/impl/pyopencl/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,14 +31,11 @@
3131
THE SOFTWARE.
3232
"""
3333

34-
from collections.abc import Callable
3534
from typing import TYPE_CHECKING, Literal
3635
from warnings import warn
3736

3837
import numpy as np
3938

40-
from pytools.tag import ToTagSetConvertible
41-
4239
from arraycontext.container.traversal import rec_map_array_container, with_array_context
4340
from arraycontext.context import (
4441
Array,
@@ -50,9 +47,12 @@
5047

5148

5249
if TYPE_CHECKING:
50+
from collections.abc import Callable
51+
5352
import loopy as lp
5453
import pyopencl as cl
5554
import pyopencl.array as cl_array
55+
from pytools.tag import ToTagSetConvertible
5656

5757

5858
# {{{ PyOpenCLArrayContext

0 commit comments

Comments
 (0)