Skip to content

Commit 5d712a8

Browse files
committed
ENH: Add unstack()
unstack() is a new function from the 2023.12 version of the array API, which serves as the inverse of stack(), that is, unstack(x, axis=axis) returns a tuple of arrays along axis that when stacked along axis would recreate x. The implementation is in pure Python, since it is just a straightforward iteration along an index, but if it is preferable it can be moved to an implementation in C. I haven't yet added any tests.
1 parent 657c714 commit 5d712a8

File tree

2 files changed

+80
-6
lines changed

2 files changed

+80
-6
lines changed

numpy/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,8 +165,8 @@
165165
str_, subtract, sum, swapaxes, take, tan, tanh, tensordot,
166166
timedelta64, trace, transpose, true_divide, trunc, typecodes, ubyte,
167167
ufunc, uint, uint16, uint32, uint64, uint8, uintc, uintp, ulong,
168-
ulonglong, unsignedinteger, ushort, var, vdot, vecdot, void, vstack,
169-
where, zeros, zeros_like
168+
ulonglong, unsignedinteger, unstack, ushort, var, vdot, vecdot, void,
169+
vstack, where, zeros, zeros_like
170170
)
171171

172172
# NOTE: It's still under discussion whether these aliases

numpy/_core/shape_base.py

Lines changed: 78 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
__all__ = ['atleast_1d', 'atleast_2d', 'atleast_3d', 'block', 'hstack',
2-
'stack', 'vstack']
2+
'stack', 'unstack', 'vstack']
33

44
import functools
55
import itertools
@@ -11,7 +11,6 @@
1111
from .multiarray import array, asanyarray, normalize_axis_index
1212
from . import fromnumeric as _from_nx
1313

14-
1514
array_function_dispatch = functools.partial(
1615
overrides.array_function_dispatch, module='numpy')
1716

@@ -261,6 +260,7 @@ def vstack(tup, *, dtype=None, casting="same_kind"):
261260
dstack : Stack arrays in sequence depth wise (along third axis).
262261
column_stack : Stack 1-D arrays as columns into a 2-D array.
263262
vsplit : Split an array into multiple sub-arrays vertically (row-wise).
263+
unstack : Split an array into a tuple of sub-arrays along an axis.
264264
265265
Examples
266266
--------
@@ -331,8 +331,9 @@ def hstack(tup, *, dtype=None, casting="same_kind"):
331331
vstack : Stack arrays in sequence vertically (row wise).
332332
dstack : Stack arrays in sequence depth wise (along third axis).
333333
column_stack : Stack 1-D arrays as columns into a 2-D array.
334-
hsplit : Split an array into multiple sub-arrays
334+
hsplit : Split an array into multiple sub-arrays
335335
horizontally (column-wise).
336+
unstack : Split an array into a tuple of sub-arrays along an axis.
336337
337338
Examples
338339
--------
@@ -414,6 +415,7 @@ def stack(arrays, axis=0, out=None, *, dtype=None, casting="same_kind"):
414415
concatenate : Join a sequence of arrays along an existing axis.
415416
block : Assemble an nd-array from nested lists of blocks.
416417
split : Split array into a list of multiple sub-arrays of equal size.
418+
unstack : Split an array into a tuple of sub-arrays along an axis.
417419
418420
Examples
419421
--------
@@ -455,6 +457,77 @@ def stack(arrays, axis=0, out=None, *, dtype=None, casting="same_kind"):
455457
return _nx.concatenate(expanded_arrays, axis=axis, out=out,
456458
dtype=dtype, casting=casting)
457459

460+
def _unstack_dispatcher(x, *, axis=None):
461+
return (x,)
462+
463+
@array_function_dispatch(_unstack_dispatcher)
464+
def unstack(x, /, *, axis=0):
465+
"""
466+
Splits an array into a sequence of arrays along the given axis.
467+
468+
The ``axis`` parameter specifies the axis along which the array will be
469+
split. of the new axis in the dimensions of the result. For example, if
470+
``axis=0`` it will be the first dimension and if ``axis=-1`` it will be
471+
the last dimension.
472+
473+
The result is a tuple of arrays split along ``axis``. ``unstack`` serves
474+
as the reverse operation of :py:func:`stack`, i.e., ``stack(unstack(x,
475+
axis=axis), axis=axis) == x``.
476+
477+
.. versionadded:: 2.1.0
478+
479+
Parameters
480+
----------
481+
x : ndarray
482+
The array to be unstacked.
483+
484+
Returns
485+
-------
486+
unstacked : tuple of ndarrays
487+
The unstacked arrays.
488+
489+
See Also
490+
--------
491+
stack : Join a sequence of arrays along a new axis.
492+
concatenate : Join a sequence of arrays along an existing axis.
493+
vstack : Stack arrays in sequence vertically (row wise).
494+
hstack : Stack arrays in sequence horizontally (column wise).
495+
dstack : Stack arrays in sequence depth wise (along third axis).
496+
column_stack : Stack 1-D arrays as columns into a 2-D array.
497+
vsplit : Split an array into multiple sub-arrays vertically (row-wise).
498+
unstack : Split an array into a tuple of sub-arrays along an axis.
499+
block : Assemble an nd-array from nested lists of blocks.
500+
split : Split array into a list of multiple sub-arrays of equal size.
501+
502+
Examples
503+
--------
504+
>>> arr = np.arange(24).reshape((2, 3, 4))
505+
>>> np.unstack(arr)
506+
(array([[ 0, 1, 2, 3],
507+
[ 4, 5, 6, 7],
508+
[ 8, 9, 10, 11]]),
509+
array([[12, 13, 14, 15],
510+
[16, 17, 18, 19],
511+
[20, 21, 22, 23]]))
512+
>>> np.unstack(arr, axis=1)
513+
(array([[ 0, 1, 2, 3],
514+
[12, 13, 14, 15]]),
515+
array([[ 4, 5, 6, 7],
516+
[16, 17, 18, 19]]),
517+
array([[ 8, 9, 10, 11],
518+
[20, 21, 22, 23]]))
519+
>>> arr2 = np.stack(np.unstack(arr, axis=1), axis=1)
520+
>>> arr2.shape
521+
(2, 3, 4)
522+
>>> np.all(arr == arr2)
523+
np.True_
524+
525+
"""
526+
x = asanyarray(x)
527+
528+
axis = normalize_axis_index(axis, x.ndim)
529+
slices = (slice(None),) * axis
530+
return tuple(x[slices + (i, ...)] for i in range(x.shape[axis]))
458531

459532
# Internal functions to eliminate the overhead of repeated dispatch in one of
460533
# the two possible paths inside np.block.
@@ -709,7 +782,7 @@ def block(arrays):
709782
second-last dimension (-2), and so on until the outermost list is reached.
710783
711784
Blocks can be of any dimension, but will not be broadcasted using
712-
the normal rules. Instead, leading axes of size 1 are inserted,
785+
the normal rules. Instead, leading axes of size 1 are inserted,
713786
to make ``block.ndim`` the same for all blocks. This is primarily useful
714787
for working with scalars, and means that code like ``np.block([v, 1])``
715788
is valid, where ``v.ndim == 1``.
@@ -755,6 +828,7 @@ def block(arrays):
755828
dstack : Stack arrays in sequence depth wise (along third axis).
756829
column_stack : Stack 1-D arrays as columns into a 2-D array.
757830
vsplit : Split an array into multiple sub-arrays vertically (row-wise).
831+
unstack : Split an array into a tuple of sub-arrays along an axis.
758832
759833
Notes
760834
-----

0 commit comments

Comments
 (0)