Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/reference/manipulation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ Joining arrays
dpnp.dstack
dpnp.column_stack
dpnp.row_stack
dpnp.unstack


Splitting arrays
Expand Down
85 changes: 85 additions & 0 deletions dpnp/dpnp_iface_manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@
"transpose",
"trim_zeros",
"unique",
"unstack",
"vsplit",
"vstack",
]
Expand Down Expand Up @@ -1722,6 +1723,8 @@ def hstack(tup, *, dtype=None, casting="same_kind"):
:obj:`dpnp.block` : Assemble an ndarray from nested lists of blocks.
:obj:`dpnp.split` : Split array into a list of multiple sub-arrays of equal
size.
:obj:`dpnp.unstack` : Split an array into a tuple of sub-arrays along
an axis.

Examples
--------
Expand Down Expand Up @@ -2860,6 +2863,8 @@ def stack(arrays, /, *, axis=0, out=None, dtype=None, casting="same_kind"):
:obj:`dpnp.block` : Assemble an ndarray from nested lists of blocks.
:obj:`dpnp.split` : Split array into a list of multiple sub-arrays of equal
size.
:obj:`dpnp.unstack` : Split an array into a tuple of sub-arrays along
an axis.

Examples
--------
Expand Down Expand Up @@ -3360,6 +3365,84 @@ def unique(
return _unpack_tuple(result)


def unstack(x, /, *, axis=0):
"""
Split an array into a sequence of arrays along the given axis.

The `axis` parameter specifies the dimension along which the array will
be split. For example, if `axis=0` (the default) it will be the first
dimension and if `axis=-1` it will be the last dimension.

The result is a tuple of arrays split along `axis`.

For full documentation refer to :obj:`numpy.unstack`.

Parameters
----------
x : {dpnp.ndarray, usm_ndarray}
The array to be unstacked.
axis : int, optional
Axis along which the array will be split.
Default: ``0``.

Returns
-------
unstacked : tuple of dpnp.ndarray
The unstacked arrays.

See Also
--------
:obj:`dpnp.stack` : Join a sequence of arrays along a new axis.
:obj:`dpnp.concatenate` : Join a sequence of arrays along an existing axis.
:obj:`dpnp.block` : Assemble an ndarray from nested lists of blocks.
:obj:`dpnp.split` : Split array into a list of multiple sub-arrays of equal
size.

Notes
-----
``unstack`` serves as the reverse operation of :obj:`dpnp.stack`, i.e.,
``dpnp.stack(dpnp.unstack(x, axis=axis), axis=axis) == x``.

This function is equivalent to ``tuple(dpnp.moveaxis(x, axis, 0))``, since
iterating on an array iterates along the first axis.

Examples
--------
>>> import dpnp as np
>>> arr = np.arange(24).reshape((2, 3, 4))
>>> np.unstack(arr)
(array([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]]),
array([[12, 13, 14, 15],
[16, 17, 18, 19],
[20, 21, 22, 23]]))

>>> np.unstack(arr, axis=1)
(array([[ 0, 1, 2, 3],
[12, 13, 14, 15]]),
array([[ 4, 5, 6, 7],
[16, 17, 18, 19]]),
array([[ 8, 9, 10, 11],
[20, 21, 22, 23]]))

>>> arr2 = np.stack(np.unstack(arr, axis=1), axis=1)
>>> arr2.shape
(2, 3, 4)
>>> np.all(arr == arr2)
array(True)

"""

usm_x = dpnp.get_usm_ndarray(x)

if usm_x.ndim == 0:
raise ValueError("Input array must be at least 1-d.")

res = dpt.unstack(usm_x, axis=axis)
return tuple(dpnp_array._create_from_usm_ndarray(a) for a in res)


def vsplit(ary, indices_or_sections):
"""
Split an array into multiple sub-arrays vertically (row-wise).
Expand Down Expand Up @@ -3468,6 +3551,8 @@ def vstack(tup, *, dtype=None, casting="same_kind"):
:obj:`dpnp.block` : Assemble an ndarray from nested lists of blocks.
:obj:`dpnp.split` : Split array into a list of multiple sub-arrays of equal
size.
:obj:`dpnp.unstack` : Split an array into a tuple of sub-arrays along
an axis.

Examples
--------
Expand Down
79 changes: 79 additions & 0 deletions tests/test_arraymanipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -822,6 +822,85 @@ def test_generator(self):
dpnp.stack(map(lambda x: x, dpnp.ones((3, 2))))


# numpy.unstack() is available since numpy >= 2.1
@testing.with_requires("numpy>=2.1")
class TestUnstack:
def test_non_array_input(self):
with pytest.raises(TypeError):
dpnp.unstack(1)

@pytest.mark.parametrize(
"input", [([1, 2, 3],), [dpnp.int32(1), dpnp.int32(2), dpnp.int32(3)]]
)
def test_scalar_input(self, input):
with pytest.raises(TypeError):
dpnp.unstack(input)

@pytest.mark.parametrize("dtype", get_all_dtypes())
def test_0d_array_input(self, dtype):
np_a = numpy.array(1, dtype=dtype)
dp_a = dpnp.array(np_a, dtype=dtype)

with pytest.raises(ValueError):
numpy.unstack(np_a)
with pytest.raises(ValueError):
dpnp.unstack(dp_a)

@pytest.mark.parametrize("dtype", get_all_dtypes())
def test_1d_array(self, dtype):
np_a = numpy.array([1, 2, 3], dtype=dtype)
dp_a = dpnp.array(np_a, dtype=dtype)

np_res = numpy.unstack(np_a)
dp_res = dpnp.unstack(dp_a)
assert len(dp_res) == len(np_res)
for dp_arr, np_arr in zip(dp_res, np_res):
assert_array_equal(dp_arr.asnumpy(), np_arr)

@pytest.mark.parametrize("dtype", get_all_dtypes())
def test_2d_array(self, dtype):
np_a = numpy.array([[1, 2, 3], [4, 5, 6]], dtype=dtype)
dp_a = dpnp.array(np_a, dtype=dtype)

np_res = numpy.unstack(np_a, axis=0)
dp_res = dpnp.unstack(dp_a, axis=0)
assert len(dp_res) == len(np_res)
for dp_arr, np_arr in zip(dp_res, np_res):
assert_array_equal(dp_arr.asnumpy(), np_arr)

@pytest.mark.parametrize("axis", [0, 1, -1])
@pytest.mark.parametrize("dtype", get_all_dtypes())
def test_2d_array_axis(self, axis, dtype):
np_a = numpy.array([[1, 2, 3], [4, 5, 6]], dtype=dtype)
dp_a = dpnp.array(np_a, dtype=dtype)

np_res = numpy.unstack(np_a, axis=axis)
dp_res = dpnp.unstack(dp_a, axis=axis)
assert len(dp_res) == len(np_res)
for dp_arr, np_arr in zip(dp_res, np_res):
assert_array_equal(dp_arr.asnumpy(), np_arr)

@pytest.mark.parametrize("axis", [2, -3])
@pytest.mark.parametrize("dtype", get_all_dtypes())
def test_invalid_axis(self, axis, dtype):
np_a = numpy.array([[1, 2, 3], [4, 5, 6]], dtype=dtype)
dp_a = dpnp.array(np_a, dtype=dtype)

with pytest.raises(AxisError):
numpy.unstack(np_a, axis=axis)
with pytest.raises(AxisError):
dpnp.unstack(dp_a, axis=axis)

@pytest.mark.parametrize("dtype", get_all_dtypes())
def test_empty_array(self, dtype):
np_a = numpy.array([], dtype=dtype)
dp_a = dpnp.array(np_a, dtype=dtype)

np_res = numpy.unstack(np_a)
dp_res = dpnp.unstack(dp_a)
assert len(dp_res) == len(np_res)


class TestVstack:
def test_non_iterable(self):
assert_raises(TypeError, dpnp.vstack, 1)
Expand Down
Loading