diff --git a/dpnp/dpnp_iface_manipulation.py b/dpnp/dpnp_iface_manipulation.py index fb58d1e4b48d..68cd417031cb 100644 --- a/dpnp/dpnp_iface_manipulation.py +++ b/dpnp/dpnp_iface_manipulation.py @@ -97,6 +97,7 @@ "transpose", "trim_zeros", "unique", + "unstack", "vsplit", "vstack", ] @@ -1723,6 +1724,8 @@ def hstack(tup, *, dtype=None, casting="same_kind"): :obj:`dpnp.block` : Assemble an ndarray from nested lists of blocks. :obj:`dpnp.split` : Split array into a list of multiple sub-arrays of equal size. + :obj:`dpnp.unstack` : Split an array into a tuple of sub-arrays along + an axis. Examples -------- @@ -2913,6 +2916,8 @@ def stack(arrays, /, *, axis=0, out=None, dtype=None, casting="same_kind"): :obj:`dpnp.block` : Assemble an ndarray from nested lists of blocks. :obj:`dpnp.split` : Split array into a list of multiple sub-arrays of equal size. + :obj:`dpnp.unstack` : Split an array into a tuple of sub-arrays along + an axis. Examples -------- @@ -3413,6 +3418,84 @@ def unique( return _unpack_tuple(result) +def unstack(x, /, *, axis=0): + """ + Split an array into a sequence of arrays along the given axis. + + The `axis` parameter specifies the dimension along which the array will + be split. For example, if ``axis=0`` (the default) it will be the first + dimension and if ``axis=-1`` it will be the last dimension. + + The result is a tuple of arrays split along `axis`. + + For full documentation refer to :obj:`numpy.unstack`. + + Parameters + ---------- + x : {dpnp.ndarray, usm_ndarray} + The array to be unstacked. + axis : int, optional + Axis along which the array will be split. + Default: ``0``. + + Returns + ------- + unstacked : tuple of dpnp.ndarray + The unstacked arrays. + + See Also + -------- + :obj:`dpnp.stack` : Join a sequence of arrays along a new axis. + :obj:`dpnp.concatenate` : Join a sequence of arrays along an existing axis. + :obj:`dpnp.block` : Assemble an ndarray from nested lists of blocks. + :obj:`dpnp.split` : Split array into a list of multiple sub-arrays of equal + size. + + Notes + ----- + :obj:`dpnp.unstack` serves as the reverse operation of :obj:`dpnp.stack`, + i.e., ``dpnp.stack(dpnp.unstack(x, axis=axis), axis=axis) == x``. + + This function is equivalent to ``tuple(dpnp.moveaxis(x, axis, 0))``, since + iterating on an array iterates along the first axis. + + Examples + -------- + >>> import dpnp as np + >>> arr = np.arange(24).reshape((2, 3, 4)) + >>> np.unstack(arr) + (array([[ 0, 1, 2, 3], + [ 4, 5, 6, 7], + [ 8, 9, 10, 11]]), + array([[12, 13, 14, 15], + [16, 17, 18, 19], + [20, 21, 22, 23]])) + + >>> np.unstack(arr, axis=1) + (array([[ 0, 1, 2, 3], + [12, 13, 14, 15]]), + array([[ 4, 5, 6, 7], + [16, 17, 18, 19]]), + array([[ 8, 9, 10, 11], + [20, 21, 22, 23]])) + + >>> arr2 = np.stack(np.unstack(arr, axis=1), axis=1) + >>> arr2.shape + (2, 3, 4) + >>> np.all(arr == arr2) + array(True) + + """ + + usm_x = dpnp.get_usm_ndarray(x) + + if usm_x.ndim == 0: + raise ValueError("Input array must be at least 1-d.") + + res = dpt.unstack(usm_x, axis=axis) + return tuple(dpnp_array._create_from_usm_ndarray(a) for a in res) + + def vsplit(ary, indices_or_sections): """ Split an array into multiple sub-arrays vertically (row-wise). @@ -3521,6 +3604,8 @@ def vstack(tup, *, dtype=None, casting="same_kind"): :obj:`dpnp.block` : Assemble an ndarray from nested lists of blocks. :obj:`dpnp.split` : Split array into a list of multiple sub-arrays of equal size. + :obj:`dpnp.unstack` : Split an array into a tuple of sub-arrays along + an axis. Examples -------- diff --git a/tests/test_arraymanipulation.py b/tests/test_arraymanipulation.py index 7eb210471fa5..0760988af961 100644 --- a/tests/test_arraymanipulation.py +++ b/tests/test_arraymanipulation.py @@ -866,6 +866,85 @@ def test_generator(self): dpnp.stack(map(lambda x: x, dpnp.ones((3, 2)))) +# numpy.unstack() is available since numpy >= 2.1 +@testing.with_requires("numpy>=2.1") +class TestUnstack: + def test_non_array_input(self): + with pytest.raises(TypeError): + dpnp.unstack(1) + + @pytest.mark.parametrize( + "input", [([1, 2, 3],), [dpnp.int32(1), dpnp.int32(2), dpnp.int32(3)]] + ) + def test_scalar_input(self, input): + with pytest.raises(TypeError): + dpnp.unstack(input) + + @pytest.mark.parametrize("dtype", get_all_dtypes()) + def test_0d_array_input(self, dtype): + np_a = numpy.array(1, dtype=dtype) + dp_a = dpnp.array(np_a, dtype=dtype) + + with pytest.raises(ValueError): + numpy.unstack(np_a) + with pytest.raises(ValueError): + dpnp.unstack(dp_a) + + @pytest.mark.parametrize("dtype", get_all_dtypes()) + def test_1d_array(self, dtype): + np_a = numpy.array([1, 2, 3], dtype=dtype) + dp_a = dpnp.array(np_a, dtype=dtype) + + np_res = numpy.unstack(np_a) + dp_res = dpnp.unstack(dp_a) + assert len(dp_res) == len(np_res) + for dp_arr, np_arr in zip(dp_res, np_res): + assert_array_equal(dp_arr.asnumpy(), np_arr) + + @pytest.mark.parametrize("dtype", get_all_dtypes()) + def test_2d_array(self, dtype): + np_a = numpy.array([[1, 2, 3], [4, 5, 6]], dtype=dtype) + dp_a = dpnp.array(np_a, dtype=dtype) + + np_res = numpy.unstack(np_a, axis=0) + dp_res = dpnp.unstack(dp_a, axis=0) + assert len(dp_res) == len(np_res) + for dp_arr, np_arr in zip(dp_res, np_res): + assert_array_equal(dp_arr.asnumpy(), np_arr) + + @pytest.mark.parametrize("axis", [0, 1, -1]) + @pytest.mark.parametrize("dtype", get_all_dtypes()) + def test_2d_array_axis(self, axis, dtype): + np_a = numpy.array([[1, 2, 3], [4, 5, 6]], dtype=dtype) + dp_a = dpnp.array(np_a, dtype=dtype) + + np_res = numpy.unstack(np_a, axis=axis) + dp_res = dpnp.unstack(dp_a, axis=axis) + assert len(dp_res) == len(np_res) + for dp_arr, np_arr in zip(dp_res, np_res): + assert_array_equal(dp_arr.asnumpy(), np_arr) + + @pytest.mark.parametrize("axis", [2, -3]) + @pytest.mark.parametrize("dtype", get_all_dtypes()) + def test_invalid_axis(self, axis, dtype): + np_a = numpy.array([[1, 2, 3], [4, 5, 6]], dtype=dtype) + dp_a = dpnp.array(np_a, dtype=dtype) + + with pytest.raises(AxisError): + numpy.unstack(np_a, axis=axis) + with pytest.raises(AxisError): + dpnp.unstack(dp_a, axis=axis) + + @pytest.mark.parametrize("dtype", get_all_dtypes()) + def test_empty_array(self, dtype): + np_a = numpy.array([], dtype=dtype) + dp_a = dpnp.array(np_a, dtype=dtype) + + np_res = numpy.unstack(np_a) + dp_res = dpnp.unstack(dp_a) + assert len(dp_res) == len(np_res) + + class TestVstack: def test_non_iterable(self): assert_raises(TypeError, dpnp.vstack, 1)