@@ -603,29 +603,51 @@ def _flatten(subary: ArrayOrContainerT) -> None:
603603
604604def unflatten (
605605 template : ArrayOrContainerT , ary : Any ,
606- actx : ArrayContext ) -> ArrayOrContainerT :
606+ actx : ArrayContext , * ,
607+ strict : bool = True ) -> ArrayOrContainerT :
607608 """Unflatten an array *ary* produced by :func:`flatten` back into an
608609 :class:`~arraycontext.ArrayContainer`.
609610
610611 The order and sizes of each slice into *ary* are determined by the
611612 array container *template*.
613+
614+ :arg strict: if *True* additional :class:`~numpy.dtype` and stride
615+ checking is performed on the unflattened array. Otherwise, these
616+ checks are skipped.
612617 """
613618 # NOTE: https://github.com/python/mypy/issues/7057
614619 offset = 0
620+ common_dtype = None
615621
616622 def _unflatten (template_subary : ArrayOrContainerT ) -> ArrayOrContainerT :
617- nonlocal offset
623+ nonlocal offset , common_dtype
618624
619625 try :
620626 iterable = serialize_container (template_subary )
621627 except NotAnArrayContainerError :
628+ # {{{ validate subary
629+
622630 if (offset + template_subary .size ) > ary .size :
623631 raise ValueError ("'template' and 'ary' sizes do not match: "
624632 "'template' is too large" )
625633
626- if template_subary .dtype != ary .dtype :
627- raise ValueError ("'template' dtype does not match 'ary': "
628- f"got { template_subary .dtype } , expected { ary .dtype } " )
634+ if strict :
635+ if template_subary .dtype != ary .dtype :
636+ raise ValueError ("'template' dtype does not match 'ary': "
637+ f"got { template_subary .dtype } , expected { ary .dtype } " )
638+ else :
639+ # NOTE: still require that *template* has a uniform dtype
640+ if common_dtype is None :
641+ common_dtype = template_subary .dtype
642+ else :
643+ if common_dtype != template_subary .dtype :
644+ raise ValueError ("arrays in 'template' have different "
645+ f"dtypes: got { template_subary .dtype } , but "
646+ f"expected { common_dtype } ." )
647+
648+ # }}}
649+
650+ # {{{ reshape
629651
630652 flat_subary = ary [offset :offset + template_subary .size ]
631653 try :
@@ -640,12 +662,18 @@ def _unflatten(template_subary: ArrayOrContainerT) -> ArrayOrContainerT:
640662 "This functionality needs to be implemented by the "
641663 "array context." ) from exc
642664
643- if hasattr (template_subary , "strides" ):
665+ # }}}
666+
667+ # {{{ check strides
668+
669+ if strict and hasattr (template_subary , "strides" ):
644670 if template_subary .strides != subary .strides :
645671 raise ValueError (
646672 f"strides do not match template: got { subary .strides } , "
647673 f"expected { template_subary .strides } " )
648674
675+ # }}}
676+
649677 offset += template_subary .size
650678 return subary
651679 else :
0 commit comments