Skip to content

Commit 3b08761

Browse files
committed
Merge branch 'main' into pytato-array-context-transforms
2 parents 714272a + 8f86278 commit 3b08761

File tree

9 files changed

+69
-29
lines changed

9 files changed

+69
-29
lines changed

arraycontext/container/__init__.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
2323
Serialization/deserialization
2424
-----------------------------
25-
.. autofunction:: is_array_container
25+
.. autofunction:: is_array_container_type
2626
.. autofunction:: serialize_container
2727
.. autofunction:: deserialize_container
2828
@@ -95,8 +95,6 @@ class ArrayContainer:
9595
9696
This allows enumeration of the component arrays in a container and the
9797
construction of modified containers from an iterable of those component arrays.
98-
:func:`is_array_container` will return *True* for types that have
99-
a container serialization function registered.
10098
10199
Packages may register their own types as array containers. They must not
102100
register other types (e.g. :class:`list`) as array containers.
@@ -156,6 +154,13 @@ def is_array_container_type(cls: type) -> bool:
156154
"""
157155
:returns: *True* if the type *cls* has a registered implementation of
158156
:func:`serialize_container`, or if it is an :class:`ArrayContainer`.
157+
158+
.. warning::
159+
160+
Not all instances of a type that this function labels an array container
161+
must automatically be array containers. For example, while this
162+
function will say that :class:`numpy.ndarray` is an array container
163+
type, only object arrays *actually are* array containers.
159164
"""
160165
return (
161166
cls is ArrayContainer
@@ -168,6 +173,13 @@ def is_array_container(ary: Any) -> bool:
168173
:returns: *True* if the instance *ary* has a registered implementation of
169174
:func:`serialize_container`.
170175
"""
176+
177+
from warnings import warn
178+
warn("is_array_container is deprecated and will be removed in 2022. "
179+
"If you must know precisely whether something is an array container, "
180+
"try serializing it and catch TypeError. For a cheaper option, see "
181+
"is_array_container_type.",
182+
DeprecationWarning, stacklevel=2)
171183
return (serialize_container.dispatch(ary.__class__)
172184
is not serialize_container.__wrapped__) # type:ignore[attr-defined]
173185

@@ -190,7 +202,7 @@ def get_container_context(ary: ArrayContainer) -> Optional[ArrayContext]:
190202
@serialize_container.register(np.ndarray)
191203
def _serialize_ndarray_container(ary: np.ndarray) -> Iterable[Tuple[Any, Any]]:
192204
if ary.dtype.char != "O":
193-
raise ValueError(
205+
raise TypeError(
194206
f"cannot seriealize '{type(ary).__name__}' with dtype '{ary.dtype}'")
195207

196208
# special-cased for speed
@@ -232,7 +244,7 @@ def get_container_context_recursively(ary: Any) -> Optional[ArrayContext]:
232244
any level, an assertion error is raised.
233245
"""
234246
actx = None
235-
if not is_array_container(ary):
247+
if not is_array_container_type(ary.__class__):
236248
return actx
237249

238250
# try getting the array context directly

arraycontext/container/dataclass.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def dataclass_array_container(cls: type) -> type:
4444
4545
Attributes that are not array containers are allowed. In order to decide
4646
whether an attribute is an array container, the declared attribute type
47-
is checked by the criteria from :func:`is_array_container`.
47+
is checked by the criteria from :func:`is_array_container_type`.
4848
"""
4949
from dataclasses import is_dataclass
5050
assert is_dataclass(cls)

arraycontext/container/traversal.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@
6060

6161
from arraycontext.context import ArrayContext
6262
from arraycontext.container import (
63-
ContainerT, ArrayOrContainerT, is_array_container,
63+
ContainerT, ArrayOrContainerT, is_array_container_type,
6464
serialize_container, deserialize_container)
6565

6666

@@ -81,7 +81,7 @@ def _map_array_container_impl(
8181
def rec(_ary: ArrayOrContainerT) -> ArrayOrContainerT:
8282
if type(_ary) is leaf_cls: # type(ary) is never None
8383
return f(_ary)
84-
elif is_array_container(_ary):
84+
elif is_array_container_type(_ary.__class__):
8585
return deserialize_container(_ary, [
8686
(key, frec(subary)) for key, subary in serialize_container(_ary)
8787
])
@@ -108,7 +108,7 @@ def _multimap_array_container_impl(
108108
def rec(*_args: Any) -> Any:
109109
template_ary = _args[container_indices[0]]
110110
if (type(template_ary) is leaf_cls
111-
or not is_array_container(template_ary)):
111+
or not is_array_container_type(template_ary.__class__)):
112112
return f(*_args)
113113

114114
assert all(
@@ -136,7 +136,7 @@ def rec(*_args: Any) -> Any:
136136

137137
container_indices: List[int] = [
138138
i for i, arg in enumerate(args)
139-
if is_array_container(arg) and type(arg) is not leaf_cls]
139+
if is_array_container_type(arg.__class__) and type(arg) is not leaf_cls]
140140

141141
if not container_indices:
142142
return f(*args)
@@ -448,7 +448,7 @@ def freeze(
448448
449449
See :meth:`ArrayContext.thaw`.
450450
"""
451-
if is_array_container(ary):
451+
if is_array_container_type(ary.__class__):
452452
return map_array_container(partial(freeze, actx=actx), ary)
453453
else:
454454
if actx is None:
@@ -504,7 +504,7 @@ def from_numpy(ary: Any, actx: ArrayContext) -> Any:
504504
def _from_numpy(subary: Any) -> Any:
505505
if isinstance(subary, np.ndarray) and subary.dtype != "O":
506506
return actx.from_numpy(subary)
507-
elif is_array_container(subary):
507+
elif is_array_container_type(subary.__class__):
508508
return map_array_container(_from_numpy, subary)
509509
else:
510510
raise TypeError(f"unrecognized array type: '{type(subary).__name__}'")

arraycontext/fake_numpy.py

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525

2626
import numpy as np
27-
from arraycontext.container import is_array_container, serialize_container
27+
from arraycontext.container import is_array_container_type, serialize_container
2828
from arraycontext.container.traversal import (
2929
rec_map_array_container, multimapped_over_array_containers)
3030
from pytools import memoize_in
@@ -182,7 +182,7 @@ def _new_like(self, ary, alloc_like):
182182
# e.g. `np.zeros_like(x)` returns `array([0, 0, ...], dtype=object)`
183183
# FIXME: what about object arrays nested in an ArrayContainer?
184184
raise NotImplementedError("operation not implemented for object arrays")
185-
elif is_array_container(ary):
185+
elif is_array_container_type(ary.__class__):
186186
return rec_map_array_container(alloc_like, ary)
187187
elif isinstance(ary, Number):
188188
# NOTE: `np.zeros_like(x)` returns `array(x, shape=())`, which
@@ -225,6 +225,26 @@ def _scalar_list_norm(ary, ord):
225225
raise NotImplementedError(f"unsupported value of 'ord': {ord}")
226226

227227

228+
def _reduce_norm(actx, arys, ord):
229+
from numbers import Number
230+
from functools import reduce
231+
232+
if ord is None:
233+
ord = 2
234+
235+
# NOTE: these are ordered by an expected usage frequency
236+
if ord == 2:
237+
return actx.np.sqrt(sum(subary*subary for subary in arys))
238+
elif ord == np.inf:
239+
return reduce(actx.np.maximum, arys)
240+
elif ord == -np.inf:
241+
return reduce(actx.np.minimum, arys)
242+
elif isinstance(ord, Number) and ord > 0:
243+
return sum(subary**ord for subary in arys)**(1/ord)
244+
else:
245+
raise NotImplementedError(f"unsupported value of 'ord': {ord}")
246+
247+
228248
class BaseFakeNumpyLinalgNamespace:
229249
def __init__(self, array_context):
230250
self._array_context = array_context
@@ -253,8 +273,8 @@ def norm(self, ary, ord=None):
253273

254274
return flat_norm(ary, ord=ord)
255275

256-
if is_array_container(ary):
257-
return _scalar_list_norm([
276+
if is_array_container_type(ary.__class__):
277+
return _reduce_norm(actx, [
258278
self.norm(subary, ord=ord)
259279
for _, subary in serialize_container(ary)
260280
], ord=ord)
@@ -266,14 +286,16 @@ def norm(self, ary, ord=None):
266286
raise NotImplementedError("only vector norms are implemented")
267287

268288
if ary.size == 0:
269-
return 0
289+
return ary.dtype.type(0)
270290

291+
if ord == 2:
292+
return actx.np.sqrt(actx.np.sum(abs(ary)**2))
271293
if ord == np.inf:
272-
return self._array_context.np.max(abs(ary))
294+
return actx.np.max(abs(ary))
273295
elif ord == -np.inf:
274-
return self._array_context.np.min(abs(ary))
296+
return actx.np.min(abs(ary))
275297
elif isinstance(ord, Number) and ord > 0:
276-
return self._array_context.np.sum(abs(ary)**ord)**(1/ord)
298+
return actx.np.sum(abs(ary)**ord)**(1/ord)
277299
else:
278300
raise NotImplementedError(f"unsupported value of 'ord': {ord}")
279301
# }}}

arraycontext/impl/pyopencl/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,11 +244,17 @@ def transform_loopy_program(self, t_unit):
244244

245245
if "idof" in all_inames:
246246
inner_iname = "idof"
247+
247248
elif "i0" in all_inames:
248249
outer_iname = "i0"
249250

250251
if "i1" in all_inames:
251252
inner_iname = "i1"
253+
254+
elif not all_inames:
255+
# no loops, nothing to transform
256+
return t_unit
257+
252258
else:
253259
raise RuntimeError(
254260
"Unable to reason what outer_iname and inner_iname "

arraycontext/impl/pyopencl/fake_numpy.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131

3232
from arraycontext.fake_numpy import \
3333
BaseFakeNumpyNamespace, BaseFakeNumpyLinalgNamespace
34-
from arraycontext.container import is_array_container
34+
from arraycontext.container import is_array_container_type
3535
from arraycontext.container.traversal import (
3636
rec_map_array_container,
3737
rec_multimap_array_container,
@@ -247,7 +247,7 @@ def as_device_scalar(bool_value):
247247
def rec_equal(x, y):
248248
if type(x) != type(y):
249249
return as_device_scalar(False)
250-
elif not is_array_container(x):
250+
elif not is_array_container_type(x.__class__):
251251
if x.shape != y.shape:
252252
return as_device_scalar(False)
253253
else:

arraycontext/impl/pytato/compile.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from arraycontext.container import ArrayContainer
3131
from arraycontext import PytatoPyOpenCLArrayContext
3232
from arraycontext.container.traversal import (rec_keyed_map_array_container,
33-
is_array_container)
33+
is_array_container_type)
3434

3535
import numpy as np
3636
from typing import Any, Callable, Tuple, Dict, Mapping
@@ -124,7 +124,7 @@ def _get_arg_id_to_arg_and_arg_id_to_descr(args: Tuple[Any, ...],
124124
arg_id = (kw,)
125125
arg_id_to_arg[arg_id] = arg
126126
arg_id_to_descr[arg_id] = ScalarInputDescriptor(np.dtype(type(arg)))
127-
elif is_array_container(arg):
127+
elif is_array_container_type(arg.__class__):
128128
def id_collector(keys, ary):
129129
arg_id = (kw,) + keys
130130
arg_id_to_arg[arg_id] = ary
@@ -150,7 +150,7 @@ def _get_f_placeholder_args(arg, kw, arg_id_to_name):
150150
if np.isscalar(arg):
151151
name = arg_id_to_name[(kw,)]
152152
return pt.make_placeholder(name, (), np.dtype(type(arg)))
153-
elif is_array_container(arg):
153+
elif is_array_container_type(arg.__class__):
154154
def _rec_to_placeholder(keys, ary):
155155
name = arg_id_to_name[(kw,) + keys]
156156
return pt.make_placeholder(name, ary.shape, ary.dtype)
@@ -217,7 +217,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any:
217217
**{kw: _get_f_placeholder_args(arg, kw, input_naming_map)
218218
for kw, arg in kwargs.items()})
219219

220-
if not is_array_container(outputs):
220+
if not is_array_container_type(outputs.__class__):
221221
# TODO: We could possibly just short-circuit this interface if the
222222
# returned type is a scalar. Not sure if it's worth it though.
223223
raise NotImplementedError(

arraycontext/impl/pytato/fake_numpy.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from arraycontext.fake_numpy import (
2727
BaseFakeNumpyNamespace, BaseFakeNumpyLinalgNamespace,
2828
)
29-
from arraycontext.container import is_array_container
29+
from arraycontext.container import is_array_container_type
3030
from arraycontext.container.traversal import (
3131
rec_map_array_container,
3232
rec_multimap_array_container,
@@ -179,7 +179,7 @@ def as_device_scalar(bool_value):
179179

180180
if type(a) != type(b):
181181
return as_device_scalar(False)
182-
elif not is_array_container(a):
182+
elif not is_array_container_type(a.__class__):
183183
if a.shape != b.shape:
184184
return as_device_scalar(False)
185185
else:

test/test_arraycontext.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -904,7 +904,7 @@ def test_numpy_conversion(actx_factory):
904904
with pytest.raises(TypeError):
905905
from_numpy(ac_actx, actx)
906906

907-
with pytest.raises(ValueError):
907+
with pytest.raises(TypeError):
908908
to_numpy(ac, actx)
909909

910910
# }}}

0 commit comments

Comments
 (0)