Skip to content

Commit 2c7c177

Browse files
Implementation of bincounts
1 parent 81047fa commit 2c7c177

File tree

14 files changed

+622
-92
lines changed

14 files changed

+622
-92
lines changed

dpnp/backend/extensions/statistics/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
set(python_module_name _statistics_impl)
2828
set(_module_src
2929
${CMAKE_CURRENT_SOURCE_DIR}/common.cpp
30+
${CMAKE_CURRENT_SOURCE_DIR}/bincount.cpp
3031
${CMAKE_CURRENT_SOURCE_DIR}/histogram.cpp
3132
${CMAKE_CURRENT_SOURCE_DIR}/histogram_common.cpp
3233
${CMAKE_CURRENT_SOURCE_DIR}/statistics_py.cpp
Lines changed: 233 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,233 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2024, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
//
13+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23+
// THE POSSIBILITY OF SUCH DAMAGE.
24+
//*****************************************************************************
25+
26+
#include <memory>
27+
28+
#include <pybind11/pybind11.h>
29+
#include <pybind11/stl.h>
30+
31+
#include "bincount.hpp"
32+
#include "histogram_common.hpp"
33+
34+
using dpctl::tensor::usm_ndarray;
35+
36+
using namespace statistics::histogram;
37+
using namespace statistics::common;
38+
39+
namespace
40+
{
41+
42+
template <typename T>
43+
struct BincountEdges
44+
{
45+
static constexpr bool const sync_after_init = false;
46+
using boundsT = std::tuple<T, T>;
47+
48+
BincountEdges(const T &min, const T &max)
49+
{
50+
this->min = min;
51+
this->max = max;
52+
}
53+
54+
template <int _Dims>
55+
void init(const sycl::nd_item<_Dims> &) const
56+
{
57+
}
58+
59+
boundsT get_bounds() const
60+
{
61+
return {min, max};
62+
}
63+
64+
template <int _Dims, typename dT>
65+
size_t get_bin(const sycl::nd_item<_Dims> &,
66+
const dT *val,
67+
const boundsT &) const
68+
{
69+
return val[0] - min;
70+
}
71+
72+
template <typename dT>
73+
bool in_bounds(const dT *val, const boundsT &bounds) const
74+
{
75+
Less<dT> _less;
76+
return !_less(val[0], std::get<0>(bounds)) &&
77+
!_less(std::get<1>(bounds), val[0]) && !IsNan<dT>::isnan(val[0]);
78+
}
79+
80+
private:
81+
T min;
82+
T max;
83+
};
84+
85+
template <typename T, typename HistType = size_t>
86+
struct BincountF
87+
{
88+
static sycl::event impl(sycl::queue &exec_q,
89+
const void *vin,
90+
int64_t min,
91+
int64_t max,
92+
const void *vweights,
93+
void *vout,
94+
const size_t,
95+
const size_t size,
96+
const std::vector<sycl::event> &depends)
97+
{
98+
const T *in = static_cast<const T *>(vin);
99+
const HistType *weights = static_cast<const HistType *>(vweights);
100+
// shift output pointer by min elements
101+
HistType *out = static_cast<HistType *>(vout) + min;
102+
103+
size_t needed_bins_count = (max - min) + 1;
104+
auto device = exec_q.get_device();
105+
106+
uint32_t local_size = get_max_local_size(exec_q);
107+
108+
uint32_t WorkPI = 128; // empirically found number
109+
auto nd_range = make_ndrange(size, local_size, WorkPI);
110+
111+
return exec_q.submit([&](sycl::handler &cgh) {
112+
cgh.depends_on(depends);
113+
constexpr uint32_t dims = 1;
114+
115+
auto dispatch_bins = [&](const auto &weights) {
116+
auto local_mem_size =
117+
device.get_info<sycl::info::device::local_mem_size>() /
118+
sizeof(T);
119+
if (local_mem_size >= needed_bins_count) {
120+
uint32_t local_hist_count = get_local_hist_copies_count(
121+
local_mem_size, local_size, needed_bins_count);
122+
123+
auto hist = HistWithLocalCopies<HistType>(
124+
out, needed_bins_count, local_hist_count, cgh);
125+
126+
auto edges = BincountEdges(min, max);
127+
submit_histogram(in, size, dims, WorkPI, hist, edges,
128+
weights, nd_range, cgh);
129+
}
130+
else {
131+
auto hist = HistGlobalMemory<HistType>(out);
132+
auto edges = BincountEdges(min, max);
133+
submit_histogram(in, size, dims, WorkPI, hist, edges,
134+
weights, nd_range, cgh);
135+
}
136+
};
137+
138+
if (weights) {
139+
auto _weights = Weights(weights);
140+
dispatch_bins(_weights);
141+
}
142+
else {
143+
auto _weights = NoWeights();
144+
dispatch_bins(_weights);
145+
}
146+
});
147+
}
148+
};
149+
150+
using SupportedTypes = std::tuple<std::tuple<int64_t, int64_t>,
151+
std::tuple<int64_t, float>,
152+
std::tuple<int64_t, double>>;
153+
154+
} // namespace
155+
156+
Bincount::Bincount() : dispatch_table("sample", "histogram")
157+
{
158+
dispatch_table.populate_dispatch_table<SupportedTypes, BincountF>();
159+
}
160+
161+
std::tuple<sycl::event, sycl::event>
162+
Bincount::call(const dpctl::tensor::usm_ndarray &sample,
163+
int64_t min,
164+
int64_t max,
165+
std::optional<const dpctl::tensor::usm_ndarray> &weights,
166+
dpctl::tensor::usm_ndarray &histogram,
167+
const std::vector<sycl::event> &depends)
168+
{
169+
validate(sample, std::optional<const dpctl::tensor::usm_ndarray>(), weights,
170+
histogram);
171+
172+
if (sample.get_size() == 0) {
173+
return {sycl::event(), sycl::event()};
174+
}
175+
176+
const int sample_typenum = sample.get_typenum();
177+
const int hist_typenum = histogram.get_typenum();
178+
179+
auto bincount_func = dispatch_table.get(sample_typenum, hist_typenum);
180+
181+
auto exec_q = sample.get_queue();
182+
183+
void *weights_ptr =
184+
weights.has_value() ? weights.value().get_data() : nullptr;
185+
186+
auto ev = bincount_func(exec_q, sample.get_data(), min, max, weights_ptr,
187+
histogram.get_data(), histogram.get_shape(0),
188+
sample.get_shape(0), depends);
189+
190+
sycl::event args_ev;
191+
if (weights.has_value()) {
192+
args_ev = dpctl::utils::keep_args_alive(
193+
exec_q, {sample, weights.value(), histogram}, {ev});
194+
}
195+
else {
196+
args_ev =
197+
dpctl::utils::keep_args_alive(exec_q, {sample, histogram}, {ev});
198+
}
199+
200+
return {args_ev, ev};
201+
}
202+
203+
std::unique_ptr<Bincount> bincount;
204+
205+
void statistics::histogram::populate_bincount(py::module_ m)
206+
{
207+
using namespace std::placeholders;
208+
209+
bincount.reset(new Bincount());
210+
211+
auto bincount_func =
212+
[bincountp = bincount.get()](
213+
const dpctl::tensor::usm_ndarray &sample, int64_t min, int64_t max,
214+
std::optional<const dpctl::tensor::usm_ndarray> &weights,
215+
dpctl::tensor::usm_ndarray &histogram,
216+
const std::vector<sycl::event> &depends) {
217+
return bincountp->call(sample, min, max, weights, histogram,
218+
depends);
219+
};
220+
221+
m.def("bincount", bincount_func,
222+
"Count number of occurrences of each value in array of non-negative "
223+
"ints.",
224+
py::arg("sample"), py::arg("min"), py::arg("max"), py::arg("weights"),
225+
py::arg("histogram"), py::arg("depends") = py::list());
226+
227+
auto bincount_dtypes = [bincountp = bincount.get()]() {
228+
return bincountp->dispatch_table.get_all_supported_types();
229+
};
230+
231+
m.def("bincount_dtypes", bincount_dtypes,
232+
"Get the supported data types for bincount.");
233+
}
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2024, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
//
13+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23+
// THE POSSIBILITY OF SUCH DAMAGE.
24+
//*****************************************************************************
25+
26+
#pragma once
27+
28+
#include <dpctl4pybind11.hpp>
29+
#include <sycl/sycl.hpp>
30+
31+
#include "dispatch_table.hpp"
32+
33+
namespace dpctl_td_ns = dpctl::tensor::type_dispatch;
34+
35+
namespace statistics
36+
{
37+
namespace histogram
38+
{
39+
struct Bincount
40+
{
41+
using FnT = sycl::event (*)(sycl::queue &,
42+
const void *,
43+
int64_t,
44+
int64_t,
45+
const void *,
46+
void *,
47+
const size_t,
48+
const size_t,
49+
const std::vector<sycl::event> &);
50+
51+
common::DispatchTable2<FnT> dispatch_table;
52+
53+
Bincount();
54+
55+
std::tuple<sycl::event, sycl::event>
56+
call(const dpctl::tensor::usm_ndarray &input,
57+
int64_t min,
58+
int64_t max,
59+
std::optional<const dpctl::tensor::usm_ndarray> &weights,
60+
dpctl::tensor::usm_ndarray &output,
61+
const std::vector<sycl::event> &depends);
62+
};
63+
64+
void populate_bincount(py::module_ m);
65+
} // namespace histogram
66+
} // namespace statistics

0 commit comments

Comments
 (0)