Skip to content

Commit b839c40

Browse files
authored
simplify dpnp.average implementation (#2189)
* simplify dpnp.average implementation * address comments
1 parent ffd3829 commit b839c40

File tree

1 file changed

+29
-35
lines changed

1 file changed

+29
-35
lines changed

dpnp/dpnp_iface_statistics.py

Lines changed: 29 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -47,18 +47,11 @@
4747
import dpnp
4848

4949
# pylint: disable=no-name-in-module
50-
from .dpnp_algo import (
51-
dpnp_correlate,
52-
)
50+
from .dpnp_algo import dpnp_correlate
5351
from .dpnp_array import dpnp_array
54-
from .dpnp_utils import (
55-
call_origin,
56-
get_usm_allocations,
57-
)
52+
from .dpnp_utils import call_origin, get_usm_allocations
5853
from .dpnp_utils.dpnp_utils_reduction import dpnp_wrap_reduction_call
59-
from .dpnp_utils.dpnp_utils_statistics import (
60-
dpnp_cov,
61-
)
54+
from .dpnp_utils.dpnp_utils_statistics import dpnp_cov
6255

6356
__all__ = [
6457
"amax",
@@ -276,60 +269,61 @@ def average(a, axis=None, weights=None, returned=False, *, keepdims=False):
276269
"""
277270

278271
dpnp.check_supported_arrays_type(a)
272+
usm_type, exec_q = get_usm_allocations([a, weights])
273+
279274
if weights is None:
280275
avg = dpnp.mean(a, axis=axis, keepdims=keepdims)
281276
scl = dpnp.asanyarray(
282277
avg.dtype.type(a.size / avg.size),
283-
usm_type=a.usm_type,
284-
sycl_queue=a.sycl_queue,
278+
usm_type=usm_type,
279+
sycl_queue=exec_q,
285280
)
286281
else:
287-
if not isinstance(weights, (dpnp_array, dpt.usm_ndarray)):
288-
wgt = dpnp.asanyarray(
289-
weights, usm_type=a.usm_type, sycl_queue=a.sycl_queue
282+
if not dpnp.is_supported_array_type(weights):
283+
weights = dpnp.asarray(
284+
weights, usm_type=usm_type, sycl_queue=exec_q
290285
)
291-
else:
292-
get_usm_allocations([a, weights])
293-
wgt = weights
294286

295-
if not dpnp.issubdtype(a.dtype, dpnp.inexact):
287+
a_dtype = a.dtype
288+
if not dpnp.issubdtype(a_dtype, dpnp.inexact):
296289
default_dtype = dpnp.default_float_type(a.device)
297-
result_dtype = dpnp.result_type(a.dtype, wgt.dtype, default_dtype)
290+
res_dtype = dpnp.result_type(a_dtype, weights.dtype, default_dtype)
298291
else:
299-
result_dtype = dpnp.result_type(a.dtype, wgt.dtype)
292+
res_dtype = dpnp.result_type(a_dtype, weights.dtype)
300293

301294
# Sanity checks
302-
if a.shape != wgt.shape:
295+
wgt_shape = weights.shape
296+
a_shape = a.shape
297+
if a_shape != wgt_shape:
303298
if axis is None:
304299
raise TypeError(
305300
"Axis must be specified when shapes of input array and "
306301
"weights differ."
307302
)
308-
if wgt.ndim != 1:
303+
if weights.ndim != 1:
309304
raise TypeError(
310305
"1D weights expected when shapes of input array and "
311306
"weights differ."
312307
)
313-
if wgt.shape[0] != a.shape[axis]:
308+
if wgt_shape[0] != a_shape[axis]:
314309
raise ValueError(
315310
"Length of weights not compatible with specified axis."
316311
)
317312

318-
# setup wgt to broadcast along axis
319-
wgt = dpnp.broadcast_to(wgt, (a.ndim - 1) * (1,) + wgt.shape)
320-
wgt = wgt.swapaxes(-1, axis)
313+
# setup weights to broadcast along axis
314+
weights = dpnp.broadcast_to(
315+
weights, (a.ndim - 1) * (1,) + wgt_shape
316+
)
317+
weights = weights.swapaxes(-1, axis)
321318

322-
scl = wgt.sum(axis=axis, dtype=result_dtype, keepdims=keepdims)
319+
scl = weights.sum(axis=axis, dtype=res_dtype, keepdims=keepdims)
323320
if dpnp.any(scl == 0.0):
324321
raise ZeroDivisionError("Weights sum to zero, can't be normalized")
325322

326-
# result_datatype
327-
avg = (
328-
dpnp.multiply(a, wgt).sum(
329-
axis=axis, dtype=result_dtype, keepdims=keepdims
330-
)
331-
/ scl
323+
avg = dpnp.multiply(a, weights).sum(
324+
axis=axis, dtype=res_dtype, keepdims=keepdims
332325
)
326+
avg /= scl
333327

334328
if returned:
335329
if scl.shape != avg.shape:
@@ -556,7 +550,7 @@ def cov(
556550
557551
"""
558552

559-
if not isinstance(m, (dpnp_array, dpt.usm_ndarray)):
553+
if not dpnp.is_supported_array_type(m):
560554
pass
561555
elif m.ndim > 2:
562556
pass

0 commit comments

Comments
 (0)