Skip to content

Commit c5ed317

Browse files
authored
Merge pull request #91 from alexfikl/flatten-to-numpy
Flatten entire array containers
2 parents 9c24abb + 6bc8d56 commit c5ed317

File tree

6 files changed

+232
-19
lines changed

6 files changed

+232
-19
lines changed

arraycontext/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
rec_map_reduce_array_container,
5959
rec_multimap_reduce_array_container,
6060
thaw, freeze,
61+
flatten, unflatten,
6162
from_numpy, to_numpy)
6263

6364
from .impl.pyopencl import PyOpenCLArrayContext
@@ -92,6 +93,7 @@
9293
"map_reduce_array_container", "multimap_reduce_array_container",
9394
"rec_map_reduce_array_container", "rec_multimap_reduce_array_container",
9495
"thaw", "freeze",
96+
"flatten", "unflatten",
9597
"from_numpy", "to_numpy",
9698

9799
"PyOpenCLArrayContext", "PytatoPyOpenCLArrayContext",

arraycontext/container/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,11 @@ def serialize_container(ary: ArrayContainer) -> Iterable[Tuple[Any, Any]]:
120120
r"""Serialize the array container into an iterable over its components.
121121
122122
The order of the components and their identifiers are entirely under
123-
the control of the container class.
123+
the control of the container class. However, the order is required to be
124+
deterministic, i.e. two calls to :func:`serialize_container` on
125+
array containers of the same types with the same number of
126+
sub-arrays must result in an iterable with the keys in the same
127+
order.
124128
125129
If *ary* is mutable, the serialization function is not required to ensure
126130
that the serialization result reflects the array state at the time of the

arraycontext/container/traversal.py

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,11 @@
2323
.. autofunction:: freeze
2424
.. autofunction:: thaw
2525
26+
Flattening and unflattening
27+
~~~~~~~~~~~~~~~~~~~~~~~~~~~
28+
.. autofunction:: flatten
29+
.. autofunction:: unflatten
30+
2631
Numpy conversion
2732
~~~~~~~~~~~~~~~~
2833
.. autofunction:: from_numpy
@@ -493,6 +498,131 @@ def thaw(ary: ArrayOrContainerT, actx: ArrayContext) -> ArrayOrContainerT:
493498
# }}}
494499

495500

501+
# {{{ flatten / unflatten
502+
503+
def flatten(ary: ArrayOrContainerT, actx: ArrayContext) -> Any:
504+
"""Convert all arrays in the :class:`~arraycontext.ArrayContainer`
505+
into single flat array of a type :attr:`arraycontext.ArrayContext.array_types`.
506+
507+
The operation requires :attr:`arraycontext.ArrayContext.np` to have
508+
``ravel`` and ``concatenate`` methods implemented. The order in which the
509+
individual leaf arrays appear in the final array is dependent on the order
510+
given by :func:`~arraycontext.serialize_container`.
511+
"""
512+
common_dtype = None
513+
result: List[Any] = []
514+
515+
def _flatten(subary: ArrayOrContainerT) -> None:
516+
nonlocal common_dtype
517+
518+
try:
519+
iterable = serialize_container(subary)
520+
except TypeError:
521+
if common_dtype is None:
522+
common_dtype = subary.dtype
523+
524+
if subary.dtype != common_dtype:
525+
raise ValueError("arrays in container have different dtypes: "
526+
f"got {subary.dtype}, expected {common_dtype}")
527+
528+
try:
529+
flat_subary = actx.np.ravel(subary, order="C")
530+
except ValueError as exc:
531+
# NOTE: we can't do much if the array context fails to ravel,
532+
# since it is the one responsible for the actual memory layout
533+
if hasattr(subary, "strides"):
534+
strides_msg = f" and strides {subary.strides}"
535+
else:
536+
strides_msg = ""
537+
538+
raise NotImplementedError(
539+
f"'{type(actx).__name__}.np.ravel' failed to reshape "
540+
f"an array with shape {subary.shape}{strides_msg}. "
541+
"This functionality needs to be implemented by the "
542+
"array context.") from exc
543+
544+
result.append(flat_subary)
545+
else:
546+
for _, isubary in iterable:
547+
_flatten(isubary)
548+
549+
_flatten(ary)
550+
551+
return actx.np.concatenate(result)
552+
553+
554+
def unflatten(
555+
template: ArrayOrContainerT, ary: Any,
556+
actx: ArrayContext) -> ArrayOrContainerT:
557+
"""Unflatten an array *ary* produced by :func:`flatten` back into an
558+
:class:`~arraycontext.ArrayContainer`.
559+
560+
The order and sizes of each slice into *ary* are determined by the
561+
array container *template*.
562+
"""
563+
# NOTE: https://github.com/python/mypy/issues/7057
564+
offset = 0
565+
566+
def _unflatten(template_subary: ArrayOrContainerT) -> ArrayOrContainerT:
567+
nonlocal offset
568+
569+
try:
570+
iterable = serialize_container(template_subary)
571+
except TypeError:
572+
if (offset + template_subary.size) > ary.size:
573+
raise ValueError("'template' and 'ary' sizes do not match: "
574+
"'template' is too large")
575+
576+
if template_subary.dtype != ary.dtype:
577+
raise ValueError("'template' dtype does not match 'ary': "
578+
f"got {template_subary.dtype}, expected {ary.dtype}")
579+
580+
flat_subary = ary[offset:offset + template_subary.size]
581+
try:
582+
subary = actx.np.reshape(flat_subary,
583+
template_subary.shape, order="C")
584+
except ValueError as exc:
585+
# NOTE: we can't do much if the array context fails to reshape,
586+
# since it is the one responsible for the actual memory layout
587+
raise NotImplementedError(
588+
f"'{type(actx).__name__}.np.reshape' failed to reshape "
589+
f"the flat array into shape {template_subary.shape}. "
590+
"This functionality needs to be implemented by the "
591+
"array context.") from exc
592+
593+
if hasattr(template_subary, "strides"):
594+
if template_subary.strides != subary.strides:
595+
raise ValueError(
596+
f"strides do not match template: got {subary.strides}, "
597+
f"expected {template_subary.strides}")
598+
599+
offset += template_subary.size
600+
return subary
601+
else:
602+
return deserialize_container(template_subary, [
603+
(key, _unflatten(isubary)) for key, isubary in iterable
604+
])
605+
606+
if not isinstance(ary, actx.array_types):
607+
raise TypeError("'ary' does not have a type supported by the provided "
608+
f"array context: got '{type(ary).__name__}', expected one of "
609+
f"{actx.array_types}")
610+
611+
if ary.ndim != 1:
612+
raise ValueError(
613+
"only one dimensional arrays can be unflattened: "
614+
f"'ary' has shape {ary.shape}")
615+
616+
result = _unflatten(template)
617+
if offset != ary.size:
618+
raise ValueError("'template' and 'ary' sizes do not match: "
619+
"'ary' is too large")
620+
621+
return result
622+
623+
# }}}
624+
625+
496626
# {{{ numpy conversion
497627

498628
def from_numpy(ary: Any, actx: ArrayContext) -> Any:

arraycontext/impl/pyopencl/fake_numpy.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -172,8 +172,10 @@ def stack(self, arrays, axis=0):
172172
queue=self._array_context.queue),
173173
*arrays)
174174

175-
def reshape(self, a, newshape):
176-
return cl_array.reshape(a, newshape)
175+
def reshape(self, a, newshape, order="C"):
176+
return rec_map_array_container(
177+
lambda ary: ary.reshape(newshape, order=order),
178+
a)
177179

178180
def concatenate(self, arrays, axis=0):
179181
return cl_array.concatenate(

arraycontext/impl/pytato/fake_numpy.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,10 @@ def __getattr__(self, name):
6464

6565
return super().__getattr__(name)
6666

67-
def reshape(self, a, newshape):
68-
return rec_multimap_array_container(pt.reshape, a, newshape)
67+
def reshape(self, a, newshape, order="C"):
68+
return rec_map_array_container(
69+
lambda ary: pt.reshape(a, newshape, order=order),
70+
a)
6971

7072
def transpose(self, a, axes=None):
7173
return rec_multimap_array_container(pt.transpose, a, axes)

test/test_arraycontext.py

Lines changed: 87 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ def imag(self):
159159

160160
@serialize_container.register(DOFArray)
161161
def _serialize_dof_container(ary: DOFArray):
162-
return enumerate(ary.data)
162+
return list(enumerate(ary.data))
163163

164164

165165
@deserialize_container.register(DOFArray)
@@ -203,17 +203,27 @@ def randn(shape, dtype):
203203
rng = np.random.default_rng()
204204
dtype = np.dtype(dtype)
205205

206+
if shape == 0:
207+
ashape = 1
208+
else:
209+
ashape = shape
210+
206211
if dtype.kind == "c":
207212
dtype = np.dtype(f"<f{dtype.itemsize // 2}")
208-
return rng.standard_normal(shape, dtype) \
209-
+ 1j * rng.standard_normal(shape, dtype)
213+
r = rng.standard_normal(ashape, dtype) \
214+
+ 1j * rng.standard_normal(ashape, dtype)
210215
elif dtype.kind == "f":
211-
return rng.standard_normal(shape, dtype)
216+
r = rng.standard_normal(ashape, dtype)
212217
elif dtype.kind == "i":
213-
return rng.integers(0, 128, shape, dtype)
218+
r = rng.integers(0, 512, ashape, dtype)
214219
else:
215220
raise TypeError(dtype.kind)
216221

222+
if shape == 0:
223+
return np.array(r[0])
224+
225+
return r
226+
217227

218228
def assert_close_to_numpy(actx, op, args):
219229
assert np.allclose(
@@ -672,11 +682,14 @@ def array_context(self):
672682
return self.mass.array_context
673683

674684

675-
def _get_test_containers(actx, ambient_dim=2, size=50_000):
676-
if size == 0:
677-
x = DOFArray(actx, (actx.from_numpy(np.array(np.random.randn())),))
678-
else:
679-
x = DOFArray(actx, (actx.from_numpy(np.random.randn(size)),))
685+
def _get_test_containers(actx, ambient_dim=2, shapes=50_000):
686+
from numbers import Number
687+
if isinstance(shapes, (Number, tuple)):
688+
shapes = [shapes]
689+
690+
x = DOFArray(actx, tuple([
691+
actx.from_numpy(randn(shape, np.float64))
692+
for shape in shapes]))
680693

681694
# pylint: disable=unexpected-keyword-arg, no-value-for-parameter
682695
dataclass_of_dofs = MyContainer(
@@ -705,7 +718,7 @@ def _get_test_containers(actx, ambient_dim=2, size=50_000):
705718
def test_container_scalar_map(actx_factory):
706719
actx = actx_factory()
707720

708-
arys = _get_test_containers(actx, size=0)
721+
arys = _get_test_containers(actx, shapes=0)
709722
arys += (np.pi,)
710723

711724
from arraycontext import (
@@ -877,16 +890,76 @@ def test_container_norm(actx_factory, ord):
877890
# }}}
878891

879892

893+
# {{{ test flatten and unflatten
894+
895+
@pytest.mark.parametrize("shapes", [
896+
0, # tests device scalars when flattening
897+
512,
898+
[(128, 67)],
899+
[(127, 67), (18, 0)], # tests 0-sized arrays
900+
[(64, 7), (154, 12)]
901+
])
902+
def test_flatten_array_container(actx_factory, shapes):
903+
if np.prod(shapes) == 0:
904+
# https://github.com/inducer/loopy/pull/497
905+
# NOTE: only fails for the pytato array context at the moment
906+
pytest.xfail("strides do not match in subary")
907+
908+
actx = actx_factory()
909+
910+
from arraycontext import flatten, unflatten
911+
arys = _get_test_containers(actx, shapes=shapes)
912+
913+
for ary in arys:
914+
flat = flatten(ary, actx)
915+
assert flat.ndim == 1
916+
917+
ary_roundtrip = unflatten(ary, flat, actx)
918+
919+
from arraycontext import rec_multimap_reduce_array_container
920+
assert rec_multimap_reduce_array_container(
921+
np.prod,
922+
lambda x, y: x.shape == y.shape,
923+
ary, ary_roundtrip)
924+
925+
assert actx.to_numpy(
926+
actx.np.linalg.norm(ary - ary_roundtrip)
927+
) < 1.0e-15
928+
929+
930+
def test_flatten_array_container_failure(actx_factory):
931+
actx = actx_factory()
932+
933+
from arraycontext import flatten, unflatten
934+
ary = _get_test_containers(actx, shapes=512)[0]
935+
flat_ary = flatten(ary, actx)
936+
937+
with pytest.raises(TypeError):
938+
# cannot unflatten from a numpy array
939+
unflatten(ary, actx.to_numpy(flat_ary), actx)
940+
941+
with pytest.raises(ValueError):
942+
# cannot unflatten non-flat arrays
943+
unflatten(ary, flat_ary.reshape(2, -1), actx)
944+
945+
with pytest.raises(ValueError):
946+
# cannot unflatten partially
947+
unflatten(ary, flat_ary[:-1], actx)
948+
949+
# }}}
950+
951+
880952
# {{{ test from_numpy and to_numpy
881953

882954
def test_numpy_conversion(actx_factory):
883955
actx = actx_factory()
884956

957+
nelements = 42
885958
ac = MyContainer(
886959
name="test_numpy_conversion",
887-
mass=np.random.rand(42),
888-
momentum=make_obj_array([np.random.rand(42) for _ in range(3)]),
889-
enthalpy=np.random.rand(42),
960+
mass=np.random.rand(nelements, nelements),
961+
momentum=make_obj_array([np.random.rand(nelements) for _ in range(3)]),
962+
enthalpy=np.array(np.random.rand()),
890963
)
891964

892965
from arraycontext import from_numpy, to_numpy

0 commit comments

Comments
 (0)