|
38 | 38 |
|
39 | 39 |
|
40 | 40 | import numpy |
41 | | -from dpctl.tensor._numpy_helper import normalize_axis_index |
| 41 | +from dpctl.tensor._numpy_helper import ( |
| 42 | + normalize_axis_index, |
| 43 | + normalize_axis_tuple, |
| 44 | +) |
42 | 45 |
|
43 | 46 | import dpnp |
44 | 47 |
|
45 | | -__all__ = ["apply_along_axis"] |
| 48 | +__all__ = ["apply_along_axis", "apply_over_axes"] |
46 | 49 |
|
47 | 50 |
|
48 | 51 | def apply_along_axis(func1d, axis, arr, *args, **kwargs): |
@@ -185,3 +188,83 @@ def apply_along_axis(func1d, axis, arr, *args, **kwargs): |
185 | 188 | buff = dpnp.moveaxis(buff, -1, axis) |
186 | 189 |
|
187 | 190 | return buff |
| 191 | + |
| 192 | + |
| 193 | +def apply_over_axes(func, a, axes): |
| 194 | + """ |
| 195 | + Apply a function repeatedly over multiple axes. |
| 196 | +
|
| 197 | + `func` is called as ``res = func(a, axis)``, where `axis` is the first |
| 198 | + element of `axes`. The result `res` of the function call must have |
| 199 | + either the same dimensions as `a` or one less dimension. If `res` |
| 200 | + has one less dimension than `a`, a dimension is inserted before |
| 201 | + `axis`. The call to `func` is then repeated for each axis in `axes`, |
| 202 | + with `res` as the first argument. |
| 203 | +
|
| 204 | + For full documentation refer to :obj:`numpy.apply_over_axes`. |
| 205 | +
|
| 206 | + Parameters |
| 207 | + ---------- |
| 208 | + func : function |
| 209 | + This function must take two arguments, ``func(a, axis)``. |
| 210 | + a : {dpnp.ndarray, usm_ndarray} |
| 211 | + Input array. |
| 212 | + axes : {int, sequence of ints} |
| 213 | + Axes over which `func` is applied. |
| 214 | +
|
| 215 | + Returns |
| 216 | + ------- |
| 217 | + out : dpnp.ndarray |
| 218 | + The output array. The number of dimensions is the same as `a`, |
| 219 | + but the shape can be different. This depends on whether `func` |
| 220 | + changes the shape of its output with respect to its input. |
| 221 | +
|
| 222 | + See Also |
| 223 | + -------- |
| 224 | + :obj:`dpnp.apply_along_axis` : Apply a function to 1-D slices of an array |
| 225 | + along the given axis. |
| 226 | +
|
| 227 | + Examples |
| 228 | + -------- |
| 229 | + >>> import dpnp as np |
| 230 | + >>> a = np.arange(24).reshape(2, 3, 4) |
| 231 | + >>> a |
| 232 | + array([[[ 0, 1, 2, 3], |
| 233 | + [ 4, 5, 6, 7], |
| 234 | + [ 8, 9, 10, 11]], |
| 235 | + [[12, 13, 14, 15], |
| 236 | + [16, 17, 18, 19], |
| 237 | + [20, 21, 22, 23]]]) |
| 238 | +
|
| 239 | + Sum over axes 0 and 2. The result has same number of dimensions |
| 240 | + as the original array: |
| 241 | +
|
| 242 | + >>> np.apply_over_axes(np.sum, a, [0, 2]) |
| 243 | + array([[[ 60], |
| 244 | + [ 92], |
| 245 | + [124]]]) |
| 246 | +
|
| 247 | + Tuple axis arguments to ufuncs are equivalent: |
| 248 | +
|
| 249 | + >>> np.sum(a, axis=(0, 2), keepdims=True) |
| 250 | + array([[[ 60], |
| 251 | + [ 92], |
| 252 | + [124]]]) |
| 253 | +
|
| 254 | + """ |
| 255 | + |
| 256 | + dpnp.check_supported_arrays_type(a) |
| 257 | + if isinstance(axes, int): |
| 258 | + axes = (axes,) |
| 259 | + axes = normalize_axis_tuple(axes, a.ndim) |
| 260 | + |
| 261 | + for axis in axes: |
| 262 | + res = func(a, axis) |
| 263 | + if res.ndim != a.ndim: |
| 264 | + res = dpnp.expand_dims(res, axis) |
| 265 | + if res.ndim != a.ndim: |
| 266 | + raise ValueError( |
| 267 | + "function is not returning an array of the correct shape" |
| 268 | + ) |
| 269 | + a = res |
| 270 | + return res |
0 commit comments