Skip to content

Commit 600c659

Browse files
alexfiklinducer
authored andcommitted
enable and apply isort
1 parent 812b5ea commit 600c659

23 files changed

+205
-206
lines changed

arraycontext/__init__.py

Lines changed: 25 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -29,61 +29,38 @@
2929
"""
3030

3131
import sys
32-
from .context import (
33-
ArrayContext,
34-
35-
Scalar, ScalarLike,
36-
Array, ArrayT,
37-
ArrayOrContainer, ArrayOrContainerT,
38-
ArrayOrContainerOrScalar, ArrayOrContainerOrScalarT,
39-
40-
tag_axes)
41-
42-
from .transform_metadata import (CommonSubexpressionTag,
43-
ElementwiseMapKernelTag)
44-
45-
# deprecated, remove in 2022.
46-
from .metadata import _FirstAxisIsElementsTag
4732

4833
from .container import (
49-
ArrayContainer, ArrayContainerT,
50-
NotAnArrayContainerError,
51-
is_array_container, is_array_container_type,
52-
get_container_context_opt,
53-
get_container_context_recursively, get_container_context_recursively_opt,
54-
serialize_container, deserialize_container,
55-
register_multivector_as_array_container)
34+
ArrayContainer, ArrayContainerT, NotAnArrayContainerError, deserialize_container,
35+
get_container_context_opt, get_container_context_recursively,
36+
get_container_context_recursively_opt, is_array_container,
37+
is_array_container_type, register_multivector_as_array_container,
38+
serialize_container)
5639
from .container.arithmetic import with_container_arithmetic
5740
from .container.dataclass import dataclass_array_container
58-
5941
from .container.traversal import (
60-
map_array_container,
61-
multimap_array_container,
62-
rec_map_array_container,
63-
rec_multimap_array_container,
64-
mapped_over_array_containers,
65-
multimapped_over_array_containers,
66-
map_reduce_array_container,
67-
multimap_reduce_array_container,
68-
rec_map_reduce_array_container,
69-
rec_multimap_reduce_array_container,
70-
thaw, freeze,
71-
flatten, unflatten, flat_size_and_dtype,
72-
from_numpy, to_numpy,
73-
outer, with_array_context)
74-
75-
from .impl.pyopencl import PyOpenCLArrayContext
76-
from .impl.pytato import (PytatoPyOpenCLArrayContext,
77-
PytatoJAXArrayContext)
42+
flat_size_and_dtype, flatten, freeze, from_numpy, map_array_container,
43+
map_reduce_array_container, mapped_over_array_containers,
44+
multimap_array_container, multimap_reduce_array_container,
45+
multimapped_over_array_containers, outer, rec_map_array_container,
46+
rec_map_reduce_array_container, rec_multimap_array_container,
47+
rec_multimap_reduce_array_container, thaw, to_numpy, unflatten,
48+
with_array_context)
49+
from .context import (
50+
Array, ArrayContext, ArrayOrContainer, ArrayOrContainerOrScalar,
51+
ArrayOrContainerOrScalarT, ArrayOrContainerT, ArrayT, Scalar, ScalarLike,
52+
tag_axes)
7853
from .impl.jax import EagerJAXArrayContext
79-
80-
from .pytest import (
81-
PytestArrayContextFactory,
82-
PytestPyOpenCLArrayContextFactory,
83-
pytest_generate_tests_for_array_contexts,
84-
pytest_generate_tests_for_pyopencl_array_context)
85-
54+
from .impl.pyopencl import PyOpenCLArrayContext
55+
from .impl.pytato import PytatoJAXArrayContext, PytatoPyOpenCLArrayContext
8656
from .loopy import make_loopy_program
57+
# deprecated, remove in 2022.
58+
from .metadata import _FirstAxisIsElementsTag
59+
from .pytest import (
60+
PytestArrayContextFactory, PytestPyOpenCLArrayContextFactory,
61+
pytest_generate_tests_for_array_contexts,
62+
pytest_generate_tests_for_pyopencl_array_context)
63+
from .transform_metadata import CommonSubexpressionTag, ElementwiseMapKernelTag
8764

8865

8966
__all__ = (

arraycontext/container/__init__.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,17 +69,19 @@
6969
"""
7070

7171
from functools import singledispatch
72-
from arraycontext.context import ArrayContext
73-
from typing import Any, Iterable, Tuple, Optional, TypeVar, Protocol, TYPE_CHECKING
74-
import numpy as np
72+
from typing import TYPE_CHECKING, Any, Iterable, Optional, Protocol, Tuple, TypeVar
7573

7674
# For use in singledispatch type annotations, because sphinx can't figure out
7775
# what 'np' is.
7876
import numpy
77+
import numpy as np
78+
79+
from arraycontext.context import ArrayContext
7980

8081

8182
if TYPE_CHECKING:
8283
from pymbolic.geometric_algebra import MultiVector
84+
8385
from arraycontext import ArrayOrContainer
8486

8587

arraycontext/container/arithmetic.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import enum
99

10+
1011
__copyright__ = """
1112
Copyright (C) 2020-1 University of Illinois Board of Trustees
1213
"""
@@ -31,8 +32,8 @@
3132
THE SOFTWARE.
3233
"""
3334

35+
from typing import Any, Callable, Optional, Tuple, Type, TypeVar, Union
3436
from warnings import warn
35-
from typing import Any, Callable, Optional, Tuple, TypeVar, Union, Type
3637

3738
import numpy as np
3839

arraycontext/container/dataclass.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,9 @@
3030
THE SOFTWARE.
3131
"""
3232

33+
from dataclasses import Field, fields, is_dataclass
3334
from typing import Tuple, Union, get_args, get_origin
3435

35-
from dataclasses import Field, is_dataclass, fields
3636
from arraycontext.container import is_array_container_type
3737

3838

@@ -95,7 +95,7 @@ def is_array_field(f: Field) -> bool:
9595
# NOTE:
9696
# * `_BaseGenericAlias` catches `List`, `Tuple`, etc.
9797
# * `_SpecialForm` catches `Any`, `Literal`, etc.
98-
from typing import ( # type: ignore[attr-defined]
98+
from typing import ( # type: ignore[attr-defined]
9999
_BaseGenericAlias, _SpecialForm)
100100
if isinstance(f.type, (_BaseGenericAlias, _SpecialForm)):
101101
# NOTE: anything except a Union is not allowed

arraycontext/container/traversal.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141

4242
from __future__ import annotations
4343

44+
4445
__copyright__ = """
4546
Copyright (C) 2020-1 University of Illinois Board of Trustees
4647
"""
@@ -65,22 +66,18 @@
6566
THE SOFTWARE.
6667
"""
6768

68-
from typing import Any, Callable, Iterable, List, Optional, Union, Tuple, cast
69-
from functools import update_wrapper, partial, singledispatch
69+
from functools import partial, singledispatch, update_wrapper
70+
from typing import Any, Callable, Iterable, List, Optional, Tuple, Union, cast
7071
from warnings import warn
7172

7273
import numpy as np
7374

74-
from arraycontext.context import (
75-
ArrayT, ArrayOrContainer, ArrayOrContainerT,
76-
ArrayOrContainerOrScalar, ScalarLike,
77-
ArrayContext, Array
78-
)
7975
from arraycontext.container import (
80-
NotAnArrayContainerError,
81-
ArrayContainer,
82-
serialize_container, deserialize_container,
83-
get_container_context_recursively_opt)
76+
ArrayContainer, NotAnArrayContainerError, deserialize_container,
77+
get_container_context_recursively_opt, serialize_container)
78+
from arraycontext.context import (
79+
Array, ArrayContext, ArrayOrContainer, ArrayOrContainerOrScalar,
80+
ArrayOrContainerT, ArrayT, ScalarLike)
8481

8582

8683
# {{{ array container traversal helpers

arraycontext/context.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -160,15 +160,18 @@
160160

161161
from abc import ABC, abstractmethod
162162
from typing import (
163-
Any, Callable, Dict, Optional, Tuple, Union, Mapping, Protocol, TypeVar,
164-
TYPE_CHECKING)
163+
TYPE_CHECKING, Any, Callable, Dict, Mapping, Optional, Protocol, Tuple, TypeVar,
164+
Union)
165165

166166
import numpy as np
167+
167168
from pytools import memoize_method
168169
from pytools.tag import ToTagSetConvertible
169170

171+
170172
if TYPE_CHECKING:
171173
import loopy
174+
172175
from arraycontext.container import ArrayContainer
173176

174177

@@ -426,8 +429,9 @@ def _get_einsum_prg(self,
426429
spec: str, arg_names: Tuple[str, ...],
427430
tagged: ToTagSetConvertible) -> "loopy.TranslationUnit":
428431
import loopy as lp
429-
from .loopy import _DEFAULT_LOOPY_OPTIONS
430432
from loopy.version import MOST_RECENT_LANGUAGE_VERSION
433+
434+
from .loopy import _DEFAULT_LOOPY_OPTIONS
431435
return lp.make_einsum(
432436
spec,
433437
arg_names,

arraycontext/fake_numpy.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525

2626
import numpy as np
27+
2728
from arraycontext.container import NotAnArrayContainerError, serialize_container
2829
from arraycontext.container.traversal import rec_map_array_container
2930

@@ -105,8 +106,8 @@ def conjugate(self, x):
105106
# {{{ BaseFakeNumpyLinalgNamespace
106107

107108
def _reduce_norm(actx, arys, ord):
108-
from numbers import Number
109109
from functools import reduce
110+
from numbers import Number
110111

111112
if ord is None:
112113
ord = 2

arraycontext/impl/jax/__init__.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,10 @@
3232
import numpy as np
3333

3434
from pytools.tag import ToTagSetConvertible
35-
from arraycontext.context import ArrayContext, Array, ArrayOrContainer, ScalarLike
36-
from arraycontext.container.traversal import (with_array_context,
37-
rec_map_array_container)
35+
36+
from arraycontext.container.traversal import (
37+
rec_map_array_container, with_array_context)
38+
from arraycontext.context import Array, ArrayContext, ArrayOrContainer, ScalarLike
3839

3940

4041
class EagerJAXArrayContext(ArrayContext):

arraycontext/impl/jax/fake_numpy.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,17 +23,15 @@
2323
"""
2424
from functools import partial, reduce
2525

26-
import numpy as np
2726
import jax.numpy as jnp
27+
import numpy as np
2828

29-
from arraycontext.fake_numpy import (
30-
BaseFakeNumpyNamespace, BaseFakeNumpyLinalgNamespace,
31-
)
32-
from arraycontext.container.traversal import (
33-
rec_multimap_array_container, rec_map_array_container,
34-
rec_map_reduce_array_container,
35-
)
3629
from arraycontext.container import NotAnArrayContainerError, serialize_container
30+
from arraycontext.container.traversal import (
31+
rec_map_array_container, rec_map_reduce_array_container,
32+
rec_multimap_array_container)
33+
from arraycontext.fake_numpy import (
34+
BaseFakeNumpyLinalgNamespace, BaseFakeNumpyNamespace)
3735

3836

3937
class EagerJAXFakeNumpyLinalgNamespace(BaseFakeNumpyLinalgNamespace):

arraycontext/impl/pyopencl/__init__.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,21 +28,21 @@
2828
THE SOFTWARE.
2929
"""
3030

31+
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple
3132
from warnings import warn
32-
from typing import Callable, Dict, List, Optional, Tuple, TYPE_CHECKING
3333

3434
import numpy as np
3535

3636
from pytools.tag import ToTagSetConvertible
3737

38-
from arraycontext.context import ArrayContext, Array, ArrayOrContainer, ScalarLike
39-
from arraycontext.container.traversal import (rec_map_array_container,
40-
with_array_context)
38+
from arraycontext.container.traversal import (
39+
rec_map_array_container, with_array_context)
40+
from arraycontext.context import Array, ArrayContext, ArrayOrContainer, ScalarLike
4141

4242

4343
if TYPE_CHECKING:
44-
import pyopencl
4544
import loopy as lp
45+
import pyopencl
4646

4747

4848
# {{{ PyOpenCLArrayContext
@@ -287,6 +287,7 @@ def call_loopy(self, t_unit, **kwargs):
287287
wait_event_queue.pop(0).wait()
288288

289289
import arraycontext.impl.pyopencl.taggable_cl_array as tga
290+
290291
# FIXME: Inherit loopy tags for these arrays
291292
return {name: tga.to_tagged_cl_array(ary) for name, ary in result.items()}
292293

0 commit comments

Comments
 (0)