@@ -109,6 +109,43 @@ def test_apply_along_axis_invalid_axis():
109109 xp .apply_along_axis (xp .sum , axis , a )
110110
111111
112+ class TestApplyOverAxes (unittest .TestCase ):
113+ @testing .numpy_cupy_array_equal (type_check = has_support_aspect64 ())
114+ def test_simple (self , xp ):
115+ a = xp .arange (24 ).reshape (2 , 3 , 4 )
116+ aoa_a = xp .apply_over_axes (xp .sum , a , [0 , 2 ])
117+ return aoa_a
118+
119+ def test_apply_over_axis_invalid_0darr (self ):
120+ # cupy will not accept 0darr, but numpy does
121+ with pytest .raises (AxisError ):
122+ a = cupy .array (42 )
123+ cupy .apply_over_axes (cupy .sum , a , 0 )
124+ # test for numpy, it can run without error
125+ a = numpy .array (42 )
126+ numpy .apply_over_axes (numpy .sum , a , 0 )
127+
128+ @testing .numpy_cupy_allclose (type_check = has_support_aspect64 ())
129+ def test_apply_over_axis_shape_preserve_func (self , xp ):
130+ a = xp .arange (10 ).reshape (2 , 5 , 1 )
131+
132+ def normalize (arr , axis ):
133+ """shape-preserve operation, return {x_i/sum(x)}"""
134+ row_sums = arr .sum (axis = axis )
135+ return a / row_sums [:, xp .newaxis ]
136+
137+ aoa_a = xp .apply_over_axes (normalize , a , 1 )
138+ assert a .shape == aoa_a .shape
139+ return aoa_a
140+
141+ def test_apply_over_axis_invalid_axis (self ):
142+ for xp in [numpy , cupy ]:
143+ a = xp .ones ((8 , 4 ))
144+ axis = 3
145+ with pytest .raises (AxisError ):
146+ xp .apply_over_axes (xp .sum , a , axis )
147+
148+
112149class TestPutAlongAxis (unittest .TestCase ):
113150
114151 @testing .for_all_dtypes ()
0 commit comments