Skip to content

Commit 6bc8d56

Browse files
authored
Merge branch 'main' into flatten-to-numpy
2 parents f16ccab + 9c24abb commit 6bc8d56

File tree

8 files changed

+36
-39
lines changed

8 files changed

+36
-39
lines changed

arraycontext/container/__init__.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
2323
Serialization/deserialization
2424
-----------------------------
25-
.. autofunction:: is_array_container
25+
.. autofunction:: is_array_container_type
2626
.. autofunction:: serialize_container
2727
.. autofunction:: deserialize_container
2828
@@ -95,8 +95,6 @@ class ArrayContainer:
9595
9696
This allows enumeration of the component arrays in a container and the
9797
construction of modified containers from an iterable of those component arrays.
98-
:func:`is_array_container` will return *True* for types that have
99-
a container serialization function registered.
10098
10199
Packages may register their own types as array containers. They must not
102100
register other types (e.g. :class:`list`) as array containers.
@@ -160,6 +158,13 @@ def is_array_container_type(cls: type) -> bool:
160158
"""
161159
:returns: *True* if the type *cls* has a registered implementation of
162160
:func:`serialize_container`, or if it is an :class:`ArrayContainer`.
161+
162+
.. warning::
163+
164+
Not all instances of a type that this function labels an array container
165+
must automatically be array containers. For example, while this
166+
function will say that :class:`numpy.ndarray` is an array container
167+
type, only object arrays *actually are* array containers.
163168
"""
164169
return (
165170
cls is ArrayContainer
@@ -172,6 +177,13 @@ def is_array_container(ary: Any) -> bool:
172177
:returns: *True* if the instance *ary* has a registered implementation of
173178
:func:`serialize_container`.
174179
"""
180+
181+
from warnings import warn
182+
warn("is_array_container is deprecated and will be removed in 2022. "
183+
"If you must know precisely whether something is an array container, "
184+
"try serializing it and catch TypeError. For a cheaper option, see "
185+
"is_array_container_type.",
186+
DeprecationWarning, stacklevel=2)
175187
return (serialize_container.dispatch(ary.__class__)
176188
is not serialize_container.__wrapped__) # type:ignore[attr-defined]
177189

@@ -194,7 +206,7 @@ def get_container_context(ary: ArrayContainer) -> Optional[ArrayContext]:
194206
@serialize_container.register(np.ndarray)
195207
def _serialize_ndarray_container(ary: np.ndarray) -> Iterable[Tuple[Any, Any]]:
196208
if ary.dtype.char != "O":
197-
raise ValueError(
209+
raise TypeError(
198210
f"cannot seriealize '{type(ary).__name__}' with dtype '{ary.dtype}'")
199211

200212
# special-cased for speed
@@ -236,7 +248,7 @@ def get_container_context_recursively(ary: Any) -> Optional[ArrayContext]:
236248
any level, an assertion error is raised.
237249
"""
238250
actx = None
239-
if not is_array_container(ary):
251+
if not is_array_container_type(ary.__class__):
240252
return actx
241253

242254
# try getting the array context directly

arraycontext/container/dataclass.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def dataclass_array_container(cls: type) -> type:
4444
4545
Attributes that are not array containers are allowed. In order to decide
4646
whether an attribute is an array container, the declared attribute type
47-
is checked by the criteria from :func:`is_array_container`.
47+
is checked by the criteria from :func:`is_array_container_type`.
4848
"""
4949
from dataclasses import is_dataclass
5050
assert is_dataclass(cls)

arraycontext/container/traversal.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@
6565

6666
from arraycontext.context import ArrayContext
6767
from arraycontext.container import (
68-
ContainerT, ArrayOrContainerT, is_array_container,
68+
ContainerT, ArrayOrContainerT, is_array_container_type,
6969
serialize_container, deserialize_container)
7070

7171

@@ -86,7 +86,7 @@ def _map_array_container_impl(
8686
def rec(_ary: ArrayOrContainerT) -> ArrayOrContainerT:
8787
if type(_ary) is leaf_cls: # type(ary) is never None
8888
return f(_ary)
89-
elif is_array_container(_ary):
89+
elif is_array_container_type(_ary.__class__):
9090
return deserialize_container(_ary, [
9191
(key, frec(subary)) for key, subary in serialize_container(_ary)
9292
])
@@ -113,7 +113,7 @@ def _multimap_array_container_impl(
113113
def rec(*_args: Any) -> Any:
114114
template_ary = _args[container_indices[0]]
115115
if (type(template_ary) is leaf_cls
116-
or not is_array_container(template_ary)):
116+
or not is_array_container_type(template_ary.__class__)):
117117
return f(*_args)
118118

119119
assert all(
@@ -141,7 +141,7 @@ def rec(*_args: Any) -> Any:
141141

142142
container_indices: List[int] = [
143143
i for i, arg in enumerate(args)
144-
if is_array_container(arg) and type(arg) is not leaf_cls]
144+
if is_array_container_type(arg.__class__) and type(arg) is not leaf_cls]
145145

146146
if not container_indices:
147147
return f(*args)
@@ -453,7 +453,7 @@ def freeze(
453453
454454
See :meth:`ArrayContext.thaw`.
455455
"""
456-
if is_array_container(ary):
456+
if is_array_container_type(ary.__class__):
457457
return map_array_container(partial(freeze, actx=actx), ary)
458458
else:
459459
if actx is None:
@@ -634,7 +634,7 @@ def from_numpy(ary: Any, actx: ArrayContext) -> Any:
634634
def _from_numpy(subary: Any) -> Any:
635635
if isinstance(subary, np.ndarray) and subary.dtype != "O":
636636
return actx.from_numpy(subary)
637-
elif is_array_container(subary):
637+
elif is_array_container_type(subary.__class__):
638638
return map_array_container(_from_numpy, subary)
639639
else:
640640
raise TypeError(f"unrecognized array type: '{type(subary).__name__}'")

arraycontext/fake_numpy.py

Lines changed: 3 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525

2626
import numpy as np
27-
from arraycontext.container import is_array_container, serialize_container
27+
from arraycontext.container import is_array_container_type, serialize_container
2828
from arraycontext.container.traversal import (
2929
rec_map_array_container, multimapped_over_array_containers)
3030
from pytools import memoize_in
@@ -182,7 +182,7 @@ def _new_like(self, ary, alloc_like):
182182
# e.g. `np.zeros_like(x)` returns `array([0, 0, ...], dtype=object)`
183183
# FIXME: what about object arrays nested in an ArrayContainer?
184184
raise NotImplementedError("operation not implemented for object arrays")
185-
elif is_array_container(ary):
185+
elif is_array_container_type(ary.__class__):
186186
return rec_map_array_container(alloc_like, ary)
187187
elif isinstance(ary, Number):
188188
# NOTE: `np.zeros_like(x)` returns `array(x, shape=())`, which
@@ -210,21 +210,6 @@ def conjugate(self, x):
210210

211211
# {{{ BaseFakeNumpyLinalgNamespace
212212

213-
def _scalar_list_norm(ary, ord):
214-
if ord is None:
215-
ord = 2
216-
217-
from numbers import Number
218-
if ord == np.inf:
219-
return max(ary)
220-
elif ord == -np.inf:
221-
return min(ary)
222-
elif isinstance(ord, Number) and ord > 0:
223-
return sum(iary**ord for iary in ary)**(1/ord)
224-
else:
225-
raise NotImplementedError(f"unsupported value of 'ord': {ord}")
226-
227-
228213
def _reduce_norm(actx, arys, ord):
229214
from numbers import Number
230215
from functools import reduce
@@ -273,7 +258,7 @@ def norm(self, ary, ord=None):
273258

274259
return flat_norm(ary, ord=ord)
275260

276-
if is_array_container(ary):
261+
if is_array_container_type(ary.__class__):
277262
return _reduce_norm(actx, [
278263
self.norm(subary, ord=ord)
279264
for _, subary in serialize_container(ary)

arraycontext/impl/pyopencl/fake_numpy.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131

3232
from arraycontext.fake_numpy import \
3333
BaseFakeNumpyNamespace, BaseFakeNumpyLinalgNamespace
34-
from arraycontext.container import is_array_container
34+
from arraycontext.container import is_array_container_type
3535
from arraycontext.container.traversal import (
3636
rec_map_array_container,
3737
rec_multimap_array_container,
@@ -249,7 +249,7 @@ def as_device_scalar(bool_value):
249249
def rec_equal(x, y):
250250
if type(x) != type(y):
251251
return as_device_scalar(False)
252-
elif not is_array_container(x):
252+
elif not is_array_container_type(x.__class__):
253253
if x.shape != y.shape:
254254
return as_device_scalar(False)
255255
else:

arraycontext/impl/pytato/compile.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from arraycontext.container import ArrayContainer
3131
from arraycontext import PytatoPyOpenCLArrayContext
3232
from arraycontext.container.traversal import (rec_keyed_map_array_container,
33-
is_array_container)
33+
is_array_container_type)
3434

3535
import numpy as np
3636
from typing import Any, Callable, Tuple, Dict, Mapping
@@ -119,7 +119,7 @@ def _get_arg_id_to_arg_and_arg_id_to_descr(args: Tuple[Any, ...],
119119
arg_id = (kw,)
120120
arg_id_to_arg[arg_id] = arg
121121
arg_id_to_descr[arg_id] = ScalarInputDescriptor(np.dtype(type(arg)))
122-
elif is_array_container(arg):
122+
elif is_array_container_type(arg.__class__):
123123
def id_collector(keys, ary):
124124
arg_id = (kw,) + keys
125125
arg_id_to_arg[arg_id] = ary
@@ -145,7 +145,7 @@ def _get_f_placeholder_args(arg, kw, arg_id_to_name):
145145
if np.isscalar(arg):
146146
name = arg_id_to_name[(kw,)]
147147
return pt.make_placeholder(name, (), np.dtype(type(arg)))
148-
elif is_array_container(arg):
148+
elif is_array_container_type(arg.__class__):
149149
def _rec_to_placeholder(keys, ary):
150150
name = arg_id_to_name[(kw,) + keys]
151151
return pt.make_placeholder(name, ary.shape, ary.dtype)
@@ -212,7 +212,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any:
212212
**{kw: _get_f_placeholder_args(arg, kw, input_naming_map)
213213
for kw, arg in kwargs.items()})
214214

215-
if not is_array_container(outputs):
215+
if not is_array_container_type(outputs.__class__):
216216
# TODO: We could possibly just short-circuit this interface if the
217217
# returned type is a scalar. Not sure if it's worth it though.
218218
raise NotImplementedError(

arraycontext/impl/pytato/fake_numpy.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from arraycontext.fake_numpy import (
2727
BaseFakeNumpyNamespace, BaseFakeNumpyLinalgNamespace,
2828
)
29-
from arraycontext.container import is_array_container
29+
from arraycontext.container import is_array_container_type
3030
from arraycontext.container.traversal import (
3131
rec_map_array_container,
3232
rec_multimap_array_container,
@@ -181,7 +181,7 @@ def as_device_scalar(bool_value):
181181

182182
if type(a) != type(b):
183183
return as_device_scalar(False)
184-
elif not is_array_container(a):
184+
elif not is_array_container_type(a.__class__):
185185
if a.shape != b.shape:
186186
return as_device_scalar(False)
187187
else:

test/test_arraycontext.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -977,7 +977,7 @@ def test_numpy_conversion(actx_factory):
977977
with pytest.raises(TypeError):
978978
from_numpy(ac_actx, actx)
979979

980-
with pytest.raises(ValueError):
980+
with pytest.raises(TypeError):
981981
to_numpy(ac, actx)
982982

983983
# }}}

0 commit comments

Comments
 (0)