Skip to content

Commit d536dae

Browse files
Implementation of bincounts (#2142)
Co-authored-by: Anton <[email protected]>
1 parent 923eb84 commit d536dae

16 files changed

+706
-158
lines changed

dpnp/backend/extensions/statistics/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
set(python_module_name _statistics_impl)
2828
set(_module_src
2929
${CMAKE_CURRENT_SOURCE_DIR}/common.cpp
30+
${CMAKE_CURRENT_SOURCE_DIR}/bincount.cpp
3031
${CMAKE_CURRENT_SOURCE_DIR}/histogram.cpp
3132
${CMAKE_CURRENT_SOURCE_DIR}/histogram_common.cpp
3233
${CMAKE_CURRENT_SOURCE_DIR}/statistics_py.cpp
Lines changed: 231 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,231 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2024, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
//
13+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23+
// THE POSSIBILITY OF SUCH DAMAGE.
24+
//*****************************************************************************
25+
26+
#include <memory>
27+
28+
#include <pybind11/pybind11.h>
29+
#include <pybind11/stl.h>
30+
31+
#include "bincount.hpp"
32+
#include "histogram_common.hpp"
33+
34+
using dpctl::tensor::usm_ndarray;
35+
36+
using namespace statistics::histogram;
37+
using namespace statistics::common;
38+
39+
namespace
40+
{
41+
42+
template <typename T>
43+
struct BincountEdges
44+
{
45+
static constexpr bool const sync_after_init = false;
46+
using boundsT = std::tuple<T, T>;
47+
48+
BincountEdges(const T &min, const T &max)
49+
{
50+
this->min = min;
51+
this->max = max;
52+
}
53+
54+
template <int _Dims>
55+
void init(const sycl::nd_item<_Dims> &) const
56+
{
57+
}
58+
59+
boundsT get_bounds() const
60+
{
61+
return {min, max};
62+
}
63+
64+
template <int _Dims, typename dT>
65+
size_t get_bin(const sycl::nd_item<_Dims> &,
66+
const dT *val,
67+
const boundsT &) const
68+
{
69+
return val[0] - min;
70+
}
71+
72+
template <typename dT>
73+
bool in_bounds(const dT *val, const boundsT &bounds) const
74+
{
75+
return check_in_bounds(val[0], std::get<0>(bounds),
76+
std::get<1>(bounds));
77+
}
78+
79+
private:
80+
T min;
81+
T max;
82+
};
83+
84+
template <typename T, typename HistType = size_t>
85+
struct BincountF
86+
{
87+
static sycl::event impl(sycl::queue &exec_q,
88+
const void *vin,
89+
const int64_t min,
90+
const int64_t max,
91+
const void *vweights,
92+
void *vout,
93+
const size_t,
94+
const size_t size,
95+
const std::vector<sycl::event> &depends)
96+
{
97+
const T *in = static_cast<const T *>(vin);
98+
const HistType *weights = static_cast<const HistType *>(vweights);
99+
// shift output pointer by min elements
100+
HistType *out = static_cast<HistType *>(vout) + min;
101+
102+
const size_t needed_bins_count = (max - min) + 1;
103+
104+
const uint32_t local_size = get_max_local_size(exec_q);
105+
106+
constexpr uint32_t WorkPI = 128; // empirically found number
107+
const auto nd_range = make_ndrange(size, local_size, WorkPI);
108+
109+
return exec_q.submit([&](sycl::handler &cgh) {
110+
cgh.depends_on(depends);
111+
constexpr uint32_t dims = 1;
112+
113+
auto dispatch_bins = [&](const auto &weights) {
114+
const auto local_mem_size =
115+
get_local_mem_size_in_items<T>(exec_q);
116+
if (local_mem_size >= needed_bins_count) {
117+
const uint32_t local_hist_count =
118+
get_local_hist_copies_count(local_mem_size, local_size,
119+
needed_bins_count);
120+
121+
auto hist = HistWithLocalCopies<HistType>(
122+
out, needed_bins_count, local_hist_count, cgh);
123+
124+
auto edges = BincountEdges(min, max);
125+
submit_histogram(in, size, dims, WorkPI, hist, edges,
126+
weights, nd_range, cgh);
127+
}
128+
else {
129+
auto hist = HistGlobalMemory<HistType>(out);
130+
auto edges = BincountEdges(min, max);
131+
submit_histogram(in, size, dims, WorkPI, hist, edges,
132+
weights, nd_range, cgh);
133+
}
134+
};
135+
136+
if (weights) {
137+
auto _weights = Weights(weights);
138+
dispatch_bins(_weights);
139+
}
140+
else {
141+
auto _weights = NoWeights();
142+
dispatch_bins(_weights);
143+
}
144+
});
145+
}
146+
};
147+
148+
using SupportedTypes = std::tuple<std::tuple<int64_t, int64_t>,
149+
std::tuple<int64_t, float>,
150+
std::tuple<int64_t, double>>;
151+
152+
} // namespace
153+
154+
Bincount::Bincount() : dispatch_table("sample", "histogram")
155+
{
156+
dispatch_table.populate_dispatch_table<SupportedTypes, BincountF>();
157+
}
158+
159+
std::tuple<sycl::event, sycl::event> Bincount::call(
160+
const dpctl::tensor::usm_ndarray &sample,
161+
const int64_t min,
162+
const int64_t max,
163+
const std::optional<const dpctl::tensor::usm_ndarray> &weights,
164+
dpctl::tensor::usm_ndarray &histogram,
165+
const std::vector<sycl::event> &depends)
166+
{
167+
validate(sample, std::optional<const dpctl::tensor::usm_ndarray>(), weights,
168+
histogram);
169+
170+
if (sample.get_size() == 0) {
171+
return {sycl::event(), sycl::event()};
172+
}
173+
174+
const int sample_typenum = sample.get_typenum();
175+
const int hist_typenum = histogram.get_typenum();
176+
177+
auto bincount_func = dispatch_table.get(sample_typenum, hist_typenum);
178+
179+
auto exec_q = sample.get_queue();
180+
181+
void *weights_ptr =
182+
weights.has_value() ? weights.value().get_data() : nullptr;
183+
184+
auto ev = bincount_func(exec_q, sample.get_data(), min, max, weights_ptr,
185+
histogram.get_data(), histogram.get_shape(0),
186+
sample.get_shape(0), depends);
187+
188+
sycl::event args_ev;
189+
if (weights.has_value()) {
190+
args_ev = dpctl::utils::keep_args_alive(
191+
exec_q, {sample, weights.value(), histogram}, {ev});
192+
}
193+
else {
194+
args_ev =
195+
dpctl::utils::keep_args_alive(exec_q, {sample, histogram}, {ev});
196+
}
197+
198+
return {args_ev, ev};
199+
}
200+
201+
std::unique_ptr<Bincount> bincount;
202+
203+
void statistics::histogram::populate_bincount(py::module_ m)
204+
{
205+
using namespace std::placeholders;
206+
207+
bincount.reset(new Bincount());
208+
209+
auto bincount_func =
210+
[bincountp = bincount.get()](
211+
const dpctl::tensor::usm_ndarray &sample, int64_t min, int64_t max,
212+
std::optional<const dpctl::tensor::usm_ndarray> &weights,
213+
dpctl::tensor::usm_ndarray &histogram,
214+
const std::vector<sycl::event> &depends) {
215+
return bincountp->call(sample, min, max, weights, histogram,
216+
depends);
217+
};
218+
219+
m.def("bincount", bincount_func,
220+
"Count number of occurrences of each value in array of non-negative "
221+
"ints.",
222+
py::arg("sample"), py::arg("min"), py::arg("max"), py::arg("weights"),
223+
py::arg("histogram"), py::arg("depends") = py::list());
224+
225+
auto bincount_dtypes = [bincountp = bincount.get()]() {
226+
return bincountp->dispatch_table.get_all_supported_types();
227+
};
228+
229+
m.def("bincount_dtypes", bincount_dtypes,
230+
"Get the supported data types for bincount.");
231+
}
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2024, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
//
13+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23+
// THE POSSIBILITY OF SUCH DAMAGE.
24+
//*****************************************************************************
25+
26+
#pragma once
27+
28+
#include <dpctl4pybind11.hpp>
29+
#include <sycl/sycl.hpp>
30+
31+
#include "dispatch_table.hpp"
32+
33+
namespace dpctl_td_ns = dpctl::tensor::type_dispatch;
34+
35+
namespace statistics
36+
{
37+
namespace histogram
38+
{
39+
struct Bincount
40+
{
41+
using FnT = sycl::event (*)(sycl::queue &,
42+
const void *,
43+
const int64_t,
44+
const int64_t,
45+
const void *,
46+
void *,
47+
const size_t,
48+
const size_t,
49+
const std::vector<sycl::event> &);
50+
51+
common::DispatchTable2<FnT> dispatch_table;
52+
53+
Bincount();
54+
55+
std::tuple<sycl::event, sycl::event>
56+
call(const dpctl::tensor::usm_ndarray &input,
57+
const int64_t min,
58+
const int64_t max,
59+
const std::optional<const dpctl::tensor::usm_ndarray> &weights,
60+
dpctl::tensor::usm_ndarray &output,
61+
const std::vector<sycl::event> &depends);
62+
};
63+
64+
void populate_bincount(py::module_ m);
65+
} // namespace histogram
66+
} // namespace statistics

dpnp/backend/extensions/statistics/common.hpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,19 @@ size_t get_local_mem_size_in_items(const sycl::device &device, size_t reserve)
165165
return get_local_mem_size_in_bytes(device, sizeof(T) * reserve) / sizeof(T);
166166
}
167167

168+
template <typename T>
169+
inline size_t get_local_mem_size_in_items(const sycl::queue &queue)
170+
{
171+
return get_local_mem_size_in_items<T>(queue.get_device());
172+
}
173+
174+
template <typename T>
175+
inline size_t get_local_mem_size_in_items(const sycl::queue &queue,
176+
size_t reserve)
177+
{
178+
return get_local_mem_size_in_items<T>(queue.get_device(), reserve);
179+
}
180+
168181
template <int Dims>
169182
sycl::nd_range<Dims> make_ndrange(const sycl::range<Dims> &global_range,
170183
const sycl::range<Dims> &local_range,

dpnp/backend/extensions/statistics/histogram.cpp

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -94,9 +94,8 @@ struct HistogramEdges
9494
template <typename dT>
9595
bool in_bounds(const dT *val, const boundsT &bounds) const
9696
{
97-
Less<dT> _less;
98-
return !_less(val[0], std::get<0>(bounds)) &&
99-
!_less(std::get<1>(bounds), val[0]) && !IsNan<dT>::isnan(val[0]);
97+
return check_in_bounds(val[0], std::get<0>(bounds),
98+
std::get<1>(bounds));
10099
}
101100

102101
private:
@@ -110,7 +109,7 @@ template <typename T>
110109
using UncachedEdges = HistogramEdges<T, UncachedData<const T, 1>>;
111110

112111
template <typename T, typename BinsT, typename HistType = size_t>
113-
struct histogram_kernel
112+
struct HistogramF
114113
{
115114
static sycl::event impl(sycl::queue &exec_q,
116115
const void *vin,
@@ -185,7 +184,7 @@ struct histogram_kernel
185184
};
186185

187186
template <typename SampleType, typename HistType>
188-
using histogram_kernel_ = histogram_kernel<SampleType, SampleType, HistType>;
187+
using HistogramF_ = HistogramF<SampleType, SampleType, HistType>;
189188

190189
} // namespace
191190

@@ -212,7 +211,7 @@ using SupportedTypes = std::tuple<std::tuple<uint64_t, int64_t>,
212211

213212
Histogram::Histogram() : dispatch_table("sample", "histogram")
214213
{
215-
dispatch_table.populate_dispatch_table<SupportedTypes, histogram_kernel_>();
214+
dispatch_table.populate_dispatch_table<SupportedTypes, HistogramF_>();
216215
}
217216

218217
std::tuple<sycl::event, sycl::event>

0 commit comments

Comments
 (0)