diff --git a/dpnp/tests/third_party/cupy/lib_tests/test_shape_base.py b/dpnp/tests/third_party/cupy/lib_tests/test_shape_base.py index 76da1cdb1de0..c241824fa81d 100644 --- a/dpnp/tests/third_party/cupy/lib_tests/test_shape_base.py +++ b/dpnp/tests/third_party/cupy/lib_tests/test_shape_base.py @@ -109,6 +109,44 @@ def test_apply_along_axis_invalid_axis(): xp.apply_along_axis(xp.sum, axis, a) +class TestApplyOverAxes(unittest.TestCase): + + @testing.numpy_cupy_array_equal(type_check=has_support_aspect64()) + def test_simple(self, xp): + a = xp.arange(24).reshape(2, 3, 4) + aoa_a = xp.apply_over_axes(xp.sum, a, [0, 2]) + return aoa_a + + def test_apply_over_axis_invalid_0darr(self): + # cupy will not accept 0darr, but numpy does + with pytest.raises(AxisError): + a = cupy.array(42) + cupy.apply_over_axes(cupy.sum, a, 0) + # test for numpy, it can run without error + a = numpy.array(42) + numpy.apply_over_axes(numpy.sum, a, 0) + + @testing.numpy_cupy_allclose(type_check=has_support_aspect64()) + def test_apply_over_axis_shape_preserve_func(self, xp): + a = xp.arange(10).reshape(2, 5, 1) + + def normalize(arr, axis): + """shape-preserve operation, return {x_i/sum(x)}""" + row_sums = arr.sum(axis=axis) + return a / row_sums[:, xp.newaxis] + + aoa_a = xp.apply_over_axes(normalize, a, 1) + assert a.shape == aoa_a.shape + return aoa_a + + def test_apply_over_axis_invalid_axis(self): + for xp in [numpy, cupy]: + a = xp.ones((8, 4)) + axis = 3 + with pytest.raises(AxisError): + xp.apply_over_axes(xp.sum, a, axis) + + class TestPutAlongAxis(unittest.TestCase): @testing.for_all_dtypes()