Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 29 additions & 35 deletions dpnp/dpnp_iface_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,18 +47,11 @@
import dpnp

# pylint: disable=no-name-in-module
from .dpnp_algo import (
dpnp_correlate,
)
from .dpnp_algo import dpnp_correlate
from .dpnp_array import dpnp_array
from .dpnp_utils import (
call_origin,
get_usm_allocations,
)
from .dpnp_utils import call_origin, get_usm_allocations
from .dpnp_utils.dpnp_utils_reduction import dpnp_wrap_reduction_call
from .dpnp_utils.dpnp_utils_statistics import (
dpnp_cov,
)
from .dpnp_utils.dpnp_utils_statistics import dpnp_cov

__all__ = [
"amax",
Expand Down Expand Up @@ -276,60 +269,61 @@ def average(a, axis=None, weights=None, returned=False, *, keepdims=False):
"""

dpnp.check_supported_arrays_type(a)
usm_type, exec_q = get_usm_allocations([a, weights])

if weights is None:
avg = dpnp.mean(a, axis=axis, keepdims=keepdims)
scl = dpnp.asanyarray(
avg.dtype.type(a.size / avg.size),
usm_type=a.usm_type,
sycl_queue=a.sycl_queue,
usm_type=usm_type,
sycl_queue=exec_q,
)
else:
if not isinstance(weights, (dpnp_array, dpt.usm_ndarray)):
wgt = dpnp.asanyarray(
weights, usm_type=a.usm_type, sycl_queue=a.sycl_queue
if not dpnp.is_supported_array_type(weights):
weights = dpnp.asarray(
weights, usm_type=usm_type, sycl_queue=exec_q
)
else:
get_usm_allocations([a, weights])
wgt = weights

if not dpnp.issubdtype(a.dtype, dpnp.inexact):
a_dtype = a.dtype
if not dpnp.issubdtype(a_dtype, dpnp.inexact):
default_dtype = dpnp.default_float_type(a.device)
result_dtype = dpnp.result_type(a.dtype, wgt.dtype, default_dtype)
res_dtype = dpnp.result_type(a_dtype, weights.dtype, default_dtype)
else:
result_dtype = dpnp.result_type(a.dtype, wgt.dtype)
res_dtype = dpnp.result_type(a_dtype, weights.dtype)

# Sanity checks
if a.shape != wgt.shape:
wgt_shape = weights.shape
a_shape = a.shape
if a_shape != wgt_shape:
if axis is None:
raise TypeError(
"Axis must be specified when shapes of input array and "
"weights differ."
)
if wgt.ndim != 1:
if weights.ndim != 1:
raise TypeError(
"1D weights expected when shapes of input array and "
"weights differ."
)
if wgt.shape[0] != a.shape[axis]:
if wgt_shape[0] != a_shape[axis]:
raise ValueError(
"Length of weights not compatible with specified axis."
)

# setup wgt to broadcast along axis
wgt = dpnp.broadcast_to(wgt, (a.ndim - 1) * (1,) + wgt.shape)
wgt = wgt.swapaxes(-1, axis)
# setup weights to broadcast along axis
weights = dpnp.broadcast_to(
weights, (a.ndim - 1) * (1,) + wgt_shape
)
weights = weights.swapaxes(-1, axis)

scl = wgt.sum(axis=axis, dtype=result_dtype, keepdims=keepdims)
scl = weights.sum(axis=axis, dtype=res_dtype, keepdims=keepdims)
if dpnp.any(scl == 0.0):
raise ZeroDivisionError("Weights sum to zero, can't be normalized")

# result_datatype
avg = (
dpnp.multiply(a, wgt).sum(
axis=axis, dtype=result_dtype, keepdims=keepdims
)
/ scl
avg = dpnp.multiply(a, weights).sum(
axis=axis, dtype=res_dtype, keepdims=keepdims
)
avg /= scl

if returned:
if scl.shape != avg.shape:
Expand Down Expand Up @@ -556,7 +550,7 @@ def cov(

"""

if not isinstance(m, (dpnp_array, dpt.usm_ndarray)):
if not dpnp.is_supported_array_type(m):
pass
elif m.ndim > 2:
pass
Expand Down
Loading