From 5b1957d2adf883136f6e9383581ed50b91c8332f Mon Sep 17 00:00:00 2001 From: Alexander Kalistratov Date: Mon, 16 Sep 2024 18:08:13 +0200 Subject: [PATCH] Implementation of bincounts --- .../extensions/statistics/CMakeLists.txt | 1 + .../extensions/statistics/bincount.cpp | 231 ++++++++++++++++ .../extensions/statistics/bincount.hpp | 66 +++++ dpnp/backend/extensions/statistics/common.hpp | 13 + .../extensions/statistics/histogram.cpp | 11 +- .../statistics/histogram_common.cpp | 57 ++-- .../statistics/histogram_common.hpp | 11 +- .../extensions/statistics/statistics_py.cpp | 2 + dpnp/dpnp_iface_histograms.py | 253 +++++++++++++++--- dpnp/dpnp_iface_statistics.py | 23 -- tests/test_histogram.py | 111 ++++++-- tests/test_statistics.py | 35 --- tests/test_sycl_queue.py | 21 ++ tests/test_usm_type.py | 12 + .../cupy/statistics_tests/test_histogram.py | 16 +- tests_external/skipped_tests_numpy.tbl | 1 - 16 files changed, 706 insertions(+), 158 deletions(-) create mode 100644 dpnp/backend/extensions/statistics/bincount.cpp create mode 100644 dpnp/backend/extensions/statistics/bincount.hpp diff --git a/dpnp/backend/extensions/statistics/CMakeLists.txt b/dpnp/backend/extensions/statistics/CMakeLists.txt index 20c868066576..2b784555630d 100644 --- a/dpnp/backend/extensions/statistics/CMakeLists.txt +++ b/dpnp/backend/extensions/statistics/CMakeLists.txt @@ -27,6 +27,7 @@ set(python_module_name _statistics_impl) set(_module_src ${CMAKE_CURRENT_SOURCE_DIR}/common.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/bincount.cpp ${CMAKE_CURRENT_SOURCE_DIR}/histogram.cpp ${CMAKE_CURRENT_SOURCE_DIR}/histogram_common.cpp ${CMAKE_CURRENT_SOURCE_DIR}/statistics_py.cpp diff --git a/dpnp/backend/extensions/statistics/bincount.cpp b/dpnp/backend/extensions/statistics/bincount.cpp new file mode 100644 index 000000000000..a7a67c12ea47 --- /dev/null +++ b/dpnp/backend/extensions/statistics/bincount.cpp @@ -0,0 +1,231 @@ +//***************************************************************************** +// Copyright (c) 2024, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** + +#include + +#include +#include + +#include "bincount.hpp" +#include "histogram_common.hpp" + +using dpctl::tensor::usm_ndarray; + +using namespace statistics::histogram; +using namespace statistics::common; + +namespace +{ + +template +struct BincountEdges +{ + static constexpr bool const sync_after_init = false; + using boundsT = std::tuple; + + BincountEdges(const T &min, const T &max) + { + this->min = min; + this->max = max; + } + + template + void init(const sycl::nd_item<_Dims> &) const + { + } + + boundsT get_bounds() const + { + return {min, max}; + } + + template + size_t get_bin(const sycl::nd_item<_Dims> &, + const dT *val, + const boundsT &) const + { + return val[0] - min; + } + + template + bool in_bounds(const dT *val, const boundsT &bounds) const + { + return check_in_bounds(val[0], std::get<0>(bounds), + std::get<1>(bounds)); + } + +private: + T min; + T max; +}; + +template +struct BincountF +{ + static sycl::event impl(sycl::queue &exec_q, + const void *vin, + const int64_t min, + const int64_t max, + const void *vweights, + void *vout, + const size_t, + const size_t size, + const std::vector &depends) + { + const T *in = static_cast(vin); + const HistType *weights = static_cast(vweights); + // shift output pointer by min elements + HistType *out = static_cast(vout) + min; + + const size_t needed_bins_count = (max - min) + 1; + + const uint32_t local_size = get_max_local_size(exec_q); + + constexpr uint32_t WorkPI = 128; // empirically found number + const auto nd_range = make_ndrange(size, local_size, WorkPI); + + return exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + constexpr uint32_t dims = 1; + + auto dispatch_bins = [&](const auto &weights) { + const auto local_mem_size = + get_local_mem_size_in_items(exec_q); + if (local_mem_size >= needed_bins_count) { + const uint32_t local_hist_count = + get_local_hist_copies_count(local_mem_size, local_size, + needed_bins_count); + + auto hist = HistWithLocalCopies( + out, needed_bins_count, local_hist_count, cgh); + + auto edges = BincountEdges(min, max); + submit_histogram(in, size, dims, WorkPI, hist, edges, + weights, nd_range, cgh); + } + else { + auto hist = HistGlobalMemory(out); + auto edges = BincountEdges(min, max); + submit_histogram(in, size, dims, WorkPI, hist, edges, + weights, nd_range, cgh); + } + }; + + if (weights) { + auto _weights = Weights(weights); + dispatch_bins(_weights); + } + else { + auto _weights = NoWeights(); + dispatch_bins(_weights); + } + }); + } +}; + +using SupportedTypes = std::tuple, + std::tuple, + std::tuple>; + +} // namespace + +Bincount::Bincount() : dispatch_table("sample", "histogram") +{ + dispatch_table.populate_dispatch_table(); +} + +std::tuple Bincount::call( + const dpctl::tensor::usm_ndarray &sample, + const int64_t min, + const int64_t max, + const std::optional &weights, + dpctl::tensor::usm_ndarray &histogram, + const std::vector &depends) +{ + validate(sample, std::optional(), weights, + histogram); + + if (sample.get_size() == 0) { + return {sycl::event(), sycl::event()}; + } + + const int sample_typenum = sample.get_typenum(); + const int hist_typenum = histogram.get_typenum(); + + auto bincount_func = dispatch_table.get(sample_typenum, hist_typenum); + + auto exec_q = sample.get_queue(); + + void *weights_ptr = + weights.has_value() ? weights.value().get_data() : nullptr; + + auto ev = bincount_func(exec_q, sample.get_data(), min, max, weights_ptr, + histogram.get_data(), histogram.get_shape(0), + sample.get_shape(0), depends); + + sycl::event args_ev; + if (weights.has_value()) { + args_ev = dpctl::utils::keep_args_alive( + exec_q, {sample, weights.value(), histogram}, {ev}); + } + else { + args_ev = + dpctl::utils::keep_args_alive(exec_q, {sample, histogram}, {ev}); + } + + return {args_ev, ev}; +} + +std::unique_ptr bincount; + +void statistics::histogram::populate_bincount(py::module_ m) +{ + using namespace std::placeholders; + + bincount.reset(new Bincount()); + + auto bincount_func = + [bincountp = bincount.get()]( + const dpctl::tensor::usm_ndarray &sample, int64_t min, int64_t max, + std::optional &weights, + dpctl::tensor::usm_ndarray &histogram, + const std::vector &depends) { + return bincountp->call(sample, min, max, weights, histogram, + depends); + }; + + m.def("bincount", bincount_func, + "Count number of occurrences of each value in array of non-negative " + "ints.", + py::arg("sample"), py::arg("min"), py::arg("max"), py::arg("weights"), + py::arg("histogram"), py::arg("depends") = py::list()); + + auto bincount_dtypes = [bincountp = bincount.get()]() { + return bincountp->dispatch_table.get_all_supported_types(); + }; + + m.def("bincount_dtypes", bincount_dtypes, + "Get the supported data types for bincount."); +} diff --git a/dpnp/backend/extensions/statistics/bincount.hpp b/dpnp/backend/extensions/statistics/bincount.hpp new file mode 100644 index 000000000000..70a17431383f --- /dev/null +++ b/dpnp/backend/extensions/statistics/bincount.hpp @@ -0,0 +1,66 @@ +//***************************************************************************** +// Copyright (c) 2024, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** + +#pragma once + +#include +#include + +#include "dispatch_table.hpp" + +namespace dpctl_td_ns = dpctl::tensor::type_dispatch; + +namespace statistics +{ +namespace histogram +{ +struct Bincount +{ + using FnT = sycl::event (*)(sycl::queue &, + const void *, + const int64_t, + const int64_t, + const void *, + void *, + const size_t, + const size_t, + const std::vector &); + + common::DispatchTable2 dispatch_table; + + Bincount(); + + std::tuple + call(const dpctl::tensor::usm_ndarray &input, + const int64_t min, + const int64_t max, + const std::optional &weights, + dpctl::tensor::usm_ndarray &output, + const std::vector &depends); +}; + +void populate_bincount(py::module_ m); +} // namespace histogram +} // namespace statistics diff --git a/dpnp/backend/extensions/statistics/common.hpp b/dpnp/backend/extensions/statistics/common.hpp index 518a3cf86a9e..99a31002ea13 100644 --- a/dpnp/backend/extensions/statistics/common.hpp +++ b/dpnp/backend/extensions/statistics/common.hpp @@ -165,6 +165,19 @@ size_t get_local_mem_size_in_items(const sycl::device &device, size_t reserve) return get_local_mem_size_in_bytes(device, sizeof(T) * reserve) / sizeof(T); } +template +inline size_t get_local_mem_size_in_items(const sycl::queue &queue) +{ + return get_local_mem_size_in_items(queue.get_device()); +} + +template +inline size_t get_local_mem_size_in_items(const sycl::queue &queue, + size_t reserve) +{ + return get_local_mem_size_in_items(queue.get_device(), reserve); +} + template sycl::nd_range make_ndrange(const sycl::range &global_range, const sycl::range &local_range, diff --git a/dpnp/backend/extensions/statistics/histogram.cpp b/dpnp/backend/extensions/statistics/histogram.cpp index 4317280b5a8d..848f89451236 100644 --- a/dpnp/backend/extensions/statistics/histogram.cpp +++ b/dpnp/backend/extensions/statistics/histogram.cpp @@ -94,9 +94,8 @@ struct HistogramEdges template bool in_bounds(const dT *val, const boundsT &bounds) const { - Less
_less; - return !_less(val[0], std::get<0>(bounds)) && - !_less(std::get<1>(bounds), val[0]) && !IsNan
::isnan(val[0]); + return check_in_bounds(val[0], std::get<0>(bounds), + std::get<1>(bounds)); } private: @@ -110,7 +109,7 @@ template using UncachedEdges = HistogramEdges>; template -struct histogram_kernel +struct HistogramF { static sycl::event impl(sycl::queue &exec_q, const void *vin, @@ -185,7 +184,7 @@ struct histogram_kernel }; template -using histogram_kernel_ = histogram_kernel; +using HistogramF_ = HistogramF; } // namespace @@ -212,7 +211,7 @@ using SupportedTypes = std::tuple, Histogram::Histogram() : dispatch_table("sample", "histogram") { - dispatch_table.populate_dispatch_table(); + dispatch_table.populate_dispatch_table(); } std::tuple diff --git a/dpnp/backend/extensions/statistics/histogram_common.cpp b/dpnp/backend/extensions/statistics/histogram_common.cpp index 30197f74e422..e2445b78bb3f 100644 --- a/dpnp/backend/extensions/statistics/histogram_common.cpp +++ b/dpnp/backend/extensions/statistics/histogram_common.cpp @@ -50,16 +50,24 @@ namespace histogram { void validate(const usm_ndarray &sample, - const usm_ndarray &bins, - std::optional &weights, + const std::optional &bins, + const std::optional &weights, const usm_ndarray &histogram) { auto exec_q = sample.get_queue(); using array_ptr = const usm_ndarray *; - std::vector arrays{&sample, &bins, &histogram}; + std::vector arrays{&sample, &histogram}; std::unordered_map names = { - {arrays[0], "sample"}, {arrays[1], "bins"}, {arrays[2], "histogram"}}; + {arrays[0], "sample"}, {arrays[1], "histogram"}}; + + array_ptr bins_ptr = nullptr; + + if (bins.has_value()) { + bins_ptr = &bins.value(); + arrays.push_back(bins_ptr); + names.insert({bins_ptr, "bins"}); + } array_ptr weights_ptr = nullptr; @@ -116,11 +124,11 @@ void validate(const usm_ndarray &sample, }; check_overlaping(&sample, &histogram); - check_overlaping(&bins, &histogram); + check_overlaping(bins_ptr, &histogram); check_overlaping(weights_ptr, &histogram); - if (bins.get_size() < 2) { - throw py::value_error(get_name(&bins) + + if (bins_ptr && bins_ptr->get_size() < 2) { + throw py::value_error(get_name(bins_ptr) + " parameter must have at least 2 elements"); } @@ -160,40 +168,43 @@ void validate(const usm_ndarray &sample, } if (sample.get_ndim() == 1) { - if (bins.get_ndim() != 1) { + if (bins_ptr != nullptr && bins_ptr->get_ndim() != 1) { throw py::value_error(get_name(&sample) + " parameter is 1d, but " + - get_name(&bins) + " is " + - std::to_string(bins.get_ndim()) + "d"); + get_name(bins_ptr) + " is " + + std::to_string(bins_ptr->get_ndim()) + "d"); } } else if (sample.get_ndim() == 2) { auto sample_count = sample.get_shape(0); auto expected_dims = sample.get_shape(1); - if (bins.get_ndim() != expected_dims) { + if (bins_ptr != nullptr && bins_ptr->get_ndim() != expected_dims) { throw py::value_error(get_name(&sample) + " parameter has shape {" + std::to_string(sample_count) + "x" + std::to_string(expected_dims) + "}" + - ", so " + get_name(&bins) + + ", so " + get_name(bins_ptr) + " parameter expected to be " + std::to_string(expected_dims) + "d. " "Actual " + - std::to_string(bins.get_ndim()) + "d"); + std::to_string(bins->get_ndim()) + "d"); } } - py::ssize_t expected_hist_size = 1; - for (int i = 0; i < bins.get_ndim(); ++i) { - expected_hist_size *= (bins.get_shape(i) - 1); - } + if (bins_ptr != nullptr) { + py::ssize_t expected_hist_size = 1; + for (int i = 0; i < bins_ptr->get_ndim(); ++i) { + expected_hist_size *= (bins_ptr->get_shape(i) - 1); + } - if (histogram.get_size() != expected_hist_size) { - throw py::value_error( - get_name(&histogram) + " and " + get_name(&bins) + - " shape mismatch. " + get_name(&histogram) + - " expected to have size = " + std::to_string(expected_hist_size) + - ". Actual " + std::to_string(histogram.get_size())); + if (histogram.get_size() != expected_hist_size) { + throw py::value_error( + get_name(&histogram) + " and " + get_name(bins_ptr) + + " shape mismatch. " + get_name(&histogram) + + " expected to have size = " + + std::to_string(expected_hist_size) + ". Actual " + + std::to_string(histogram.get_size())); + } } int64_t max_hist_size = std::numeric_limits::max() - 1; diff --git a/dpnp/backend/extensions/statistics/histogram_common.hpp b/dpnp/backend/extensions/statistics/histogram_common.hpp index 9d2cbd163024..e7503fe98778 100644 --- a/dpnp/backend/extensions/statistics/histogram_common.hpp +++ b/dpnp/backend/extensions/statistics/histogram_common.hpp @@ -278,6 +278,13 @@ struct Weights T *data = nullptr; }; +template +bool check_in_bounds(const dT &val, const dT &min, const dT &max) +{ + Less
_less; + return !_less(val, min) && !_less(max, val) && !IsNan
::isnan(val); +} + template class histogram_kernel; @@ -330,8 +337,8 @@ void submit_histogram(const T *in, } void validate(const usm_ndarray &sample, - const usm_ndarray &bins, - std::optional &weights, + const std::optional &bins, + const std::optional &weights, const usm_ndarray &histogram); uint32_t get_local_hist_copies_count(uint32_t loc_mem_size_in_items, diff --git a/dpnp/backend/extensions/statistics/statistics_py.cpp b/dpnp/backend/extensions/statistics/statistics_py.cpp index e4f533cd7a02..2f3bf6a901c1 100644 --- a/dpnp/backend/extensions/statistics/statistics_py.cpp +++ b/dpnp/backend/extensions/statistics/statistics_py.cpp @@ -30,9 +30,11 @@ #include #include +#include "bincount.hpp" #include "histogram.hpp" PYBIND11_MODULE(_statistics_impl, m) { + statistics::histogram::populate_bincount(m); statistics::histogram::populate_histogram(m); } diff --git a/dpnp/dpnp_iface_histograms.py b/dpnp/dpnp_iface_histograms.py index b5eea9f37bfc..809cb4e5e996 100644 --- a/dpnp/dpnp_iface_histograms.py +++ b/dpnp/dpnp_iface_histograms.py @@ -38,7 +38,6 @@ """ import operator -import warnings import dpctl.utils as dpu import numpy @@ -53,6 +52,7 @@ from .dpnp_utils import map_dtype_to_device __all__ = [ + "bincount", "digitize", "histogram", "histogram_bin_edges", @@ -63,6 +63,35 @@ _range = range +def _result_type_for_device(dtype1, dtype2, device): + rt = dpnp.result_type(dtype1, dtype2) + return map_dtype_to_device(rt, device) + + +def _align_dtypes(a_dtype, bins_dtype, ntype, supported_types, device): + has_fp64 = device.has_aspect_fp64 + has_fp16 = device.has_aspect_fp16 + + a_bin_dtype = _result_type_for_device(a_dtype, bins_dtype, device) + + # histogram implementation doesn't support uint64 as histogram type + # we can use int64 instead. Result would be correct even in case of overflow + if ntype == numpy.uint64: + ntype = dpnp.int64 + + if (a_bin_dtype, ntype) in supported_types: + return a_bin_dtype, ntype + + for sample_type, hist_type in supported_types: + if _can_cast( + a_bin_dtype, sample_type, has_fp16, has_fp64 + ) and _can_cast(ntype, hist_type, has_fp16, has_fp64): + return sample_type, hist_type + + # should not happen + return None, None + + def _ravel_check_a_and_weights(a, weights): """ Check input `a` and `weights` arrays, and ravel both. @@ -74,16 +103,6 @@ def _ravel_check_a_and_weights(a, weights): dpnp.check_supported_arrays_type(a) usm_type = a.usm_type - # ensure that the array is a "subtractable" dtype - if a.dtype == dpnp.bool: - warnings.warn( - f"Converting input from {a.dtype} to {numpy.uint8} " - "for compatibility.", - RuntimeWarning, - stacklevel=3, - ) - a = dpnp.astype(a, numpy.uint8) - if weights is not None: # check that `weights` array has supported type dpnp.check_supported_arrays_type(weights) @@ -208,6 +227,189 @@ def _get_bin_edges(a, bins, range, usm_type): return bin_edges, None +def _bincount_validate(x, weights, minlength): + if x.ndim > 1: + raise ValueError("object too deep for desired array") + if x.ndim < 1: + raise ValueError("object of too small depth for desired array") + if not dpnp.issubdtype(x.dtype, dpnp.integer) and not dpnp.issubdtype( + x.dtype, dpnp.bool + ): + raise TypeError("x must be an integer array") + if weights is not None: + if x.shape != weights.shape: + raise ValueError("The weights and x don't have the same length.") + if not ( + dpnp.issubdtype(weights.dtype, dpnp.integer) + or dpnp.issubdtype(weights.dtype, dpnp.floating) + or dpnp.issubdtype(weights.dtype, dpnp.bool) + ): + raise ValueError( + f"Weights must be integer or float. Got {weights.dtype}" + ) + + if minlength is not None: + minlength = int(minlength) + if minlength < 0: + raise ValueError("minlength must be non-negative") + + +def _bincount_run_native( + x_casted, weights_casted, minlength, n_dtype, usm_type +): + queue = x_casted.sycl_queue + + max_v = dpnp.max(x_casted) + min_v = dpnp.min(x_casted) + + if min_v < 0: + raise ValueError("x argument must have no negative arguments") + + size = int(dpnp.max(max_v)) + 1 + if minlength is not None: + size = max(size, minlength) + + # bincount implementation uses atomics, but atomics doesn't work with + # host usm memory + n_usm_type = "device" if usm_type == "host" else usm_type + + # bincount implementation requires output array to be filled with zeros + n_casted = dpnp.zeros( + size, dtype=n_dtype, usm_type=n_usm_type, sycl_queue=queue + ) + + _manager = dpu.SequentialOrderManager[queue] + + x_usm = dpnp.get_usm_ndarray(x_casted) + weights_usm = ( + dpnp.get_usm_ndarray(weights_casted) + if weights_casted is not None + else None + ) + n_usm = dpnp.get_usm_ndarray(n_casted) + + mem_ev, bc_ev = statistics_ext.bincount( + x_usm, + min_v, + max_v, + weights_usm, + n_usm, + depends=_manager.submitted_events, + ) + + _manager.add_event_pair(mem_ev, bc_ev) + + return n_casted + + +def bincount(x, weights=None, minlength=None): + """ + bincount(x, /, weights=None, minlength=None) + + Count number of occurrences of each value in array of non-negative ints. + + For full documentation refer to :obj:`numpy.bincount`. + + Parameters + ---------- + x : {dpnp.ndarray, usm_ndarray} + Input 1-dimensional array with nonnegative integer values. + weights : {None, dpnp.ndarray, usm_ndarray}, optional + Weights, array of the same shape as `x`. + Default: ``None`` + minlength : {None, int}, optional + A minimum number of bins for the output array. + Default: ``None`` + + Returns + ------- + out : dpnp.ndarray of ints + The result of binning the input array. + The length of `out` is equal to ``np.amax(x) + 1``. + + See Also + -------- + :obj:`dpnp.histogram` : Compute the histogram of a data set. + :obj:`dpnp.digitize` : Return the indices of the bins to which each value + :obj:`dpnp.unique` : Find the unique elements of an array. + + Examples + -------- + >>> import dpnp as np + >>> np.bincount(np.arange(5)) + array([1, 1, 1, 1, 1]) + >>> np.bincount(np.array([0, 1, 1, 3, 2, 1, 7])) + array([1, 3, 1, 1, 0, 0, 0, 1]) + + >>> x = np.array([0, 1, 1, 3, 2, 1, 7, 23]) + >>> np.bincount(x).size == np.amax(x) + 1 + array(True) + + The input array needs to be of integer dtype, otherwise a + TypeError is raised: + + >>> np.bincount(np.arange(5, dtype=np.float32)) + Traceback (most recent call last): + ... + TypeError: x must be an integer array + + A possible use of ``bincount`` is to perform sums over + variable-size chunks of an array, using the `weights` keyword. + + >>> w = np.array([0.3, 0.5, 0.2, 0.7, 1., -0.6], dtype=np.float32) # weights + >>> x = np.array([0, 1, 1, 2, 2, 2]) + >>> np.bincount(x, weights=w) + array([0.3, 0.7, 1.1], dtype=float32) + + """ + + _bincount_validate(x, weights, minlength) + + x, weights, usm_type = _ravel_check_a_and_weights(x, weights) + + queue = x.sycl_queue + device = queue.sycl_device + + if weights is None: + ntype = dpnp.dtype(dpnp.intp) + else: + # unlike in case of histogram result type is integer if no weights + # provided and float if weights are provided even if weights are integer + ntype = dpnp.default_float_type(sycl_queue=queue) + + weights_casted = None + + supported_types = statistics_ext.bincount_dtypes() + x_casted_dtype, ntype_casted = _align_dtypes( + x.dtype, x.dtype, ntype, supported_types, device + ) + + if x_casted_dtype is None or ntype_casted is None: + raise ValueError( + f"function '{bincount}' does not support input types " + f"({x.dtype}, {ntype}), " + "and the inputs could not be coerced to any " + "supported types" + ) + + x_casted = dpnp.astype(x, dtype=x_casted_dtype, copy=False) + + if weights is not None: + weights_casted = dpnp.astype(weights, dtype=ntype_casted, copy=False) + + n_casted = _bincount_run_native( + x_casted, weights_casted, minlength, ntype_casted, usm_type + ) + + n_usm_type = n_casted.usm_type + if usm_type != n_usm_type: + n = dpnp.asarray(n_casted, dtype=ntype, usm_type=usm_type) + else: + n = dpnp.astype(n_casted, ntype, copy=False) + + return n + + def digitize(x, bins, right=False): """ Return the indices of the bins to which each value in input array belongs. @@ -300,35 +502,6 @@ def digitize(x, bins, right=False): return bins.size - dpnp.searchsorted(bins[::-1], x, side=side) -def _result_type_for_device(dtype1, dtype2, device): - rt = dpnp.result_type(dtype1, dtype2) - return map_dtype_to_device(rt, device) - - -def _align_dtypes(a_dtype, bins_dtype, ntype, supported_types, device): - has_fp64 = device.has_aspect_fp64 - has_fp16 = device.has_aspect_fp16 - - a_bin_dtype = _result_type_for_device(a_dtype, bins_dtype, device) - - # histogram implementation doesn't support uint64 as histogram type - # we can use int64 instead. Result would be correct even in case of overflow - if ntype == numpy.uint64: - ntype = dpnp.int64 - - if (a_bin_dtype, ntype) in supported_types: - return a_bin_dtype, ntype - - for sample_type, hist_type in supported_types: - if _can_cast( - a_bin_dtype, sample_type, has_fp16, has_fp64 - ) and _can_cast(ntype, hist_type, has_fp16, has_fp64): - return sample_type, hist_type - - # should not happen - return None, None - - def histogram(a, bins=10, range=None, density=None, weights=None): """ Compute the histogram of a data set. diff --git a/dpnp/dpnp_iface_statistics.py b/dpnp/dpnp_iface_statistics.py index cdb8743b1ec2..83473e11a5bb 100644 --- a/dpnp/dpnp_iface_statistics.py +++ b/dpnp/dpnp_iface_statistics.py @@ -64,7 +64,6 @@ "amax", "amin", "average", - "bincount", "corrcoef", "correlate", "cov", @@ -339,28 +338,6 @@ def average(a, axis=None, weights=None, returned=False, *, keepdims=False): return avg -def bincount(x1, weights=None, minlength=0): - """ - Count number of occurrences of each value in array of non-negative integers. - - For full documentation refer to :obj:`numpy.bincount`. - - See Also - -------- - :obj:`dpnp.unique` : Find the unique elements of an array. - - Examples - -------- - >>> import dpnp as np - >>> res = np.bincount(np.arange(5)) - >>> print(res) - [1, 1, 1, 1, 1] - - """ - - return call_origin(numpy.bincount, x1, weights=weights, minlength=minlength) - - def corrcoef(x, y=None, rowvar=True, *, dtype=None): """ Return Pearson product-moment correlation coefficients. diff --git a/tests/test_histogram.py b/tests/test_histogram.py index c37e5a4316ff..92abe99526ab 100644 --- a/tests/test_histogram.py +++ b/tests/test_histogram.py @@ -16,6 +16,7 @@ assert_dtype_allclose, get_all_dtypes, get_float_dtypes, + get_integer_dtypes, has_support_aspect64, ) @@ -257,22 +258,6 @@ def test_outliers_normalization_weights(self): assert_allclose(result_hist, expected_hist) assert_allclose(result_edges, expected_edges) - @pytest.mark.parametrize("xp", [numpy, dpnp]) - def test_bool_conversion(self, xp): - a = xp.array([1, 1, 0], dtype=numpy.uint8) - int_hist, int_edges = xp.histogram(a) - - with suppress_warnings() as sup: - rec = sup.record(RuntimeWarning, "Converting input from .*") - - v = xp.array([True, True, False]) - hist, edges = xp.histogram(v) - - # A warning should be issued - assert len(rec) == 1 - assert_array_equal(hist, int_hist) - assert_array_equal(edges, int_edges) - @pytest.mark.parametrize("density", [True, False]) def test_weights(self, density): v = numpy.random.rand(100) @@ -531,3 +516,97 @@ def test_range(self, range, dtype): expected_edges = numpy.histogram_bin_edges(v, bins=bins, range=range) result_edges = dpnp.histogram_bin_edges(iv, bins=bins, range=range) assert_dtype_allclose(result_edges, expected_edges) + + +class TestBincount: + @pytest.mark.parametrize("dtype", get_integer_dtypes()) + def test_rand_data(self, dtype): + n = 100 + upper_bound = 10 if dtype != dpnp.bool_ else 1 + v = numpy.random.randint(0, upper_bound, size=n, dtype=dtype) + iv = dpnp.array(v) + + expected_hist = numpy.bincount(v) + result_hist = dpnp.bincount(iv) + assert_array_equal(result_hist, expected_hist) + + @pytest.mark.parametrize("dtype", get_integer_dtypes()) + def test_arange_data(self, dtype): + v = numpy.arange(100).astype(dtype) + iv = dpnp.array(v) + + expected_hist = numpy.bincount(v) + result_hist = dpnp.bincount(iv) + assert_array_equal(result_hist, expected_hist) + + @pytest.mark.parametrize("xp", [numpy, dpnp]) + def test_negative_values(self, xp): + x = xp.array([-1, 2]) + assert_raises(ValueError, xp.bincount, x) + + def test_no_side_effects(self): + v = dpnp.array([1, 2, 3], dtype=dpnp.int64) + copy_v = v.copy() + + # check that ensures that values passed to ``bincount`` are unchanged + _ = dpnp.bincount(v) + assert (v == copy_v).all() + + def test_weights_another_sycl_queue(self): + v = dpnp.arange(5, sycl_queue=dpctl.SyclQueue()) + w = dpnp.arange(7, 12, sycl_queue=dpctl.SyclQueue()) + with assert_raises(ValueError): + dpnp.bincount(v, weights=w) + + @pytest.mark.parametrize("xp", [numpy, dpnp]) + def test_weights_unsupported_dtype(self, xp): + v = dpnp.arange(5) + w = dpnp.arange(5, dtype=dpnp.complex64) + with assert_raises(ValueError): + dpnp.bincount(v, weights=w) + + @pytest.mark.parametrize( + "bins_count", + [10, 10**2, 10**3, 10**4, 10**5, 10**6], + ) + def test_different_bins_amount(self, bins_count): + v = numpy.arange(0, bins_count, dtype=int) + iv = dpnp.array(v) + + expected_hist = numpy.bincount(v) + result_hist = dpnp.bincount(iv) + assert_array_equal(result_hist, expected_hist) + + @pytest.mark.parametrize( + "array", + [[1, 2, 3], [1, 2, 2, 1, 2, 4], [2, 2, 2, 2]], + ids=["[1, 2, 3]", "[1, 2, 2, 1, 2, 4]", "[2, 2, 2, 2]"], + ) + @pytest.mark.parametrize( + "minlength", [0, 1, 3, 5], ids=["0", "1", "3", "5"] + ) + def test_bincount_minlength(self, array, minlength): + np_a = numpy.array(array) + dpnp_a = dpnp.array(array) + + expected = numpy.bincount(np_a, minlength=minlength) + result = dpnp.bincount(dpnp_a, minlength=minlength) + assert_allclose(expected, result) + + @pytest.mark.parametrize( + "array", [[1, 2, 2, 1, 2, 4]], ids=["[1, 2, 2, 1, 2, 4]"] + ) + @pytest.mark.parametrize( + "weights", + [None, [0.3, 0.5, 0.2, 0.7, 1.0, -0.6], [2, 2, 2, 2, 2, 2]], + ids=["None", "[0.3, 0.5, 0.2, 0.7, 1., -0.6]", "[2, 2, 2, 2, 2, 2]"], + ) + def test_bincount_weights(self, array, weights): + np_a = numpy.array(array) + np_weights = numpy.array(weights) if weights is not None else weights + dpnp_a = dpnp.array(array) + dpnp_weights = dpnp.array(weights) if weights is not None else weights + + expected = numpy.bincount(np_a, weights=np_weights) + result = dpnp.bincount(dpnp_a, weights=dpnp_weights) + assert_allclose(expected, result) diff --git a/tests/test_statistics.py b/tests/test_statistics.py index b0567216825a..4b67a8c84fa6 100644 --- a/tests/test_statistics.py +++ b/tests/test_statistics.py @@ -628,41 +628,6 @@ def test_corrcoef_scalar(self): assert_dtype_allclose(result, expected) -@pytest.mark.usefixtures("allow_fall_back_on_numpy") -class TestBincount: - @pytest.mark.parametrize( - "array", - [[1, 2, 3], [1, 2, 2, 1, 2, 4], [2, 2, 2, 2]], - ids=["[1, 2, 3]", "[1, 2, 2, 1, 2, 4]", "[2, 2, 2, 2]"], - ) - @pytest.mark.parametrize( - "minlength", [0, 1, 3, 5], ids=["0", "1", "3", "5"] - ) - def test_bincount_minlength(self, array, minlength): - np_a = numpy.array(array) - dpnp_a = dpnp.array(array) - - expected = numpy.bincount(np_a, minlength=minlength) - result = dpnp.bincount(dpnp_a, minlength=minlength) - assert_allclose(expected, result) - - @pytest.mark.parametrize( - "array", [[1, 2, 2, 1, 2, 4]], ids=["[1, 2, 2, 1, 2, 4]"] - ) - @pytest.mark.parametrize( - "weights", - [None, [0.3, 0.5, 0.2, 0.7, 1.0, -0.6], [2, 2, 2, 2, 2, 2]], - ids=["None", "[0.3, 0.5, 0.2, 0.7, 1., -0.6]", "[2, 2, 2, 2, 2, 2]"], - ) - def test_bincount_weights(self, array, weights): - np_a = numpy.array(array) - dpnp_a = dpnp.array(array) - - expected = numpy.bincount(np_a, weights=weights) - result = dpnp.bincount(dpnp_a, weights=weights) - assert_allclose(expected, result) - - @pytest.mark.parametrize( "dtype", get_all_dtypes(no_bool=True, no_none=True, no_complex=True) ) diff --git a/tests/test_sycl_queue.py b/tests/test_sycl_queue.py index 604a60446c78..43dda9a3ed50 100644 --- a/tests/test_sycl_queue.py +++ b/tests/test_sycl_queue.py @@ -2562,6 +2562,27 @@ def test_lstsq(m, n, nrhs, device): assert_sycl_queue_equal(param_dp.sycl_queue, b_dp.sycl_queue) +@pytest.mark.parametrize("weights", [None, numpy.arange(7, 12)]) +@pytest.mark.parametrize( + "device", + valid_devices, + ids=[device.filter_string for device in valid_devices], +) +def test_bincount(weights, device): + v = numpy.arange(5) + w = weights + + iv = dpnp.array(v, device=device) + iw = None if weights is None else dpnp.array(w, sycl_queue=iv.sycl_queue) + + expected_hist = numpy.bincount(v, weights=w) + result_hist = dpnp.bincount(iv, weights=iw) + assert_array_equal(result_hist, expected_hist) + + hist_queue = result_hist.sycl_queue + assert_sycl_queue_equal(hist_queue, iv.sycl_queue) + + @pytest.mark.parametrize("weights", [None, numpy.arange(7, 12)]) @pytest.mark.parametrize( "device", diff --git a/tests/test_usm_type.py b/tests/test_usm_type.py index 4aa74182ce77..f58c58605de0 100644 --- a/tests/test_usm_type.py +++ b/tests/test_usm_type.py @@ -1584,6 +1584,18 @@ def test_histogram(usm_type_v, usm_type_w): assert edges.usm_type == du.get_coerced_usm_type([usm_type_v, usm_type_w]) +@pytest.mark.parametrize("usm_type_v", list_of_usm_types, ids=list_of_usm_types) +@pytest.mark.parametrize("usm_type_w", list_of_usm_types, ids=list_of_usm_types) +def test_bincount(usm_type_v, usm_type_w): + v = dp.arange(5, usm_type=usm_type_v) + w = dp.arange(7, 12, usm_type=usm_type_w) + + hist = dp.bincount(v, weights=w) + assert v.usm_type == usm_type_v + assert w.usm_type == usm_type_w + assert hist.usm_type == du.get_coerced_usm_type([usm_type_v, usm_type_w]) + + @pytest.mark.parametrize( "func", ["tril_indices_from", "triu_indices_from", "diag_indices_from"] ) diff --git a/tests/third_party/cupy/statistics_tests/test_histogram.py b/tests/third_party/cupy/statistics_tests/test_histogram.py index 521bd4062fb3..29c02c2ea5d8 100644 --- a/tests/third_party/cupy/statistics_tests/test_histogram.py +++ b/tests/third_party/cupy/statistics_tests/test_histogram.py @@ -258,36 +258,33 @@ def test_histogram_bins_not_ordered(self, dtype): with pytest.raises(ValueError): xp.histogram(x, bins) - @pytest.mark.skip("bincount() is not implemented yet") @for_all_dtypes_bincount() @testing.numpy_cupy_allclose(accept_error=TypeError) def test_bincount(self, xp, dtype): x = testing.shaped_arange((3,), xp, dtype) return xp.bincount(x) - @pytest.mark.skip("bincount() is not implemented yet") @for_all_dtypes_bincount() @testing.numpy_cupy_allclose(accept_error=TypeError) def test_bincount_duplicated_value(self, xp, dtype): x = xp.array([1, 2, 2, 1, 2, 4], dtype) return xp.bincount(x) - @pytest.mark.skip("bincount() is not implemented yet") @for_all_dtypes_combination_bincount(names=["x_type", "w_type"]) - @testing.numpy_cupy_allclose(accept_error=TypeError) + @testing.numpy_cupy_allclose( + accept_error=TypeError, type_check=has_support_aspect64() + ) def test_bincount_with_weight(self, xp, x_type, w_type): x = testing.shaped_arange((3,), xp, x_type) w = testing.shaped_arange((3,), xp, w_type) return xp.bincount(x, weights=w) - @pytest.mark.skip("bincount() is not implemented yet") @for_all_dtypes_bincount() @testing.numpy_cupy_allclose(accept_error=TypeError) def test_bincount_with_minlength(self, xp, dtype): x = testing.shaped_arange((3,), xp, dtype) return xp.bincount(x, minlength=5) - @pytest.mark.skip("bincount() is not implemented yet") @for_all_dtypes_combination_bincount(names=["x_type", "w_type"]) def test_bincount_invalid_weight_length(self, x_type, w_type): for xp in (numpy, cupy): @@ -298,7 +295,6 @@ def test_bincount_invalid_weight_length(self, x_type, w_type): with pytest.raises((ValueError, TypeError)): xp.bincount(x, weights=w) - @pytest.mark.skip("bincount() is not implemented yet") @for_signed_dtypes_bincount() def test_bincount_negative(self, dtype): for xp in (numpy, cupy): @@ -306,7 +302,6 @@ def test_bincount_negative(self, dtype): with pytest.raises(ValueError): xp.bincount(x) - @pytest.mark.skip("bincount() is not implemented yet") @for_all_dtypes_bincount() def test_bincount_too_deep(self, dtype): for xp in (numpy, cupy): @@ -314,22 +309,19 @@ def test_bincount_too_deep(self, dtype): with pytest.raises(ValueError): xp.bincount(x) - @pytest.mark.skip("bincount() is not implemented yet") @for_all_dtypes_bincount() def test_bincount_too_small(self, dtype): for xp in (numpy, cupy): - x = xp.zeros((), dtype) + x = xp.zeros((), dtype=dtype) with pytest.raises(ValueError): xp.bincount(x) - @pytest.mark.skip("bincount() is not implemented yet") @for_all_dtypes_bincount() @testing.numpy_cupy_allclose(accept_error=TypeError) def test_bincount_zero(self, xp, dtype): x = testing.shaped_arange((3,), xp, dtype) return xp.bincount(x, minlength=0) - @pytest.mark.skip("bincount() is not implemented yet") @for_all_dtypes_bincount() def test_bincount_too_small_minlength(self, dtype): for xp in (numpy, cupy): diff --git a/tests_external/skipped_tests_numpy.tbl b/tests_external/skipped_tests_numpy.tbl index c2c0dc78ec54..a1ca87fb4bdb 100644 --- a/tests_external/skipped_tests_numpy.tbl +++ b/tests_external/skipped_tests_numpy.tbl @@ -322,7 +322,6 @@ tests/test_deprecations.py::TestAlen::test_alen tests/test_deprecations.py::TestArrayDataAttributeAssignmentDeprecation::test_data_attr_assignment tests/test_deprecations.py::TestBinaryReprInsufficientWidthParameterForRepresentation::test_insufficient_width_negative tests/test_deprecations.py::TestBinaryReprInsufficientWidthParameterForRepresentation::test_insufficient_width_positive -tests/test_deprecations.py::TestBincount::test_bincount_minlength tests/test_deprecations.py::TestComparisonDeprecations::test_array_richcompare_legacy_weirdness tests/test_deprecations.py::TestComparisonDeprecations::test_normal_types tests/test_deprecations.py::TestComparisonDeprecations::test_string