diff --git a/dpnp/backend/include/dpnp_iface.hpp b/dpnp/backend/include/dpnp_iface.hpp index 5be448c446d5..454664fb3418 100644 --- a/dpnp/backend/include/dpnp_iface.hpp +++ b/dpnp/backend/include/dpnp_iface.hpp @@ -195,27 +195,6 @@ INP_DLLEXPORT void dpnp_partition_c(void *array, const shape_elem_type *shape, const size_t ndim); -/** - * @ingroup BACKEND_API - * @brief math library implementation of sort function - * - * @param [in] q_ref Reference to SYCL queue. - * @param [in] array Input array with data. - * @param [out] result Output array with indices. - * @param [in] size Number of elements in input arrays. - * @param [in] dep_event_vec_ref Reference to vector of SYCL events. - */ -template -INP_DLLEXPORT DPCTLSyclEventRef - dpnp_sort_c(DPCTLSyclQueueRef q_ref, - void *array, - void *result, - size_t size, - const DPCTLEventVectorRef dep_event_vec_ref); - -template -INP_DLLEXPORT void dpnp_sort_c(void *array, void *result, size_t size); - /** * @ingroup BACKEND_API * @brief correlate function @@ -318,38 +297,6 @@ INP_DLLEXPORT DPCTLSyclEventRef template INP_DLLEXPORT void dpnp_initval_c(void *result1, void *value, size_t size); -/** - * @ingroup BACKEND_API - * @brief math library implementation of median function - * - * @param [in] q_ref Reference to SYCL queue. - * @param [in] array Input array with data. - * @param [out] result Output array. - * @param [in] shape Shape of input array. - * @param [in] ndim Number of elements in shape. - * @param [in] axis Axis. - * @param [in] naxis Number of elements in axis. - * @param [in] dep_event_vec_ref Reference to vector of SYCL events. - */ -template -INP_DLLEXPORT DPCTLSyclEventRef - dpnp_median_c(DPCTLSyclQueueRef q_ref, - void *array, - void *result, - const shape_elem_type *shape, - size_t ndim, - const shape_elem_type *axis, - size_t naxis, - const DPCTLEventVectorRef dep_event_vec_ref); - -template -INP_DLLEXPORT void dpnp_median_c(void *array, - void *result, - const shape_elem_type *shape, - size_t ndim, - const shape_elem_type *axis, - size_t naxis); - #define MACRO_1ARG_1TYPE_OP(__name__, __operation1__, __operation2__) \ template \ INP_DLLEXPORT DPCTLSyclEventRef __name__( \ diff --git a/dpnp/backend/include/dpnp_iface_fptr.hpp b/dpnp/backend/include/dpnp_iface_fptr.hpp index f332ce29081d..e538d1317569 100644 --- a/dpnp/backend/include/dpnp_iface_fptr.hpp +++ b/dpnp/backend/include/dpnp_iface_fptr.hpp @@ -74,9 +74,6 @@ enum class DPNPFuncName : size_t */ DPNP_FN_INITVAL_EXT, /**< Used in numpy ones, ones_like, zeros, zeros_like impls */ - DPNP_FN_MEDIAN, /**< Used in numpy.median() impl */ - DPNP_FN_MEDIAN_EXT, /**< Used in numpy.median() impl, requires extra - parameters */ DPNP_FN_MODF, /**< Used in numpy.modf() impl */ DPNP_FN_MODF_EXT, /**< Used in numpy.modf() impl, requires extra parameters */ @@ -203,7 +200,6 @@ enum class DPNPFuncName : size_t DPNP_FN_RNG_ZIPF, /**< Used in numpy.random.zipf() impl */ DPNP_FN_RNG_ZIPF_EXT, /**< Used in numpy.random.zipf() impl, requires extra parameters */ - DPNP_FN_SORT, /**< Used in numpy.sort() impl */ DPNP_FN_ZEROS, /**< Used in numpy.zeros() impl */ DPNP_FN_ZEROS_LIKE, /**< Used in numpy.zeros_like() impl */ DPNP_FN_LAST, /**< The latest element of the enumeration */ diff --git a/dpnp/backend/kernels/dpnp_krnl_sorting.cpp b/dpnp/backend/kernels/dpnp_krnl_sorting.cpp index 73a15df0a5c8..80071656d856 100644 --- a/dpnp/backend/kernels/dpnp_krnl_sorting.cpp +++ b/dpnp/backend/kernels/dpnp_krnl_sorting.cpp @@ -30,15 +30,6 @@ #include "queue_sycl.hpp" #include -template -struct _sort_less -{ - inline bool operator()(const _DataType &val1, const _DataType &val2) - { - return (val1 < val2); - } -}; - template class dpnp_partition_c_kernel; @@ -199,55 +190,6 @@ DPCTLSyclEventRef (*dpnp_partition_ext_c)(DPCTLSyclQueueRef, const DPCTLEventVectorRef) = dpnp_partition_c<_DataType>; -template -class dpnp_sort_c_kernel; - -template -DPCTLSyclEventRef dpnp_sort_c(DPCTLSyclQueueRef q_ref, - void *array1_in, - void *result1, - size_t size, - const DPCTLEventVectorRef dep_event_vec_ref) -{ - // avoid warning unused variable - (void)dep_event_vec_ref; - - DPCTLSyclEventRef event_ref = nullptr; - sycl::queue q = *(reinterpret_cast(q_ref)); - - DPNPC_ptr_adapter<_DataType> input1_ptr(q_ref, array1_in, size, true); - DPNPC_ptr_adapter<_DataType> result1_ptr(q_ref, result1, size, true, true); - _DataType *array_1 = input1_ptr.get_ptr(); - _DataType *result = result1_ptr.get_ptr(); - - std::copy(array_1, array_1 + size, result); - - auto policy = oneapi::dpl::execution::make_device_policy< - class dpnp_sort_c_kernel<_DataType>>(q); - - // fails without explicitly specifying of comparator or with std::less - // during kernels compilation affects other kernels - std::sort(policy, result, result + size, _sort_less<_DataType>()); - - policy.queue().wait(); - - return event_ref; -} - -template -void dpnp_sort_c(void *array1_in, void *result1, size_t size) -{ - DPCTLSyclQueueRef q_ref = reinterpret_cast(&DPNP_QUEUE); - DPCTLEventVectorRef dep_event_vec_ref = nullptr; - DPCTLSyclEventRef event_ref = dpnp_sort_c<_DataType>( - q_ref, array1_in, result1, size, dep_event_vec_ref); - DPCTLEvent_WaitAndThrow(event_ref); - DPCTLEvent_Delete(event_ref); -} - -template -void (*dpnp_sort_default_c)(void *, void *, size_t) = dpnp_sort_c<_DataType>; - void func_map_init_sorting(func_map_t &fmap) { fmap[DPNPFuncName::DPNP_FN_PARTITION][eft_INT][eft_INT] = { @@ -274,14 +216,5 @@ void func_map_init_sorting(func_map_t &fmap) fmap[DPNPFuncName::DPNP_FN_PARTITION_EXT][eft_C128][eft_C128] = { eft_C128, (void *)dpnp_partition_ext_c>}; - fmap[DPNPFuncName::DPNP_FN_SORT][eft_INT][eft_INT] = { - eft_INT, (void *)dpnp_sort_default_c}; - fmap[DPNPFuncName::DPNP_FN_SORT][eft_LNG][eft_LNG] = { - eft_LNG, (void *)dpnp_sort_default_c}; - fmap[DPNPFuncName::DPNP_FN_SORT][eft_FLT][eft_FLT] = { - eft_FLT, (void *)dpnp_sort_default_c}; - fmap[DPNPFuncName::DPNP_FN_SORT][eft_DBL][eft_DBL] = { - eft_DBL, (void *)dpnp_sort_default_c}; - return; } diff --git a/dpnp/backend/kernels/dpnp_krnl_statistics.cpp b/dpnp/backend/kernels/dpnp_krnl_statistics.cpp index 66a4881d7f2a..911737177532 100644 --- a/dpnp/backend/kernels/dpnp_krnl_statistics.cpp +++ b/dpnp/backend/kernels/dpnp_krnl_statistics.cpp @@ -128,88 +128,6 @@ DPCTLSyclEventRef (*dpnp_correlate_ext_c)(DPCTLSyclQueueRef, const DPCTLEventVectorRef) = dpnp_correlate_c<_DataType_output, _DataType_input1, _DataType_input2>; -template -DPCTLSyclEventRef dpnp_median_c(DPCTLSyclQueueRef q_ref, - void *array1_in, - void *result1, - const shape_elem_type *shape, - size_t ndim, - const shape_elem_type *axis, - size_t naxis, - const DPCTLEventVectorRef dep_event_vec_ref) -{ - // avoid warning unused variable - (void)dep_event_vec_ref; - - __attribute__((unused)) void *tmp = (void *)(axis + naxis); - - DPCTLSyclEventRef event_ref = nullptr; - - const size_t size = std::accumulate(shape, shape + ndim, 1, - std::multiplies()); - if (!size) { - return event_ref; - } - - sycl::queue q = *(reinterpret_cast(q_ref)); - - DPNPC_ptr_adapter<_ResultType> result_ptr(q_ref, result1, 1, true, true); - _ResultType *result = result_ptr.get_ptr(); - - _DataType *sorted = reinterpret_cast<_DataType *>( - sycl::malloc_shared(size * sizeof(_DataType), q)); - - dpnp_sort_c<_DataType>(array1_in, sorted, size); - - if (size % 2 == 0) { - result[0] = - static_cast<_ResultType>(sorted[size / 2] + sorted[size / 2 - 1]) / - 2; - } - else { - result[0] = sorted[(size - 1) / 2]; - } - - sycl::free(sorted, q); - - return event_ref; -} - -template -void dpnp_median_c(void *array1_in, - void *result1, - const shape_elem_type *shape, - size_t ndim, - const shape_elem_type *axis, - size_t naxis) -{ - DPCTLSyclQueueRef q_ref = reinterpret_cast(&DPNP_QUEUE); - DPCTLEventVectorRef dep_event_vec_ref = nullptr; - DPCTLSyclEventRef event_ref = dpnp_median_c<_DataType, _ResultType>( - q_ref, array1_in, result1, shape, ndim, axis, naxis, dep_event_vec_ref); - DPCTLEvent_WaitAndThrow(event_ref); - DPCTLEvent_Delete(event_ref); -} - -template -void (*dpnp_median_default_c)(void *, - void *, - const shape_elem_type *, - size_t, - const shape_elem_type *, - size_t) = dpnp_median_c<_DataType, _ResultType>; - -template -DPCTLSyclEventRef (*dpnp_median_ext_c)(DPCTLSyclQueueRef, - void *, - void *, - const shape_elem_type *, - size_t, - const shape_elem_type *, - size_t, - const DPCTLEventVectorRef) = - dpnp_median_c<_DataType, _ResultType>; - void func_map_init_statistics(func_map_t &fmap) { fmap[DPNPFuncName::DPNP_FN_CORRELATE][eft_INT][eft_INT] = { @@ -278,35 +196,5 @@ void func_map_init_statistics(func_map_t &fmap) fmap[DPNPFuncName::DPNP_FN_CORRELATE_EXT][eft_DBL][eft_DBL] = { eft_DBL, (void *)dpnp_correlate_ext_c}; - fmap[DPNPFuncName::DPNP_FN_MEDIAN][eft_INT][eft_INT] = { - eft_DBL, (void *)dpnp_median_default_c}; - fmap[DPNPFuncName::DPNP_FN_MEDIAN][eft_LNG][eft_LNG] = { - eft_DBL, (void *)dpnp_median_default_c}; - fmap[DPNPFuncName::DPNP_FN_MEDIAN][eft_FLT][eft_FLT] = { - eft_FLT, (void *)dpnp_median_default_c}; - fmap[DPNPFuncName::DPNP_FN_MEDIAN][eft_DBL][eft_DBL] = { - eft_DBL, (void *)dpnp_median_default_c}; - - fmap[DPNPFuncName::DPNP_FN_MEDIAN_EXT][eft_INT][eft_INT] = { - get_default_floating_type(), - (void *)dpnp_median_ext_c< - int32_t, func_type_map_t::find_type>, - get_default_floating_type(), - (void *)dpnp_median_ext_c< - int32_t, func_type_map_t::find_type< - get_default_floating_type()>>}; - fmap[DPNPFuncName::DPNP_FN_MEDIAN_EXT][eft_LNG][eft_LNG] = { - get_default_floating_type(), - (void *)dpnp_median_ext_c< - int64_t, func_type_map_t::find_type>, - get_default_floating_type(), - (void *)dpnp_median_ext_c< - int64_t, func_type_map_t::find_type< - get_default_floating_type()>>}; - fmap[DPNPFuncName::DPNP_FN_MEDIAN_EXT][eft_FLT][eft_FLT] = { - eft_FLT, (void *)dpnp_median_ext_c}; - fmap[DPNPFuncName::DPNP_FN_MEDIAN_EXT][eft_DBL][eft_DBL] = { - eft_DBL, (void *)dpnp_median_ext_c}; - return; } diff --git a/dpnp/backend/src/dpnp_fptr.hpp b/dpnp/backend/src/dpnp_fptr.hpp index 6413224a2139..1547bf2e5785 100644 --- a/dpnp/backend/src/dpnp_fptr.hpp +++ b/dpnp/backend/src/dpnp_fptr.hpp @@ -260,17 +260,6 @@ class dpnp_less_comp } }; -/** - * A template function that determines the default floating-point type - * based on the value of the template parameter has_fp64. - */ -template -static constexpr DPNPFuncType get_default_floating_type() -{ - return has_fp64::value ? DPNPFuncType::DPNP_FT_DOUBLE - : DPNPFuncType::DPNP_FT_FLOAT; -} - /** * FPTR interface initialization functions */ diff --git a/dpnp/dpnp_algo/dpnp_algo.pxd b/dpnp/dpnp_algo/dpnp_algo.pxd index 1667e9d413a2..15b65135fe24 100644 --- a/dpnp/dpnp_algo/dpnp_algo.pxd +++ b/dpnp/dpnp_algo/dpnp_algo.pxd @@ -36,7 +36,6 @@ cdef extern from "dpnp_iface_fptr.hpp" namespace "DPNPFuncName": # need this na DPNP_FN_CHOOSE_EXT DPNP_FN_CORRELATE_EXT DPNP_FN_ERF_EXT - DPNP_FN_MEDIAN_EXT DPNP_FN_MODF_EXT DPNP_FN_PARTITION_EXT DPNP_FN_RNG_BETA_EXT diff --git a/dpnp/dpnp_algo/dpnp_algo_statistics.pxi b/dpnp/dpnp_algo/dpnp_algo_statistics.pxi index bd54f2091c1f..1a788573ad28 100644 --- a/dpnp/dpnp_algo/dpnp_algo_statistics.pxi +++ b/dpnp/dpnp_algo/dpnp_algo_statistics.pxi @@ -37,17 +37,9 @@ and the rest of the library __all__ += [ "dpnp_correlate", - "dpnp_median", ] -# C function pointer to the C library template functions -ctypedef c_dpctl.DPCTLSyclEventRef(*custom_statistic_1in_1out_func_ptr_t)(c_dpctl.DPCTLSyclQueueRef, - void *, void * , shape_elem_type * , size_t, - shape_elem_type * , size_t, - const c_dpctl.DPCTLEventVectorRef) - - cpdef utils.dpnp_descriptor dpnp_correlate(utils.dpnp_descriptor x1, utils.dpnp_descriptor x2): cdef DPNPFuncType param1_type = dpnp_dtype_to_DPNPFuncType(x1.dtype) cdef DPNPFuncType param2_type = dpnp_dtype_to_DPNPFuncType(x2.dtype) @@ -90,47 +82,3 @@ cpdef utils.dpnp_descriptor dpnp_correlate(utils.dpnp_descriptor x1, utils.dpnp_ c_dpctl.DPCTLEvent_Delete(event_ref) return result - - -cpdef utils.dpnp_descriptor dpnp_median(utils.dpnp_descriptor array1): - cdef shape_type_c x1_shape = array1.shape - cdef DPNPFuncType param1_type = dpnp_dtype_to_DPNPFuncType(array1.dtype) - - cdef DPNPFuncData kernel_data = get_dpnp_function_ptr(DPNP_FN_MEDIAN_EXT, param1_type, param1_type) - - array1_obj = array1.get_array() - - cdef (DPNPFuncType, void *) ret_type_and_func = utils.get_ret_type_and_func(kernel_data, - array1_obj.sycl_device.has_aspect_fp64) - cdef DPNPFuncType return_type = ret_type_and_func[0] - cdef custom_statistic_1in_1out_func_ptr_t func = < custom_statistic_1in_1out_func_ptr_t > ret_type_and_func[1] - - cdef utils.dpnp_descriptor result = utils.create_output_descriptor((1,), - return_type, - None, - device=array1_obj.sycl_device, - usm_type=array1_obj.usm_type, - sycl_queue=array1_obj.sycl_queue) - - result_sycl_queue = result.get_array().sycl_queue - - cdef c_dpctl.SyclQueue q = result_sycl_queue - cdef c_dpctl.DPCTLSyclQueueRef q_ref = q.get_queue_ref() - - # stub for interface support - cdef shape_type_c axis - cdef Py_ssize_t axis_size = 0 - - cdef c_dpctl.DPCTLSyclEventRef event_ref = func(q_ref, - array1.get_data(), - result.get_data(), - x1_shape.data(), - array1.ndim, - axis.data(), - axis_size, - NULL) # dep_events_ref - - with nogil: c_dpctl.DPCTLEvent_WaitAndThrow(event_ref) - c_dpctl.DPCTLEvent_Delete(event_ref) - - return result diff --git a/dpnp/dpnp_iface.py b/dpnp/dpnp_iface.py index 041ec608aa19..7d1a76448e56 100644 --- a/dpnp/dpnp_iface.py +++ b/dpnp/dpnp_iface.py @@ -61,7 +61,6 @@ "as_usm_ndarray", "check_limitations", "check_supported_arrays_type", - "convert_single_elem_array_to_scalar", "default_float_type", "from_dlpack", "get_dpnp_descriptor", @@ -407,15 +406,6 @@ def check_supported_arrays_type(*arrays, scalar_type=False, all_scalars=False): return True -def convert_single_elem_array_to_scalar(obj, keepdims=False): - """Convert array with single element to scalar.""" - - if (obj.ndim > 0) and (obj.size == 1) and (keepdims is False): - return obj.dtype.type(obj[0]) - - return obj - - def default_float_type(device=None, sycl_queue=None): """ Return a floating type used by default in DPNP depending on device diff --git a/dpnp/dpnp_iface_sorting.py b/dpnp/dpnp_iface_sorting.py index c007def058b3..01fc637db717 100644 --- a/dpnp/dpnp_iface_sorting.py +++ b/dpnp/dpnp_iface_sorting.py @@ -57,7 +57,9 @@ __all__ = ["argsort", "partition", "sort", "sort_complex"] -def _wrap_sort_argsort(a, _sorting_fn, axis=-1, kind=None, order=None): +def _wrap_sort_argsort( + a, _sorting_fn, axis=-1, kind=None, order=None, stable=True +): """Wrap a sorting call from dpctl.tensor interface.""" if order is not None: @@ -75,11 +77,11 @@ def _wrap_sort_argsort(a, _sorting_fn, axis=-1, kind=None, order=None): axis = -1 axis = normalize_axis_index(axis, ndim=usm_a.ndim) - usm_res = _sorting_fn(usm_a, axis=axis) + usm_res = _sorting_fn(usm_a, axis=axis, stable=stable) return dpnp_array._create_from_usm_ndarray(usm_res) -def argsort(a, axis=-1, kind=None, order=None): +def argsort(a, axis=-1, kind=None, order=None, *, stable=True): """ Returns the indices that would sort an array. @@ -89,12 +91,18 @@ def argsort(a, axis=-1, kind=None, order=None): ---------- a : {dpnp.ndarray, usm_ndarray} Array to be sorted. - axis : int or None, optional + axis : {None, int}, optional Axis along which to sort. If ``None``, the array is flattened before - sorting. The default is -1, which sorts along the last axis. + sorting. The default is ``-1``, which sorts along the last axis. kind : {None, "stable"}, optional - Default is ``None``, which is equivalent to `"stable"`. - Unlike NumPy, no other option is accepted here. + Sorting algorithm. Default is ``None``, which is equivalent to + ``"stable"``. Unlike NumPy, no other option is accepted here. + stable : {None, bool}, optional + Sort stability. If ``True``, the returned array will maintain + the relative order of ``a`` values which compare as equal. + The same behavior applies when set to ``False`` or ``None``. + Internally, this option selects ``kind="stable"``. + Default: ``None``. Returns ------- @@ -107,14 +115,13 @@ def argsort(a, axis=-1, kind=None, order=None): Notes ----- - For zero-dimensional arrays, if `axis=None`, output is a one-dimensional + For zero-dimensional arrays, if ``axis=None``, output is a one-dimensional array with a single zero element. Otherwise, an ``AxisError`` is raised. Limitations ----------- Parameters `order` is only supported with its default value. - Parameters `kind` can only be ``None`` or ``"stable"`` which - are equivalent. + Parameter `kind` can only be ``None`` or ``"stable"`` which are equivalent. Otherwise ``NotImplementedError`` exception will be raised. See Also @@ -156,7 +163,9 @@ def argsort(a, axis=-1, kind=None, order=None): """ - return _wrap_sort_argsort(a, dpt.argsort, axis=axis, kind=kind, order=order) + return _wrap_sort_argsort( + a, dpt.argsort, axis=axis, kind=kind, order=order, stable=stable + ) def partition(x1, kth, axis=-1, kind="introselect", order=None): @@ -194,7 +203,7 @@ def partition(x1, kth, axis=-1, kind="introselect", order=None): return call_origin(numpy.partition, x1, kth, axis, kind, order) -def sort(a, axis=-1, kind=None, order=None): +def sort(a, axis=-1, kind=None, order=None, *, stable=True): """ Return a sorted copy of an array. @@ -204,12 +213,18 @@ def sort(a, axis=-1, kind=None, order=None): ---------- a : {dpnp.ndarray, usm_ndarray} Array to be sorted. - axis : int or None, optional + axis : {None, int}, optional Axis along which to sort. If ``None``, the array is flattened before - sorting. The default is -1, which sorts along the last axis. + sorting. The default is ``-1``, which sorts along the last axis. kind : {None, "stable"}, optional - Default is ``None``, which is equivalent to `"stable"`. - Unlike in NumPy any other options are not accepted here. + Sorting algorithm. Default is ``None``, which is equivalent to + ``"stable"``. Unlike NumPy, no other option is accepted here. + stable : {None, bool}, optional + Sort stability. If ``True``, the returned array will maintain + the relative order of ``a`` values which compare as equal. + The same behavior applies when set to ``False`` or ``None``. + Internally, this option selects ``kind="stable"``. + Default: ``None``. Returns ------- @@ -218,14 +233,13 @@ def sort(a, axis=-1, kind=None, order=None): Notes ----- - For zero-dimensional arrays, if `axis=None`, output is the input array + For zero-dimensional arrays, if ``axis=None``, output is the input array returned as a one-dimensional array. Otherwise, an ``AxisError`` is raised. Limitations ----------- Parameters `order` is only supported with its default value. - Parameters `kind` can only be ``None`` or ``"stable"`` which - are equivalent. + Parameter `kind` can only be ``None`` or ``"stable"`` which are equivalent. Otherwise ``NotImplementedError`` exception will be raised. See Also @@ -251,7 +265,9 @@ def sort(a, axis=-1, kind=None, order=None): """ - return _wrap_sort_argsort(a, dpt.sort, axis=axis, kind=kind, order=order) + return _wrap_sort_argsort( + a, dpt.sort, axis=axis, kind=kind, order=order, stable=stable + ) def sort_complex(a): diff --git a/dpnp/dpnp_iface_statistics.py b/dpnp/dpnp_iface_statistics.py index d663d9d18364..bc7d323ee16f 100644 --- a/dpnp/dpnp_iface_statistics.py +++ b/dpnp/dpnp_iface_statistics.py @@ -37,17 +37,18 @@ """ - import dpctl.tensor as dpt import numpy -from dpctl.tensor._numpy_helper import normalize_axis_index +from dpctl.tensor._numpy_helper import ( + normalize_axis_index, + normalize_axis_tuple, +) import dpnp # pylint: disable=no-name-in-module from .dpnp_algo import ( dpnp_correlate, - dpnp_median, ) from .dpnp_array import dpnp_array from .dpnp_utils import ( @@ -119,6 +120,22 @@ def _count_reduce_items(arr, axis, where=True): return items +def _flatten_array_along_axes(arr, axes_to_flatten): + """Flatten an array along a specific set of axes.""" + + axes_to_keep = ( + axis for axis in range(arr.ndim) if axis not in axes_to_flatten + ) + + # Move the axes_to_flatten to the front + arr_moved = dpnp.moveaxis(arr, axes_to_flatten, range(len(axes_to_flatten))) + + new_shape = (-1,) + tuple(arr.shape[axis] for axis in axes_to_keep) + flattened_arr = arr_moved.reshape(new_shape) + + return flattened_arr + + def _get_comparison_res_dt(a, _dtype, _out): """Get a data type used by dpctl for result array in comparison function.""" @@ -569,13 +586,15 @@ def mean(a, /, axis=None, dtype=None, out=None, keepdims=False, *, where=True): out : {None, dpnp.ndarray, usm_ndarray}, optional Alternative output array in which to place the result. It must have the same shape as the expected output but the type (of the calculated - values) will be cast if necessary. Default: ``None``. + values) will be cast if necessary. + Default: ``None``. keepdims : {None, bool}, optional If ``True``, the reduced axes (dimensions) are included in the result as singleton dimensions, so that the returned array remains compatible with the input array according to Array Broadcasting rules. Otherwise, if ``False``, the reduced axes are not included in - the returned array. Default: ``False``. + the returned array. + Default: ``False``. Returns ------- @@ -621,24 +640,54 @@ def mean(a, /, axis=None, dtype=None, out=None, keepdims=False, *, where=True): if dtype is not None: usm_res = dpt.astype(usm_res, dtype) - return dpnp.get_result_array(usm_res, out, casting="same_kind") + return dpnp.get_result_array(usm_res, out, casting="unsafe") -def median(x1, axis=None, out=None, overwrite_input=False, keepdims=False): +def median(a, axis=None, out=None, overwrite_input=False, keepdims=False): """ Compute the median along the specified axis. For full documentation refer to :obj:`numpy.median`. - Limitations - ----------- - Input array is supported as :obj:`dpnp.ndarray`. - Parameter `axis` is supported only with default value ``None``. - Parameter `out` is supported only with default value ``None``. - Parameter `overwrite_input` is supported only with default value ``False``. - Parameter `keepdims` is supported only with default value ``False``. - Otherwise the function will be executed sequentially on CPU. - Input array data types are limited by supported DPNP :ref:`Data types`. + Parameters + ---------- + a : {dpnp.ndarray, usm_ndarray} + Input array. + axis : {None, int, tuple or list of ints}, optional + Axis or axes along which the medians are computed. The default, + ``axis=None``, will compute the median along a flattened version of + the array. If a sequence of axes, the array is first flattened along + the given axes, then the median is computed along the resulting + flattened axis. + Default: ``None``. + out : {None, dpnp.ndarray, usm_ndarray}, optional + Alternative output array in which to place the result. It must have + the same shape as the expected output but the type (of the calculated + values) will be cast if necessary. + Default: ``None``. + overwrite_input : bool, optional + If ``True``, then allow use of memory of input array `a` for + calculations. The input array will be modified by the call to + :obj:`dpnp.median`. This will save memory when you do not need to + preserve the contents of the input array. Treat the input as undefined, + but it will probably be fully or partially sorted. + Default: ``False``. + keepdims : {None, bool}, optional + If ``True``, the reduced axes (dimensions) are included in the result + as singleton dimensions, so that the returned array remains + compatible with the input array according to Array Broadcasting + rules. Otherwise, if ``False``, the reduced axes are not included in + the returned array. + Default: ``False``. + + Returns + ------- + dpnp.median : dpnp.ndarray + A new array holding the result. If `a` has a floating-point data type, + the returned array will have the same data type as `a`. If `a` has a + boolean or integral data type, the returned array will have the + default floating point data type for the device where input array `a` + is allocated. See Also -------- @@ -646,32 +695,105 @@ def median(x1, axis=None, out=None, overwrite_input=False, keepdims=False): :obj:`dpnp.percentile` : Compute the q-th percentile of the data along the specified axis. + Notes + ----- + Given a vector ``V`` of length ``N``, the median of ``V`` is the + middle value of a sorted copy of ``V``, ``V_sorted`` - i.e., + ``V_sorted[(N-1)/2]``, when ``N`` is odd, and the average of the + two middle values of ``V_sorted`` when ``N`` is even. + Examples -------- >>> import dpnp as np >>> a = np.array([[10, 7, 4], [3, 2, 1]]) + >>> a + array([[10, 7, 4], + [ 3, 2, 1]]) >>> np.median(a) - 3.5 + array(3.5) + + >>> np.median(a, axis=0) + array([6.5, 4.5, 2.5]) + >>> np.median(a, axis=1) + array([7., 2.]) + >>> np.median(a, axis=(0, 1)) + array(3.5) + + >>> m = np.median(a, axis=0) + >>> out = np.zeros_like(m) + >>> np.median(a, axis=0, out=m) + array([6.5, 4.5, 2.5]) + >>> m + array([6.5, 4.5, 2.5]) + + >>> b = a.copy() + >>> np.median(b, axis=1, overwrite_input=True) + array([7., 2.]) + >>> assert not np.all(a==b) + >>> b = a.copy() + >>> np.median(b, axis=None, overwrite_input=True) + array(3.5) + >>> assert not np.all(a==b) """ - x1_desc = dpnp.get_dpnp_descriptor(x1, copy_when_nondefault_queue=False) - if x1_desc: - if axis is not None: - pass - elif out is not None: - pass - elif overwrite_input: - pass - elif keepdims: - pass + dpnp.check_supported_arrays_type(a) + a_ndim = a.ndim + a_shape = a.shape + _axis = range(a_ndim) if axis is None else axis + _axis = normalize_axis_tuple(_axis, a_ndim) + + if isinstance(axis, (tuple, list)): + if len(axis) == 1: + axis = axis[0] else: - result_obj = dpnp_median(x1_desc).get_pyobj() - result = dpnp.convert_single_elem_array_to_scalar(result_obj) - - return result - - return call_origin(numpy.median, x1, axis, out, overwrite_input, keepdims) + # Need to flatten if `axis` is a sequence of axes since `dpnp.sort` + # only accepts integer `axis` + # Note that the output of _flatten_array_along_axes is not + # necessarily a view of the input since `reshape` is used there. + # If this is the case, using overwrite_input is meaningless + a = _flatten_array_along_axes(a, _axis) + axis = 0 + + if overwrite_input: + if axis is None: + a_sorted = dpnp.ravel(a) + a_sorted.sort() + else: + if isinstance(a, dpt.usm_ndarray): + # dpnp.ndarray.sort only works with dpnp_array + a = dpnp_array._create_from_usm_ndarray(a) + a.sort(axis=axis) + a_sorted = a + else: + a_sorted = dpnp.sort(a, axis=axis) + + if axis is None: + axis = 0 + indexer = [slice(None)] * a_sorted.ndim + index, remainder = divmod(a_sorted.shape[axis], 2) + if remainder == 1: + # index with slice to allow mean (below) to work + indexer[axis] = slice(index, index + 1) + else: + indexer[axis] = slice(index - 1, index + 1) + + # Use `mean` in odd and even case to coerce data type and use `out` array + res = dpnp.mean(a_sorted[tuple(indexer)], axis=axis, out=out) + nan_mask = dpnp.isnan(a_sorted).any(axis=axis) + if nan_mask.any(): + res[nan_mask] = dpnp.nan + + if keepdims: + # We can't use dpnp.mean(..., keepdims) and dpnp.any(..., keepdims) + # above because of the reshape hack might have been used in + # `_flatten_array_along_axes` to handle cases when axis is a tuple. + res_shape = list(a_shape) + for i in _axis: + res_shape[i] = 1 + res = res.reshape(tuple(res_shape)) + + return res def min(a, axis=None, out=None, keepdims=False, initial=None, where=True): diff --git a/tests/test_mathematical.py b/tests/test_mathematical.py index d615f827c28d..e511f438d3e5 100644 --- a/tests/test_mathematical.py +++ b/tests/test_mathematical.py @@ -2467,7 +2467,17 @@ def test_out(self, func_params, dtype): result = getattr(dpnp, func_name)(dp_array, out=dp_out) assert result is dp_out - check_type = True if dpnp.issubdtype(dtype, dpnp.floating) else False + # numpy.ceil, numpy.floor, numpy.trunc always return float dtype for + # NumPy < 2.0.0 while output has the dtype of input for NumPy >= 2.0.0 + # (dpnp follows the latter behavior except for boolean dtype where it + # returns int8) + if ( + numpy.lib.NumpyVersion(numpy.__version__) < "2.0.0" + or dtype == numpy.bool + ): + check_type = False + else: + check_type = True assert_dtype_allclose(result, expected, check_type=check_type) @pytest.mark.parametrize( diff --git a/tests/test_sort.py b/tests/test_sort.py index b12889a674fa..afaf4e152e7a 100644 --- a/tests/test_sort.py +++ b/tests/test_sort.py @@ -4,6 +4,7 @@ from numpy.testing import assert_array_equal, assert_equal, assert_raises import dpnp +from tests.third_party.cupy import testing from .helper import ( assert_dtype_allclose, @@ -61,14 +62,26 @@ def test_argsort_ndarray(self, dtype, axis): expected = np_array.argsort(axis=axis) assert_dtype_allclose(result, expected) - def test_argsort_stable(self): + @pytest.mark.parametrize("kind", [None, "stable"]) + def test_sort_kind(self, kind): np_array = numpy.repeat(numpy.arange(10), 10) dp_array = dpnp.array(np_array) - result = dpnp.argsort(dp_array, kind="stable") + result = dpnp.argsort(dp_array, kind=kind) expected = numpy.argsort(np_array, kind="stable") assert_dtype_allclose(result, expected) + # `stable` keyword is supported in numpy 2.0 and above + @testing.with_requires("numpy>=2.0") + @pytest.mark.parametrize("stable", [None, False, True]) + def test_sort_stable(self, stable): + np_array = numpy.repeat(numpy.arange(10), 10) + dp_array = dpnp.array(np_array) + + result = dpnp.argsort(dp_array, stable="stable") + expected = numpy.argsort(np_array, stable=True) + assert_dtype_allclose(result, expected) + def test_argsort_zero_dim(self): np_array = numpy.array(2.5) dp_array = dpnp.array(np_array) @@ -82,15 +95,6 @@ def test_argsort_zero_dim(self): expected = numpy.argsort(np_array, axis=None) assert_dtype_allclose(result, expected) - def test_sort_notimplemented(self): - dp_array = dpnp.arange(10) - - with pytest.raises(NotImplementedError): - dpnp.argsort(dp_array, kind="quicksort") - - with pytest.raises(NotImplementedError): - dpnp.argsort(dp_array, order=["age"]) - class TestSearchSorted: @pytest.mark.parametrize("side", ["left", "right"]) @@ -304,14 +308,26 @@ def test_sort_ndarray(self, dtype, axis): np_array.sort(axis=axis) assert_dtype_allclose(dp_array, np_array) - def test_sort_stable(self): + @pytest.mark.parametrize("kind", [None, "stable"]) + def test_sort_kind(self, kind): np_array = numpy.repeat(numpy.arange(10), 10) dp_array = dpnp.array(np_array) - result = dpnp.sort(dp_array, kind="stable") + result = dpnp.sort(dp_array, kind=kind) expected = numpy.sort(np_array, kind="stable") assert_dtype_allclose(result, expected) + # `stable` keyword is supported in numpy 2.0 and above + @testing.with_requires("numpy>=2.0") + @pytest.mark.parametrize("stable", [None, False, True]) + def test_sort_stable(self, stable): + np_array = numpy.repeat(numpy.arange(10), 10) + dp_array = dpnp.array(np_array) + + result = dpnp.sort(dp_array, stable="stable") + expected = numpy.sort(np_array, stable=True) + assert_dtype_allclose(result, expected) + def test_sort_ndarray_axis_none(self): a = numpy.random.uniform(-10, 10, 12) dp_array = dpnp.array(a).reshape(6, 2) diff --git a/tests/test_statistics.py b/tests/test_statistics.py index ad617752d049..3baccad541b3 100644 --- a/tests/test_statistics.py +++ b/tests/test_statistics.py @@ -18,119 +18,6 @@ ) -@pytest.mark.parametrize( - "dtype", get_all_dtypes(no_none=True, no_bool=True, no_complex=True) -) -@pytest.mark.parametrize("size", [2, 4, 8, 16, 3, 9, 27, 81]) -def test_median(dtype, size): - a = numpy.arange(size, dtype=dtype) - ia = dpnp.array(a) - - np_res = numpy.median(a) - dpnp_res = dpnp.median(ia) - - assert_allclose(dpnp_res, np_res) - - -class TestMaxMin: - @pytest.mark.parametrize("func", ["max", "min"]) - @pytest.mark.parametrize("axis", [None, 0, 1, -1, 2, -2, (1, 2), (0, -2)]) - @pytest.mark.parametrize("keepdims", [False, True]) - @pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True)) - def test_func(self, func, axis, keepdims, dtype): - a = numpy.arange(768, dtype=dtype).reshape((4, 4, 6, 8)) - ia = dpnp.array(a) - - np_res = getattr(numpy, func)(a, axis=axis, keepdims=keepdims) - dpnp_res = getattr(dpnp, func)(ia, axis=axis, keepdims=keepdims) - assert_dtype_allclose(dpnp_res, np_res) - - @pytest.mark.parametrize("func", ["max", "min"]) - @pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True)) - def test_strided(self, func, dtype): - a = numpy.arange(20, dtype=dtype) - ia = dpnp.array(a) - - np_res = getattr(numpy, func)(a[::-1]) - dpnp_res = getattr(dpnp, func)(ia[::-1]) - assert_dtype_allclose(dpnp_res, np_res) - - np_res = getattr(numpy, func)(a[::2]) - dpnp_res = getattr(dpnp, func)(ia[::2]) - assert_dtype_allclose(dpnp_res, np_res) - - @pytest.mark.parametrize("func", ["max", "min"]) - @pytest.mark.parametrize("axis", [None, 0, 1, -1]) - @pytest.mark.parametrize("keepdims", [False, True]) - def test_bool(self, func, axis, keepdims): - a = numpy.arange(2, dtype=numpy.bool_) - a = numpy.tile(a, (2, 2)) - ia = dpnp.array(a) - - np_res = getattr(numpy, func)(a, axis=axis, keepdims=keepdims) - dpnp_res = getattr(dpnp, func)(ia, axis=axis, keepdims=keepdims) - assert_dtype_allclose(dpnp_res, np_res) - - @pytest.mark.parametrize("func", ["max", "min"]) - def test_out(self, func): - a = numpy.arange(12, dtype=numpy.float32).reshape((2, 2, 3)) - ia = dpnp.array(a) - - # out is dpnp_array - np_res = getattr(numpy, func)(a, axis=0) - dpnp_out = dpnp.empty(np_res.shape, dtype=np_res.dtype) - dpnp_res = getattr(dpnp, func)(ia, axis=0, out=dpnp_out) - assert dpnp_out is dpnp_res - assert_allclose(dpnp_res, np_res) - - # out is usm_ndarray - dpt_out = dpt.empty(np_res.shape, dtype=np_res.dtype) - dpnp_res = getattr(dpnp, func)(ia, axis=0, out=dpt_out) - assert dpt_out is dpnp_res.get_array() - assert_allclose(dpnp_res, np_res) - - # output is numpy array -> Error - dpnp_res = numpy.empty_like(np_res) - with pytest.raises(TypeError): - getattr(dpnp, func)(ia, axis=0, out=dpnp_res) - - # output has incorrect shape -> Error - dpnp_res = dpnp.array(numpy.zeros((4, 2))) - with pytest.raises(ValueError): - getattr(dpnp, func)(ia, axis=0, out=dpnp_res) - - @pytest.mark.usefixtures("suppress_complex_warning") - @pytest.mark.parametrize("func", ["max", "min"]) - @pytest.mark.parametrize("arr_dt", get_all_dtypes(no_none=True)) - @pytest.mark.parametrize("out_dt", get_all_dtypes(no_none=True)) - def test_out_dtype(self, func, arr_dt, out_dt): - a = ( - numpy.arange(12, dtype=numpy.float32) - .reshape((2, 2, 3)) - .astype(dtype=arr_dt) - ) - out = numpy.zeros_like(a, shape=(2, 3), dtype=out_dt) - - ia = dpnp.array(a) - iout = dpnp.array(out) - - result = getattr(dpnp, func)(ia, out=iout, axis=1) - expected = getattr(numpy, func)(a, out=out, axis=1) - assert_array_equal(expected, result) - assert result is iout - - @pytest.mark.parametrize("func", ["max", "min"]) - def test_error(self, func): - ia = dpnp.arange(5) - # where is not supported - with pytest.raises(NotImplementedError): - getattr(dpnp, func)(ia, where=False) - - # initial is not supported - with pytest.raises(NotImplementedError): - getattr(dpnp, func)(ia, initial=6) - - class TestAverage: @pytest.mark.parametrize("dtype", get_all_dtypes()) @pytest.mark.parametrize("axis", [None, 0, 1]) @@ -252,6 +139,91 @@ def test_avg_error(self): dpnp.average(a, axis=0, weights=w) +class TestMaxMin: + @pytest.mark.parametrize("func", ["max", "min"]) + @pytest.mark.parametrize("axis", [None, 0, 1, -1, 2, -2, (1, 2), (0, -2)]) + @pytest.mark.parametrize("keepdims", [False, True]) + @pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True)) + def test_func(self, func, axis, keepdims, dtype): + a = numpy.arange(768, dtype=dtype).reshape((4, 4, 6, 8)) + ia = dpnp.array(a) + + np_res = getattr(numpy, func)(a, axis=axis, keepdims=keepdims) + dpnp_res = getattr(dpnp, func)(ia, axis=axis, keepdims=keepdims) + assert_dtype_allclose(dpnp_res, np_res) + + @pytest.mark.parametrize("func", ["max", "min"]) + @pytest.mark.parametrize("axis", [None, 0, 1, -1]) + @pytest.mark.parametrize("keepdims", [False, True]) + def test_bool(self, func, axis, keepdims): + a = numpy.arange(2, dtype=numpy.bool_) + a = numpy.tile(a, (2, 2)) + ia = dpnp.array(a) + + np_res = getattr(numpy, func)(a, axis=axis, keepdims=keepdims) + dpnp_res = getattr(dpnp, func)(ia, axis=axis, keepdims=keepdims) + assert_dtype_allclose(dpnp_res, np_res) + + @pytest.mark.parametrize("func", ["max", "min"]) + def test_out(self, func): + a = numpy.arange(12, dtype=numpy.float32).reshape((2, 2, 3)) + ia = dpnp.array(a) + + # out is dpnp_array + np_res = getattr(numpy, func)(a, axis=0) + dpnp_out = dpnp.empty(np_res.shape, dtype=np_res.dtype) + dpnp_res = getattr(dpnp, func)(ia, axis=0, out=dpnp_out) + assert dpnp_out is dpnp_res + assert_allclose(dpnp_res, np_res) + + # out is usm_ndarray + dpt_out = dpt.empty(np_res.shape, dtype=np_res.dtype) + dpnp_res = getattr(dpnp, func)(ia, axis=0, out=dpt_out) + assert dpt_out is dpnp_res.get_array() + assert_allclose(dpnp_res, np_res) + + # output is numpy array -> Error + dpnp_res = numpy.empty_like(np_res) + with pytest.raises(TypeError): + getattr(dpnp, func)(ia, axis=0, out=dpnp_res) + + # output has incorrect shape -> Error + dpnp_res = dpnp.array(numpy.zeros((4, 2))) + with pytest.raises(ValueError): + getattr(dpnp, func)(ia, axis=0, out=dpnp_res) + + @pytest.mark.usefixtures("suppress_complex_warning") + @pytest.mark.parametrize("func", ["max", "min"]) + @pytest.mark.parametrize("arr_dt", get_all_dtypes(no_none=True)) + @pytest.mark.parametrize("out_dt", get_all_dtypes(no_none=True)) + def test_out_dtype(self, func, arr_dt, out_dt): + a = ( + numpy.arange(12, dtype=numpy.float32) + .reshape((2, 2, 3)) + .astype(dtype=arr_dt) + ) + out = numpy.zeros_like(a, shape=(2, 3), dtype=out_dt) + + ia = dpnp.array(a) + iout = dpnp.array(out) + + result = getattr(dpnp, func)(ia, out=iout, axis=1) + expected = getattr(numpy, func)(a, out=out, axis=1) + assert_array_equal(expected, result) + assert result is iout + + @pytest.mark.parametrize("func", ["max", "min"]) + def test_error(self, func): + ia = dpnp.arange(5) + # where is not supported + with pytest.raises(NotImplementedError): + getattr(dpnp, func)(ia, where=False) + + # initial is not supported + with pytest.raises(NotImplementedError): + getattr(dpnp, func)(ia, initial=6) + + class TestMean: @pytest.mark.parametrize("dtype", get_all_dtypes()) @pytest.mark.parametrize("axis", [None, 0, 1, (0, 1)]) @@ -265,15 +237,18 @@ def test_mean(self, dtype, axis, keepdims): assert_dtype_allclose(result, expected) @pytest.mark.parametrize("dtype", get_all_dtypes()) - @pytest.mark.parametrize("axis", [0, 1]) - def test_mean_out(self, dtype, axis): - dp_array = dpnp.array([[0, 1, 2], [3, 4, 0]], dtype=dtype) - np_array = dpnp.asnumpy(dp_array) - - expected = numpy.mean(np_array, axis=axis) - out = dpnp.empty_like(dpnp.asarray(expected)) - result = dpnp.mean(dp_array, axis=axis, out=out) - assert result is out + @pytest.mark.parametrize( + "axis, out_shape", [(0, (3,)), (1, (2,)), ((0, 1), ())] + ) + def test_mean_out(self, dtype, axis, out_shape): + ia = dpnp.array([[5, 1, 2], [8, 4, 3]], dtype=dtype) + a = dpnp.asnumpy(ia) + + out_np = numpy.empty_like(a, shape=out_shape) + out_dp = dpnp.empty_like(ia, shape=out_shape) + expected = numpy.mean(a, axis=axis, out=out_np) + result = dpnp.mean(ia, axis=axis, out=out_dp) + assert result is out_dp assert_dtype_allclose(result, expected) @pytest.mark.parametrize("dtype", get_complex_dtypes()) @@ -310,19 +285,6 @@ def test_mean_empty(self, axis, shape): expected = numpy.mean(np_array, axis=axis) assert_allclose(expected, result) - @pytest.mark.parametrize("dtype", get_all_dtypes()) - def test_mean_strided(self, dtype): - dp_array = dpnp.array([-2, -1, 0, 1, 0, 2], dtype=dtype) - np_array = dpnp.asnumpy(dp_array) - - result = dpnp.mean(dp_array[::-1]) - expected = numpy.mean(np_array[::-1]) - assert_allclose(expected, result) - - result = dpnp.mean(dp_array[::2]) - expected = numpy.mean(np_array[::2]) - assert_allclose(expected, result) - def test_mean_scalar(self): dp_array = dpnp.array(5) np_array = dpnp.asnumpy(dp_array) @@ -337,6 +299,113 @@ def test_mean_NotImplemented(self): dpnp.mean(ia, where=False) +class TestMedian: + @pytest.mark.parametrize("dtype", get_all_dtypes(no_none=True)) + @pytest.mark.parametrize("size", [1, 2, 3, 4, 8, 9]) + def test_basic(self, dtype, size): + if dtype == dpnp.bool: + a = numpy.arange(2, dtype=dtype) + a = numpy.repeat(a, size) + else: + a = numpy.array(numpy.random.uniform(-5, 5, size), dtype=dtype) + ia = dpnp.array(a) + + expected = numpy.median(a) + result = dpnp.median(ia) + + assert_dtype_allclose(result, expected) + + @pytest.mark.parametrize("axis", [None, 0, (-1,), [0, 1], (0, -2, -1)]) + @pytest.mark.parametrize("keepdims", [True, False]) + def test_axis(self, axis, keepdims): + a = numpy.random.uniform(-5, 5, 24).reshape(2, 3, 4) + ia = dpnp.array(a) + + expected = numpy.median(a, axis=axis, keepdims=keepdims) + result = dpnp.median(ia, axis=axis, keepdims=keepdims) + + assert_dtype_allclose(result, expected) + + @pytest.mark.usefixtures( + "suppress_invalid_numpy_warnings", + "suppress_mean_empty_slice_numpy_warnings", + ) + @pytest.mark.parametrize("axis", [0, 1, (0, 1)]) + @pytest.mark.parametrize("shape", [(2, 3), (2, 0), (0, 3)]) + def test_empty(self, axis, shape): + a = numpy.empty(shape) + ia = dpnp.array(a) + + result = dpnp.median(ia, axis=axis) + expected = numpy.median(a, axis=axis) + assert_dtype_allclose(result, expected) + + @pytest.mark.parametrize("dtype", get_all_dtypes()) + @pytest.mark.parametrize( + "axis, out_shape", [(0, (3,)), (1, (2,)), ((0, 1), ())] + ) + def test_out(self, dtype, axis, out_shape): + a = numpy.array([[5, 1, 2], [8, 4, 3]], dtype=dtype) + ia = dpnp.array(a) + + out_np = numpy.empty_like(a, shape=out_shape) + out_dp = dpnp.empty_like(ia, shape=out_shape) + expected = numpy.median(a, axis=axis, out=out_np) + result = dpnp.median(ia, axis=axis, out=out_dp) + assert result is out_dp + assert_dtype_allclose(result, expected) + + def test_0d_array(self): + a = numpy.array(20) + ia = dpnp.array(a) + + result = dpnp.median(ia) + expected = numpy.median(a) + assert_dtype_allclose(result, expected) + + @pytest.mark.parametrize("axis", [None, 0, (0, 1), (0, -2, -1)]) + @pytest.mark.parametrize("keepdims", [True, False]) + def test_nan(self, axis, keepdims): + a = numpy.random.uniform(-5, 5, 24).reshape(2, 3, 4) + a[0, 0, 0] = a[-1, -1, -1] = numpy.nan + ia = dpnp.array(a) + + expected = numpy.median(a, axis=axis, keepdims=keepdims) + result = dpnp.median(ia, axis=axis, keepdims=keepdims) + + assert_dtype_allclose(result, expected) + + @pytest.mark.parametrize("axis", [None, 0, -1, (0, -2, -1)]) + @pytest.mark.parametrize("keepdims", [True, False]) + def test_overwrite_input(self, axis, keepdims): + a = numpy.random.uniform(-5, 5, 24).reshape(2, 3, 4) + ia = dpnp.array(a) + + b = a.copy() + ib = ia.copy() + expected = numpy.median( + b, axis=axis, keepdims=keepdims, overwrite_input=True + ) + result = dpnp.median( + ib, axis=axis, keepdims=keepdims, overwrite_input=True + ) + assert not numpy.all(a == b) + assert not dpnp.all(ia == ib) + + assert_dtype_allclose(result, expected) + + @pytest.mark.parametrize("axis", [None, 0, (-1,), [0, 1]]) + @pytest.mark.parametrize("overwrite_input", [True, False]) + def test_usm_ndarray(self, axis, overwrite_input): + a = numpy.random.uniform(-5, 5, 24).reshape(2, 3, 4) + ia = dpt.asarray(a) + + expected = numpy.median(a, axis=axis, overwrite_input=overwrite_input) + result = dpnp.median(ia, axis=axis, overwrite_input=overwrite_input) + + assert_dtype_allclose(result, expected) + + class TestVar: @pytest.mark.usefixtures( "suppress_divide_invalid_numpy_warnings", "suppress_dof_numpy_warnings" @@ -390,19 +459,6 @@ def test_var_empty(self, axis, shape): expected = numpy.var(np_array, axis=axis) assert_dtype_allclose(result, expected) - @pytest.mark.parametrize("dtype", get_all_dtypes()) - def test_var_strided(self, dtype): - dp_array = dpnp.array([-2, -1, 0, 1, 0, 2], dtype=dtype) - np_array = dpnp.asnumpy(dp_array) - - result = dpnp.var(dp_array[::-1]) - expected = numpy.var(np_array[::-1]) - assert_dtype_allclose(result, expected) - - result = dpnp.var(dp_array[::2]) - expected = numpy.var(np_array[::2]) - assert_dtype_allclose(result, expected) - @pytest.mark.usefixtures("suppress_complex_warning") @pytest.mark.parametrize("dt_in", get_all_dtypes(no_bool=True)) @pytest.mark.parametrize("dt_out", get_float_complex_dtypes()) @@ -486,19 +542,6 @@ def test_std_empty(self, axis, shape): expected = numpy.std(np_array, axis=axis) assert_dtype_allclose(result, expected) - @pytest.mark.parametrize("dtype", get_all_dtypes()) - def test_std_strided(self, dtype): - dp_array = dpnp.array([-2, -1, 0, 1, 0, 2], dtype=dtype) - np_array = dpnp.asnumpy(dp_array) - - result = dpnp.std(dp_array[::-1]) - expected = numpy.std(np_array[::-1]) - assert_dtype_allclose(result, expected) - - result = dpnp.std(dp_array[::2]) - expected = numpy.std(np_array[::2]) - assert_dtype_allclose(result, expected) - @pytest.mark.usefixtures("suppress_complex_warning") @pytest.mark.parametrize("dt_in", get_all_dtypes(no_bool=True)) @pytest.mark.parametrize("dt_out", get_float_complex_dtypes()) diff --git a/tests/test_strides.py b/tests/test_strides.py index 8db19e31f57f..0f6732ab9a44 100644 --- a/tests/test_strides.py +++ b/tests/test_strides.py @@ -56,46 +56,78 @@ def test_strides(func_name, dtype): "arctan", "arctanh", "argsort", - "cbrt", - "ceil", + "conjugate", "copy", "cos", "cosh", "conjugate", - "degrees", "ediff1d", "exp", "exp2", "expm1", - "fabs", - "floor", + "imag", "log", "log10", "log1p", "log2", + "max", + "min", + "mean", + "median", "negative", "positive", - "radians", + "real", "sign", "sin", "sinh", "sort", "sqrt", "square", + "std", "tan", "tanh", + "var", + ], +) +@pytest.mark.parametrize("dtype", get_all_dtypes(no_none=True, no_bool=True)) +@pytest.mark.parametrize("stride", [2, -1, -3]) +def test_strides_1arg_support_complex(func_name, dtype, stride): + a = numpy.arange(10, dtype=dtype) + dpa = dpnp.array(a) + b = a[::stride] + dpb = dpa[::stride] + + dpnp_func = _getattr(dpnp, func_name) + result = dpnp_func(dpb) + + numpy_func = _getattr(numpy, func_name) + expected = numpy_func(b) + + assert_dtype_allclose(result, expected) + + +@pytest.mark.parametrize( + "func_name", + [ + "cbrt", + "ceil", + "degrees", + "fabs", + "floor", + "radians", "trunc", "unwrap", ], ) -@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True, no_complex=True)) -@pytest.mark.parametrize("shape", [(10,)], ids=["(10,)"]) -def test_strides_1arg(func_name, dtype, shape): - a = numpy.arange(numpy.prod(shape), dtype=dtype).reshape(shape) - b = a[::2] - - dpa = dpnp.reshape(dpnp.arange(numpy.prod(shape), dtype=dtype), shape) - dpb = dpa[::2] +@pytest.mark.parametrize( + "dtype", get_all_dtypes(no_none=True, no_bool=True, no_complex=True) +) +@pytest.mark.parametrize("stride", [2, -1, -3]) +def test_strides_1arg(func_name, dtype, stride): + a = numpy.arange(10, dtype=dtype) + dpa = dpnp.array(a) + b = a[::stride] + dpb = dpa[::stride] dpnp_func = _getattr(dpnp, func_name) result = dpnp_func(dpb) @@ -103,7 +135,13 @@ def test_strides_1arg(func_name, dtype, shape): numpy_func = _getattr(numpy, func_name) expected = numpy_func(b) - assert_allclose(result, expected, rtol=1e-06) + # numpy.ceil, numpy.floor, numpy.trunc always return float dtype for NumPy < 2.0.0 + # while for NumPy >= 2.0.0, output has the dtype of input (dpnp follows this behavior) + if numpy.lib.NumpyVersion(numpy.__version__) < "2.0.0": + check_type = False + else: + check_type = True + assert_dtype_allclose(result, expected, check_type=check_type) @pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True, no_complex=True)) @@ -146,32 +184,6 @@ def test_reduce_hypot(dtype): assert_allclose(result, expected) -@pytest.mark.parametrize( - "func_name", - [ - "conjugate", - "imag", - "real", - ], -) -@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True)) -@pytest.mark.parametrize("shape", [(10,)], ids=["(10,)"]) -def test_strides_1arg_complex(func_name, dtype, shape): - a = numpy.arange(numpy.prod(shape), dtype=dtype).reshape(shape) - b = a[::2] - - dpa = dpnp.reshape(dpnp.arange(numpy.prod(shape), dtype=dtype), shape) - dpb = dpa[::2] - - dpnp_func = _getattr(dpnp, func_name) - result = dpnp_func(dpb) - - numpy_func = _getattr(numpy, func_name) - expected = numpy_func(b) - - assert_allclose(result, expected, rtol=1e-06) - - @pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True, no_complex=True)) @pytest.mark.parametrize("shape", [(10,)], ids=["(10,)"]) def test_strides_erf(dtype, shape): diff --git a/tests/test_sycl_queue.py b/tests/test_sycl_queue.py index 5179f97872d9..74c65a65a59c 100644 --- a/tests/test_sycl_queue.py +++ b/tests/test_sycl_queue.py @@ -473,6 +473,7 @@ def test_meshgrid(device): pytest.param("log2", [1.0, 2.0, 4.0, 7.0]), pytest.param("max", [1.0, 2.0, 4.0, 7.0]), pytest.param("mean", [1.0, 2.0, 4.0, 7.0]), + pytest.param("median", [1.0, 2.0, 4.0, 7.0]), pytest.param("min", [1.0, 2.0, 4.0, 7.0]), pytest.param("nanargmax", [1.0, 2.0, 4.0, dpnp.nan]), pytest.param("nanargmin", [1.0, 2.0, 4.0, dpnp.nan]), diff --git a/tests/test_usm_type.py b/tests/test_usm_type.py index 8db3c5caab25..2c1fb32d2301 100644 --- a/tests/test_usm_type.py +++ b/tests/test_usm_type.py @@ -598,6 +598,7 @@ def test_norm(usm_type, ord, axis): pytest.param("logsumexp", [1.0, 2.0, 4.0, 7.0]), pytest.param("max", [1.0, 2.0, 4.0, 7.0]), pytest.param("mean", [1.0, 2.0, 4.0, 7.0]), + pytest.param("median", [1.0, 2.0, 4.0, 7.0]), pytest.param("min", [1.0, 2.0, 4.0, 7.0]), pytest.param("nanargmax", [1.0, 2.0, 4.0, dp.nan]), pytest.param("nanargmin", [1.0, 2.0, 4.0, dp.nan]), diff --git a/tests/third_party/cupy/statistics_tests/test_meanvar.py b/tests/third_party/cupy/statistics_tests/test_meanvar.py index ce5de823d61d..465d0e8e5113 100644 --- a/tests/third_party/cupy/statistics_tests/test_meanvar.py +++ b/tests/third_party/cupy/statistics_tests/test_meanvar.py @@ -18,42 +18,36 @@ def test_median_noaxis(self, xp, dtype): a = testing.shaped_random((3, 4, 5), xp, dtype) return xp.median(a) - @pytest.mark.usefixtures("allow_fall_back_on_numpy") @testing.for_all_dtypes() @testing.numpy_cupy_allclose(type_check=has_support_aspect64()) def test_median_axis1(self, xp, dtype): a = testing.shaped_random((3, 4, 5), xp, dtype) return xp.median(a, axis=1) - @pytest.mark.usefixtures("allow_fall_back_on_numpy") @testing.for_all_dtypes() @testing.numpy_cupy_allclose(type_check=has_support_aspect64()) def test_median_axis2(self, xp, dtype): a = testing.shaped_random((3, 4, 5), xp, dtype) return xp.median(a, axis=2) - @pytest.mark.usefixtures("allow_fall_back_on_numpy") @testing.for_all_dtypes() - @testing.numpy_cupy_allclose() + @testing.numpy_cupy_allclose(type_check=has_support_aspect64()) def test_median_overwrite_input(self, xp, dtype): a = testing.shaped_random((3, 4, 5), xp, dtype) return xp.median(a, overwrite_input=True) - @pytest.mark.usefixtures("allow_fall_back_on_numpy") @testing.for_all_dtypes() @testing.numpy_cupy_allclose(type_check=has_support_aspect64()) def test_median_keepdims_axis1(self, xp, dtype): a = testing.shaped_random((3, 4, 5), xp, dtype) return xp.median(a, axis=1, keepdims=True) - @pytest.mark.usefixtures("allow_fall_back_on_numpy") @testing.for_all_dtypes() @testing.numpy_cupy_allclose(type_check=has_support_aspect64()) def test_median_keepdims_noaxis(self, xp, dtype): a = testing.shaped_random((3, 4, 5), xp, dtype) return xp.median(a, keepdims=True) - @pytest.mark.usefixtures("allow_fall_back_on_numpy") def test_median_invalid_axis(self): for xp in [numpy, cupy]: a = testing.shaped_random((3, 4, 5), xp) @@ -67,14 +61,16 @@ def test_median_invalid_axis(self): return xp.median(a, (-a.ndim - 1, 1), keepdims=False) with pytest.raises(AxisError): - return xp.median( - a, - ( - 0, - a.ndim, - ), - keepdims=False, - ) + return xp.median(a, (0, a.ndim), keepdims=False) + + @testing.for_dtypes("efdFD") + @testing.numpy_cupy_allclose() + def test_median_nan(self, xp, dtype): + a = xp.array( + [[xp.nan, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, xp.nan]], + dtype=dtype, + ) + return xp.median(a, axis=1) @testing.parameterize( @@ -86,7 +82,6 @@ def test_median_invalid_axis(self): } ) ) -@pytest.mark.usefixtures("allow_fall_back_on_numpy") class TestMedianAxis: @testing.for_all_dtypes() @testing.numpy_cupy_allclose(type_check=has_support_aspect64())