Skip to content

Commit f0e3ac0

Browse files
alexfiklinducer
authored andcommitted
add specific array container exception type
1 parent 7b57511 commit f0e3ac0

File tree

6 files changed

+37
-34
lines changed

6 files changed

+37
-34
lines changed

arraycontext/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
from .metadata import _FirstAxisIsElementsTag
3939

4040
from .container import (
41-
ArrayContainer,
41+
ArrayContainer, NotAnArrayContainerError,
4242
is_array_container, is_array_container_type,
4343
get_container_context, get_container_context_recursively,
4444
serialize_container, deserialize_container,
@@ -79,7 +79,7 @@
7979
"CommonSubexpressionTag",
8080
"ElementwiseMapKernelTag",
8181

82-
"ArrayContainer",
82+
"ArrayContainer", "NotAnArrayContainerError",
8383
"is_array_container", "is_array_container_type",
8484
"get_container_context", "get_container_context_recursively",
8585
"serialize_container", "deserialize_container",

arraycontext/container/__init__.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
2121
.. autoclass:: ArrayContainer
2222
23+
.. autoexception:: NotAnArrayContainerError
24+
2325
Serialization/deserialization
2426
-----------------------------
2527
.. autofunction:: is_array_container_type
@@ -115,6 +117,10 @@ class ArrayContainer:
115117
"""
116118

117119

120+
class NotAnArrayContainerError(TypeError):
121+
""":class:`TypeError` subclass raised when an array container is expected."""
122+
123+
118124
@singledispatch
119125
def serialize_container(ary: ArrayContainer) -> Iterable[Tuple[Any, Any]]:
120126
r"""Serialize the array container into an iterable over its components.
@@ -137,7 +143,8 @@ def serialize_container(ary: ArrayContainer) -> Iterable[Tuple[Any, Any]]:
137143
for arbitrarily nested structures. The identifiers need to be hashable
138144
but are otherwise treated as opaque.
139145
"""
140-
raise TypeError(f"'{type(ary).__name__}' cannot be serialized as a container")
146+
raise NotAnArrayContainerError(
147+
f"'{type(ary).__name__}' cannot be serialized as a container")
141148

142149

143150
@singledispatch
@@ -150,7 +157,7 @@ def deserialize_container(template: Any, iterable: Iterable[Tuple[Any, Any]]) ->
150157
:param iterable: an iterable that mirrors the output of
151158
:meth:`serialize_container`.
152159
"""
153-
raise TypeError(
160+
raise NotAnArrayContainerError(
154161
f"'{type(template).__name__}' cannot be deserialized as a container")
155162

156163

@@ -181,8 +188,8 @@ def is_array_container(ary: Any) -> bool:
181188
from warnings import warn
182189
warn("is_array_container is deprecated and will be removed in 2022. "
183190
"If you must know precisely whether something is an array container, "
184-
"try serializing it and catch TypeError. For a cheaper option, see "
185-
"is_array_container_type.",
191+
"try serializing it and catch NotAnArrayContainerError. For a "
192+
"cheaper option, see is_array_container_type.",
186193
DeprecationWarning, stacklevel=2)
187194
return (serialize_container.dispatch(ary.__class__)
188195
is not serialize_container.__wrapped__) # type:ignore[attr-defined]
@@ -206,7 +213,7 @@ def get_container_context(ary: ArrayContainer) -> Optional[ArrayContext]:
206213
@serialize_container.register(np.ndarray)
207214
def _serialize_ndarray_container(ary: np.ndarray) -> Iterable[Tuple[Any, Any]]:
208215
if ary.dtype.char != "O":
209-
raise TypeError(
216+
raise NotAnArrayContainerError(
210217
f"cannot seriealize '{type(ary).__name__}' with dtype '{ary.dtype}'")
211218

212219
# special-cased for speed
@@ -254,7 +261,7 @@ def get_container_context_recursively(ary: Any) -> Optional[ArrayContext]:
254261

255262
try:
256263
iterable = serialize_container(ary)
257-
except TypeError:
264+
except NotAnArrayContainerError:
258265
return actx
259266
else:
260267
for _, subary in iterable:

arraycontext/container/traversal.py

Lines changed: 16 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@
6969

7070
from arraycontext.context import ArrayContext
7171
from arraycontext.container import (
72-
ContainerT, ArrayOrContainerT,
72+
ContainerT, ArrayOrContainerT, NotAnArrayContainerError,
7373
serialize_container, deserialize_container)
7474

7575

@@ -93,7 +93,7 @@ def rec(_ary: ArrayOrContainerT) -> ArrayOrContainerT:
9393

9494
try:
9595
iterable = serialize_container(_ary)
96-
except TypeError:
96+
except NotAnArrayContainerError:
9797
return f(_ary)
9898
else:
9999
return deserialize_container(_ary, [
@@ -127,7 +127,7 @@ def rec(*_args: Any) -> Any:
127127

128128
try:
129129
iterable_template = serialize_container(template_ary)
130-
except TypeError:
130+
except NotAnArrayContainerError:
131131
return f(*_args)
132132
else:
133133
pass
@@ -170,7 +170,7 @@ def rec(*_args: Any) -> Any:
170170
# FIXME: this will serialize again once `rec` is called, which is
171171
# not great, but it doesn't seem like there's a good way to avoid it
172172
_ = serialize_container(arg)
173-
except TypeError:
173+
except NotAnArrayContainerError:
174174
pass
175175
else:
176176
container_indices.append(i)
@@ -231,7 +231,7 @@ def map_array_container(
231231
"""
232232
try:
233233
iterable = serialize_container(ary)
234-
except TypeError:
234+
except NotAnArrayContainerError:
235235
return f(ary)
236236
else:
237237
return deserialize_container(ary, [
@@ -316,7 +316,7 @@ def keyed_map_array_container(f: Callable[[Any, Any], Any],
316316
"""
317317
try:
318318
iterable = serialize_container(ary)
319-
except TypeError:
319+
except NotAnArrayContainerError:
320320
raise ValueError(
321321
f"Non-array container type has no key: {type(ary).__name__}")
322322
else:
@@ -338,7 +338,7 @@ def rec(keys: Tuple[Union[str, int], ...],
338338
_ary: ArrayOrContainerT) -> ArrayOrContainerT:
339339
try:
340340
iterable = serialize_container(_ary)
341-
except TypeError:
341+
except NotAnArrayContainerError:
342342
return f(keys, _ary)
343343
else:
344344
return deserialize_container(_ary, [
@@ -367,7 +367,7 @@ def map_reduce_array_container(
367367
"""
368368
try:
369369
iterable = serialize_container(ary)
370-
except TypeError:
370+
except NotAnArrayContainerError:
371371
return map_func(ary)
372372
else:
373373
return reduce_func([
@@ -442,7 +442,7 @@ def rec_map_reduce_array_container(
442442
def rec(_ary: ArrayOrContainerT) -> ArrayOrContainerT:
443443
try:
444444
iterable = serialize_container(_ary)
445-
except TypeError:
445+
except NotAnArrayContainerError:
446446
return map_func(_ary)
447447
else:
448448
return reduce_func([
@@ -501,7 +501,7 @@ def freeze(
501501
"""
502502
try:
503503
iterable = serialize_container(ary)
504-
except TypeError:
504+
except NotAnArrayContainerError:
505505
if actx is None:
506506
raise TypeError(
507507
f"cannot freeze arrays of type {type(ary).__name__} "
@@ -538,7 +538,7 @@ def thaw(ary: ArrayOrContainerT, actx: ArrayContext) -> ArrayOrContainerT:
538538
"""
539539
try:
540540
iterable = serialize_container(ary)
541-
except TypeError:
541+
except NotAnArrayContainerError:
542542
return actx.thaw(ary)
543543
else:
544544
return deserialize_container(ary, [
@@ -567,7 +567,7 @@ def _flatten(subary: ArrayOrContainerT) -> None:
567567

568568
try:
569569
iterable = serialize_container(subary)
570-
except TypeError:
570+
except NotAnArrayContainerError:
571571
if common_dtype is None:
572572
common_dtype = subary.dtype
573573

@@ -618,7 +618,7 @@ def _unflatten(template_subary: ArrayOrContainerT) -> ArrayOrContainerT:
618618

619619
try:
620620
iterable = serialize_container(template_subary)
621-
except TypeError:
621+
except NotAnArrayContainerError:
622622
if (offset + template_subary.size) > ary.size:
623623
raise ValueError("'template' and 'ary' sizes do not match: "
624624
"'template' is too large")
@@ -682,9 +682,7 @@ def from_numpy(ary: Any, actx: ArrayContext) -> Any:
682682
The conversion is done using :meth:`arraycontext.ArrayContext.from_numpy`.
683683
"""
684684
def _from_numpy_with_check(subary: Any) -> Any:
685-
if np.isscalar(subary):
686-
return subary
687-
elif isinstance(subary, np.ndarray):
685+
if isinstance(subary, np.ndarray) or np.isscalar(subary):
688686
return actx.from_numpy(subary)
689687
else:
690688
raise TypeError(f"array is not an ndarray: '{type(subary).__name__}'")
@@ -699,9 +697,7 @@ def to_numpy(ary: Any, actx: ArrayContext) -> Any:
699697
The conversion is done using :meth:`arraycontext.ArrayContext.to_numpy`.
700698
"""
701699
def _to_numpy_with_check(subary: Any) -> Any:
702-
if np.isscalar(subary):
703-
return subary
704-
elif isinstance(subary, actx.array_types):
700+
if isinstance(subary, actx.array_types) or np.isscalar(subary):
705701
return actx.to_numpy(subary)
706702
else:
707703
raise TypeError(
@@ -734,7 +730,7 @@ def outer(a: Any, b: Any) -> Any:
734730
def treat_as_scalar(x: Any) -> bool:
735731
try:
736732
serialize_container(x)
737-
except TypeError:
733+
except NotAnArrayContainerError:
738734
return True
739735
else:
740736
return (

arraycontext/fake_numpy.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525

2626
import numpy as np
27-
from arraycontext.container import serialize_container
27+
from arraycontext.container import NotAnArrayContainerError, serialize_container
2828
from arraycontext.container.traversal import (
2929
rec_map_array_container, multimapped_over_array_containers)
3030
from pytools import memoize_in
@@ -258,7 +258,7 @@ def norm(self, ary, ord=None):
258258

259259
try:
260260
iterable = serialize_container(ary)
261-
except TypeError:
261+
except NotAnArrayContainerError:
262262
pass
263263
else:
264264
return _reduce_norm(actx, [

arraycontext/impl/pyopencl/fake_numpy.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333

3434
from arraycontext.fake_numpy import \
3535
BaseFakeNumpyNamespace, BaseFakeNumpyLinalgNamespace
36-
from arraycontext.container import serialize_container
36+
from arraycontext.container import NotAnArrayContainerError, serialize_container
3737
from arraycontext.container.traversal import (
3838
rec_map_array_container,
3939
rec_multimap_array_container,
@@ -252,7 +252,7 @@ def rec_equal(x, y):
252252

253253
try:
254254
iterable = zip(serialize_container(x), serialize_container(y))
255-
except TypeError:
255+
except NotAnArrayContainerError:
256256
if x.shape != y.shape:
257257
return false
258258
else:

arraycontext/impl/pytato/fake_numpy.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from arraycontext.fake_numpy import (
2929
BaseFakeNumpyNamespace, BaseFakeNumpyLinalgNamespace,
3030
)
31-
from arraycontext.container import serialize_container
31+
from arraycontext.container import NotAnArrayContainerError, serialize_container
3232
from arraycontext.container.traversal import (
3333
rec_map_array_container,
3434
rec_multimap_array_container,
@@ -186,7 +186,7 @@ def rec_equal(x, y):
186186

187187
try:
188188
iterable = zip(serialize_container(x), serialize_container(y))
189-
except TypeError:
189+
except NotAnArrayContainerError:
190190
if x.shape != y.shape:
191191
return false
192192
else:

0 commit comments

Comments
 (0)