Skip to content

Commit bec3052

Browse files
alexfiklinducer
authored andcommitted
allow unflatten to skip dtype and stride checks
1 parent f0e3ac0 commit bec3052

File tree

2 files changed

+56
-6
lines changed

2 files changed

+56
-6
lines changed

arraycontext/container/traversal.py

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -603,29 +603,51 @@ def _flatten(subary: ArrayOrContainerT) -> None:
603603

604604
def unflatten(
605605
template: ArrayOrContainerT, ary: Any,
606-
actx: ArrayContext) -> ArrayOrContainerT:
606+
actx: ArrayContext, *,
607+
strict: bool = True) -> ArrayOrContainerT:
607608
"""Unflatten an array *ary* produced by :func:`flatten` back into an
608609
:class:`~arraycontext.ArrayContainer`.
609610
610611
The order and sizes of each slice into *ary* are determined by the
611612
array container *template*.
613+
614+
:arg strict: if *True* additional :class:`~numpy.dtype` and stride
615+
checking is performed on the unflattened array. Otherwise, these
616+
checks are skipped.
612617
"""
613618
# NOTE: https://github.com/python/mypy/issues/7057
614619
offset = 0
620+
common_dtype = None
615621

616622
def _unflatten(template_subary: ArrayOrContainerT) -> ArrayOrContainerT:
617-
nonlocal offset
623+
nonlocal offset, common_dtype
618624

619625
try:
620626
iterable = serialize_container(template_subary)
621627
except NotAnArrayContainerError:
628+
# {{{ validate subary
629+
622630
if (offset + template_subary.size) > ary.size:
623631
raise ValueError("'template' and 'ary' sizes do not match: "
624632
"'template' is too large")
625633

626-
if template_subary.dtype != ary.dtype:
627-
raise ValueError("'template' dtype does not match 'ary': "
628-
f"got {template_subary.dtype}, expected {ary.dtype}")
634+
if strict:
635+
if template_subary.dtype != ary.dtype:
636+
raise ValueError("'template' dtype does not match 'ary': "
637+
f"got {template_subary.dtype}, expected {ary.dtype}")
638+
else:
639+
# NOTE: still require that *template* has a uniform dtype
640+
if common_dtype is None:
641+
common_dtype = template_subary.dtype
642+
else:
643+
if common_dtype != template_subary.dtype:
644+
raise ValueError("arrays in 'template' have different "
645+
f"dtypes: got {template_subary.dtype}, but "
646+
f"expected {common_dtype}.")
647+
648+
# }}}
649+
650+
# {{{ reshape
629651

630652
flat_subary = ary[offset:offset + template_subary.size]
631653
try:
@@ -640,12 +662,18 @@ def _unflatten(template_subary: ArrayOrContainerT) -> ArrayOrContainerT:
640662
"This functionality needs to be implemented by the "
641663
"array context.") from exc
642664

643-
if hasattr(template_subary, "strides"):
665+
# }}}
666+
667+
# {{{ check strides
668+
669+
if strict and hasattr(template_subary, "strides"):
644670
if template_subary.strides != subary.strides:
645671
raise ValueError(
646672
f"strides do not match template: got {subary.strides}, "
647673
f"expected {template_subary.strides}")
648674

675+
# }}}
676+
649677
offset += template_subary.size
650678
return subary
651679
else:

test/test_arraycontext.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -926,6 +926,28 @@ def test_flatten_array_container(actx_factory, shapes):
926926
actx.np.linalg.norm(ary - ary_roundtrip)
927927
) < 1.0e-15
928928

929+
# {{{ complex to real
930+
931+
if isinstance(shapes, (int, tuple)):
932+
shapes = [shapes]
933+
934+
ary = DOFArray(actx, tuple([
935+
actx.from_numpy(randn(shape, np.float64))
936+
for shape in shapes]))
937+
938+
template = DOFArray(actx, tuple([
939+
actx.from_numpy(randn(shape, np.complex128))
940+
for shape in shapes]))
941+
942+
flat = flatten(ary, actx)
943+
ary_roundtrip = unflatten(template, flat, actx, strict=False)
944+
945+
assert actx.to_numpy(
946+
actx.np.linalg.norm(ary - ary_roundtrip)
947+
) < 1.0e-15
948+
949+
# }}}
950+
929951

930952
def test_flatten_array_container_failure(actx_factory):
931953
actx = actx_factory()

0 commit comments

Comments
 (0)