Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 9 additions & 20 deletions dpnp/dpnp_iface_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,18 +47,11 @@
import dpnp

# pylint: disable=no-name-in-module
from .dpnp_algo import (
dpnp_correlate,
)
from .dpnp_algo import dpnp_correlate
from .dpnp_array import dpnp_array
from .dpnp_utils import (
call_origin,
get_usm_allocations,
)
from .dpnp_utils import call_origin, get_usm_allocations
from .dpnp_utils.dpnp_utils_reduction import dpnp_wrap_reduction_call
from .dpnp_utils.dpnp_utils_statistics import (
dpnp_cov,
)
from .dpnp_utils.dpnp_utils_statistics import dpnp_cov

__all__ = [
"amax",
Expand Down Expand Up @@ -276,21 +269,17 @@ def average(a, axis=None, weights=None, returned=False, *, keepdims=False):
"""

dpnp.check_supported_arrays_type(a)
usm_type, exec_q = get_usm_allocations([a, weights])

if weights is None:
avg = dpnp.mean(a, axis=axis, keepdims=keepdims)
scl = dpnp.asanyarray(
avg.dtype.type(a.size / avg.size),
usm_type=a.usm_type,
sycl_queue=a.sycl_queue,
usm_type=usm_type,
sycl_queue=exec_q,
)
else:
if not isinstance(weights, (dpnp_array, dpt.usm_ndarray)):
wgt = dpnp.asanyarray(
weights, usm_type=a.usm_type, sycl_queue=a.sycl_queue
)
else:
get_usm_allocations([a, weights])
wgt = weights
wgt = dpnp.asanyarray(weights, usm_type=usm_type, sycl_queue=exec_q)

if not dpnp.issubdtype(a.dtype, dpnp.inexact):
default_dtype = dpnp.default_float_type(a.device)
Expand Down Expand Up @@ -556,7 +545,7 @@ def cov(

"""

if not isinstance(m, (dpnp_array, dpt.usm_ndarray)):
if not dpnp.is_supported_array_type(m):
pass
elif m.ndim > 2:
pass
Expand Down
Loading