Skip to content

Commit c826993

Browse files
alexfiklinducer
andcommitted
unflatten: better check that template and ary sizes match
Co-authored-by: Andreas Klöckner <[email protected]>
1 parent 853f501 commit c826993

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

arraycontext/container/traversal.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -570,7 +570,8 @@ def _unflatten(template_subary: ArrayOrContainerT) -> ArrayOrContainerT:
570570
iterable = serialize_container(template_subary)
571571
except TypeError:
572572
if (offset + template_subary.size) > ary.size:
573-
raise ValueError("'template' and 'ary' sizes do not match")
573+
raise ValueError("'template' and 'ary' sizes do not match: "
574+
"'template' is too large")
574575

575576
if template_subary.dtype != ary.dtype:
576577
raise ValueError("'template' dtype does not match 'ary': "
@@ -612,7 +613,12 @@ def _unflatten(template_subary: ArrayOrContainerT) -> ArrayOrContainerT:
612613
"only one dimensional arrays can be unflattened: "
613614
f"'ary' has shape {ary.shape}")
614615

615-
return _unflatten(template)
616+
result = _unflatten(template)
617+
if offset != ary.size:
618+
raise ValueError("'template' and 'ary' sizes do not match: "
619+
"'ary' is too large")
620+
621+
return result
616622

617623
# }}}
618624

0 commit comments

Comments
 (0)