Skip to content

Commit 6ae2d0c

Browse files
authored
Merge pull request numpy#26579 from asmeurer/unstack
ENH: Add unstack()
2 parents c46a513 + 457de03 commit 6ae2d0c

File tree

8 files changed

+138
-6
lines changed

8 files changed

+138
-6
lines changed
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
New function `numpy.unstack`
2+
----------------------------
3+
4+
A new function ``np.unstack(array, axis=...)`` was added, which splits
5+
an array into a tuple of arrays along an axis. It serves as the inverse
6+
of `numpy.stack`.

doc/source/reference/routines.array-manipulation.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ Splitting arrays
8888
dsplit
8989
hsplit
9090
vsplit
91+
unstack
9192

9293
Tiling arrays
9394
=============

numpy/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,8 +165,8 @@
165165
str_, subtract, sum, swapaxes, take, tan, tanh, tensordot,
166166
timedelta64, trace, transpose, true_divide, trunc, typecodes, ubyte,
167167
ufunc, uint, uint16, uint32, uint64, uint8, uintc, uintp, ulong,
168-
ulonglong, unsignedinteger, ushort, var, vdot, vecdot, void, vstack,
169-
where, zeros, zeros_like
168+
ulonglong, unsignedinteger, unstack, ushort, var, vdot, vecdot, void,
169+
vstack, where, zeros, zeros_like
170170
)
171171

172172
# NOTE: It's still under discussion whether these aliases

numpy/__init__.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -399,6 +399,7 @@ from numpy._core.shape_base import (
399399
hstack as hstack,
400400
stack as stack,
401401
vstack as vstack,
402+
unstack as unstack,
402403
)
403404

404405
from numpy.lib import (

numpy/_core/shape_base.py

Lines changed: 77 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
__all__ = ['atleast_1d', 'atleast_2d', 'atleast_3d', 'block', 'hstack',
2-
'stack', 'vstack']
2+
'stack', 'unstack', 'vstack']
33

44
import functools
55
import itertools
@@ -11,7 +11,6 @@
1111
from .multiarray import array, asanyarray, normalize_axis_index
1212
from . import fromnumeric as _from_nx
1313

14-
1514
array_function_dispatch = functools.partial(
1615
overrides.array_function_dispatch, module='numpy')
1716

@@ -261,6 +260,7 @@ def vstack(tup, *, dtype=None, casting="same_kind"):
261260
dstack : Stack arrays in sequence depth wise (along third axis).
262261
column_stack : Stack 1-D arrays as columns into a 2-D array.
263262
vsplit : Split an array into multiple sub-arrays vertically (row-wise).
263+
unstack : Split an array into a tuple of sub-arrays along an axis.
264264
265265
Examples
266266
--------
@@ -331,8 +331,9 @@ def hstack(tup, *, dtype=None, casting="same_kind"):
331331
vstack : Stack arrays in sequence vertically (row wise).
332332
dstack : Stack arrays in sequence depth wise (along third axis).
333333
column_stack : Stack 1-D arrays as columns into a 2-D array.
334-
hsplit : Split an array into multiple sub-arrays
334+
hsplit : Split an array into multiple sub-arrays
335335
horizontally (column-wise).
336+
unstack : Split an array into a tuple of sub-arrays along an axis.
336337
337338
Examples
338339
--------
@@ -414,6 +415,7 @@ def stack(arrays, axis=0, out=None, *, dtype=None, casting="same_kind"):
414415
concatenate : Join a sequence of arrays along an existing axis.
415416
block : Assemble an nd-array from nested lists of blocks.
416417
split : Split array into a list of multiple sub-arrays of equal size.
418+
unstack : Split an array into a tuple of sub-arrays along an axis.
417419
418420
Examples
419421
--------
@@ -456,6 +458,76 @@ def stack(arrays, axis=0, out=None, *, dtype=None, casting="same_kind"):
456458
return _nx.concatenate(expanded_arrays, axis=axis, out=out,
457459
dtype=dtype, casting=casting)
458460

461+
def _unstack_dispatcher(x, /, *, axis=None):
462+
return (x,)
463+
464+
@array_function_dispatch(_unstack_dispatcher)
465+
def unstack(x, /, *, axis=0):
466+
"""
467+
Split an array into a sequence of arrays along the given axis.
468+
469+
The ``axis`` parameter specifies the dimension along which the array will
470+
be split. For example, if ``axis=0`` (the default) it will be the first
471+
dimension and if ``axis=-1`` it will be the last dimension.
472+
473+
The result is a tuple of arrays split along ``axis``.
474+
475+
.. versionadded:: 2.1.0
476+
477+
Parameters
478+
----------
479+
x : ndarray
480+
The array to be unstacked.
481+
axis : int, optional
482+
Axis along which the array will be split. Default: ``0``.
483+
484+
Returns
485+
-------
486+
unstacked : tuple of ndarrays
487+
The unstacked arrays.
488+
489+
See Also
490+
--------
491+
stack : Join a sequence of arrays along a new axis.
492+
concatenate : Join a sequence of arrays along an existing axis.
493+
block : Assemble an nd-array from nested lists of blocks.
494+
split : Split array into a list of multiple sub-arrays of equal size.
495+
496+
Notes
497+
-----
498+
``unstack`` serves as the reverse operation of :py:func:`stack`, i.e.,
499+
``stack(unstack(x, axis=axis), axis=axis) == x``.
500+
501+
This function is equivalent to ``tuple(np.moveaxis(x, axis, 0))``, since
502+
iterating on an array iterates along the first axis.
503+
504+
Examples
505+
--------
506+
>>> arr = np.arange(24).reshape((2, 3, 4))
507+
>>> np.unstack(arr)
508+
(array([[ 0, 1, 2, 3],
509+
[ 4, 5, 6, 7],
510+
[ 8, 9, 10, 11]]),
511+
array([[12, 13, 14, 15],
512+
[16, 17, 18, 19],
513+
[20, 21, 22, 23]]))
514+
>>> np.unstack(arr, axis=1)
515+
(array([[ 0, 1, 2, 3],
516+
[12, 13, 14, 15]]),
517+
array([[ 4, 5, 6, 7],
518+
[16, 17, 18, 19]]),
519+
array([[ 8, 9, 10, 11],
520+
[20, 21, 22, 23]]))
521+
>>> arr2 = np.stack(np.unstack(arr, axis=1), axis=1)
522+
>>> arr2.shape
523+
(2, 3, 4)
524+
>>> np.all(arr == arr2)
525+
np.True_
526+
527+
"""
528+
if x.ndim == 0:
529+
raise ValueError("Input array must be at least 1-d.")
530+
return tuple(_nx.moveaxis(x, axis, 0))
459531

460532
# Internal functions to eliminate the overhead of repeated dispatch in one of
461533
# the two possible paths inside np.block.
@@ -710,7 +782,7 @@ def block(arrays):
710782
second-last dimension (-2), and so on until the outermost list is reached.
711783
712784
Blocks can be of any dimension, but will not be broadcasted using
713-
the normal rules. Instead, leading axes of size 1 are inserted,
785+
the normal rules. Instead, leading axes of size 1 are inserted,
714786
to make ``block.ndim`` the same for all blocks. This is primarily useful
715787
for working with scalars, and means that code like ``np.block([v, 1])``
716788
is valid, where ``v.ndim == 1``.
@@ -756,6 +828,7 @@ def block(arrays):
756828
dstack : Stack arrays in sequence depth wise (along third axis).
757829
column_stack : Stack 1-D arrays as columns into a 2-D array.
758830
vsplit : Split an array into multiple sub-arrays vertically (row-wise).
831+
unstack : Split an array into a tuple of sub-arrays along an axis.
759832
760833
Notes
761834
-----

numpy/_core/shape_base.pyi

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,21 @@ def stack(
117117
casting: _CastingKind = ...
118118
) -> _ArrayType: ...
119119

120+
@overload
121+
def unstack(
122+
array: _ArrayLike[_SCT],
123+
/,
124+
*,
125+
axis: int = ...,
126+
) -> tuple[NDArray[_SCT], ...]: ...
127+
@overload
128+
def unstack(
129+
array: ArrayLike,
130+
/,
131+
*,
132+
axis: int = ...,
133+
) -> tuple[NDArray[Any], ...]: ...
134+
120135
@overload
121136
def block(arrays: _ArrayLike[_SCT]) -> NDArray[_SCT]: ...
122137
@overload

numpy/_core/tests/test_shape_base.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -490,6 +490,39 @@ def test_stack():
490490
stack((a, b), dtype=np.int64, axis=1, casting="safe")
491491

492492

493+
def test_unstack():
494+
a = np.arange(24).reshape((2, 3, 4))
495+
496+
for stacks in [np.unstack(a),
497+
np.unstack(a, axis=0),
498+
np.unstack(a, axis=-3)]:
499+
assert isinstance(stacks, tuple)
500+
assert len(stacks) == 2
501+
assert_array_equal(stacks[0], a[0])
502+
assert_array_equal(stacks[1], a[1])
503+
504+
for stacks in [np.unstack(a, axis=1),
505+
np.unstack(a, axis=-2)]:
506+
assert isinstance(stacks, tuple)
507+
assert len(stacks) == 3
508+
assert_array_equal(stacks[0], a[:, 0])
509+
assert_array_equal(stacks[1], a[:, 1])
510+
assert_array_equal(stacks[2], a[:, 2])
511+
512+
for stacks in [np.unstack(a, axis=2),
513+
np.unstack(a, axis=-1)]:
514+
assert isinstance(stacks, tuple)
515+
assert len(stacks) == 4
516+
assert_array_equal(stacks[0], a[:, :, 0])
517+
assert_array_equal(stacks[1], a[:, :, 1])
518+
assert_array_equal(stacks[2], a[:, :, 2])
519+
assert_array_equal(stacks[3], a[:, :, 3])
520+
521+
assert_raises(ValueError, np.unstack, a, axis=3)
522+
assert_raises(ValueError, np.unstack, a, axis=-4)
523+
assert_raises(ValueError, np.unstack, np.array(0), axis=0)
524+
525+
493526
@pytest.mark.parametrize("axis", [0])
494527
@pytest.mark.parametrize("out_dtype", ["c8", "f4", "f8", ">f8", "i8"])
495528
@pytest.mark.parametrize("casting",

numpy/typing/tests/data/reveal/shape_base.pyi

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,3 +53,6 @@ assert_type(np.kron(AR_f8, AR_f8), npt.NDArray[np.floating[Any]])
5353

5454
assert_type(np.tile(AR_i8, 5), npt.NDArray[np.int64])
5555
assert_type(np.tile(AR_LIKE_f8, [2, 2]), npt.NDArray[Any])
56+
57+
assert_type(np.unstack(AR_i8, axis=0), tuple[npt.NDArray[np.int64], ...])
58+
assert_type(np.unstack(AR_LIKE_f8, axis=0), tuple[npt.NDArray[Any], ...])

0 commit comments

Comments
 (0)