@@ -109,6 +109,44 @@ def test_apply_along_axis_invalid_axis():
109109 xp .apply_along_axis (xp .sum , axis , a )
110110
111111
112+ class TestApplyOverAxes (unittest .TestCase ):
113+
114+ @testing .numpy_cupy_array_equal (type_check = has_support_aspect64 ())
115+ def test_simple (self , xp ):
116+ a = xp .arange (24 ).reshape (2 , 3 , 4 )
117+ aoa_a = xp .apply_over_axes (xp .sum , a , [0 , 2 ])
118+ return aoa_a
119+
120+ def test_apply_over_axis_invalid_0darr (self ):
121+ # cupy will not accept 0darr, but numpy does
122+ with pytest .raises (AxisError ):
123+ a = cupy .array (42 )
124+ cupy .apply_over_axes (cupy .sum , a , 0 )
125+ # test for numpy, it can run without error
126+ a = numpy .array (42 )
127+ numpy .apply_over_axes (numpy .sum , a , 0 )
128+
129+ @testing .numpy_cupy_allclose (type_check = has_support_aspect64 ())
130+ def test_apply_over_axis_shape_preserve_func (self , xp ):
131+ a = xp .arange (10 ).reshape (2 , 5 , 1 )
132+
133+ def normalize (arr , axis ):
134+ """shape-preserve operation, return {x_i/sum(x)}"""
135+ row_sums = arr .sum (axis = axis )
136+ return a / row_sums [:, xp .newaxis ]
137+
138+ aoa_a = xp .apply_over_axes (normalize , a , 1 )
139+ assert a .shape == aoa_a .shape
140+ return aoa_a
141+
142+ def test_apply_over_axis_invalid_axis (self ):
143+ for xp in [numpy , cupy ]:
144+ a = xp .ones ((8 , 4 ))
145+ axis = 3
146+ with pytest .raises (AxisError ):
147+ xp .apply_over_axes (xp .sum , a , axis )
148+
149+
112150class TestPutAlongAxis (unittest .TestCase ):
113151
114152 @testing .for_all_dtypes ()
0 commit comments