Skip to content

Commit 5274a5e

Browse files
Merge master into move_tests_folder
2 parents f5a6031 + 7bfe0c8 commit 5274a5e

File tree

20 files changed

+911
-213
lines changed

20 files changed

+911
-213
lines changed

.github/workflows/conda-package.yml

Lines changed: 2 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -78,18 +78,6 @@ jobs:
7878
- name: Install conda-build
7979
run: mamba install conda-build=${{ env.CONDA_BUILD_VERSION}}
8080

81-
- name: Cache conda packages
82-
uses: actions/cache@6849a6489940f00c2f30c0fb92c6274307ccb58a # v4.1.2
83-
env:
84-
CACHE_NUMBER: 1 # Increase to reset cache
85-
with:
86-
path: ${{ runner.os == 'Linux' && '/home/runner/conda_pkgs_dir' || 'C:\Users\runneradmin\conda_pkgs_dir' }}
87-
key:
88-
${{ runner.os }}-conda-${{ env.CACHE_NUMBER }}-python-${{ matrix.python }}-${{hashFiles('**/meta.yaml') }}
89-
restore-keys: |
90-
${{ runner.os }}-conda-${{ env.CACHE_NUMBER }}-python-${{ matrix.python }}-
91-
${{ runner.os }}-conda-${{ env.CACHE_NUMBER }}-
92-
9381
- name: Build conda package
9482
run: conda build --no-test --python ${{ matrix.python }} --numpy 2.0 ${{ env.CHANNELS }} conda-recipe
9583
env:
@@ -125,7 +113,6 @@ jobs:
125113
continue-on-error: true
126114

127115
env:
128-
conda-pkgs: '/home/runner/conda_pkgs_dir/'
129116
channel-path: '${{ github.workspace }}/channel/'
130117
pkg-path-in-channel: '${{ github.workspace }}/channel/linux-64/'
131118
extracted-pkg-path: '${{ github.workspace }}/pkg/'
@@ -165,30 +152,13 @@ jobs:
165152
mamba search ${{ env.PACKAGE_NAME }} -c ${{ env.channel-path }} --override-channels --info --json > ${{ env.ver-json-path }}
166153
cat ${{ env.ver-json-path }}
167154
168-
- name: Collect dependencies
155+
- name: Get package version
169156
run: |
170157
export PACKAGE_VERSION=$(python -c "${{ env.VER_SCRIPT1 }} ${{ env.VER_SCRIPT2 }}")
171158
172159
echo PACKAGE_VERSION=${PACKAGE_VERSION}
173160
echo "PACKAGE_VERSION=$PACKAGE_VERSION" >> $GITHUB_ENV
174161
175-
mamba install ${{ env.PACKAGE_NAME }}=${PACKAGE_VERSION} python=${{ matrix.python }} ${{ env.TEST_CHANNELS }} --only-deps --dry-run > lockfile
176-
cat lockfile
177-
env:
178-
TEST_CHANNELS: '-c ${{ env.channel-path }} ${{ env.CHANNELS }}'
179-
180-
- name: Cache conda packages
181-
uses: actions/cache@6849a6489940f00c2f30c0fb92c6274307ccb58a # v4.1.2
182-
env:
183-
CACHE_NUMBER: 1 # Increase to reset cache
184-
with:
185-
path: ${{ env.conda-pkgs }}
186-
key:
187-
${{ runner.os }}-conda-${{ env.CACHE_NUMBER }}-python-${{ matrix.python }}-${{hashFiles('lockfile') }}
188-
restore-keys: |
189-
${{ runner.os }}-conda-${{ env.CACHE_NUMBER }}-python-${{ matrix.python }}-
190-
${{ runner.os }}-conda-${{ env.CACHE_NUMBER }}-
191-
192162
- name: Install dpnp
193163
run: mamba install ${{ env.PACKAGE_NAME }}=${{ env.PACKAGE_VERSION }} pytest python=${{ matrix.python }} ${{ env.TEST_CHANNELS }}
194164
env:
@@ -242,7 +212,6 @@ jobs:
242212
continue-on-error: true
243213

244214
env:
245-
conda-pkgs: 'C:\Users\runneradmin\conda_pkgs_dir\'
246215
channel-path: '${{ github.workspace }}\channel\'
247216
pkg-path-in-channel: '${{ github.workspace }}\channel\win-64\'
248217
extracted-pkg-path: '${{ github.workspace }}\pkg'
@@ -302,7 +271,7 @@ jobs:
302271
- name: Dump version.json
303272
run: more ${{ env.ver-json-path }}
304273

305-
- name: Collect dependencies
274+
- name: Get package version
306275
run: |
307276
@echo on
308277
set "SCRIPT=${{ env.VER_SCRIPT1 }} ${{ env.VER_SCRIPT2 }}"
@@ -312,25 +281,6 @@ jobs:
312281
echo PACKAGE_VERSION: %PACKAGE_VERSION%
313282
(echo PACKAGE_VERSION=%PACKAGE_VERSION%) >> %GITHUB_ENV%
314283
315-
mamba install ${{ env.PACKAGE_NAME }}=%PACKAGE_VERSION% python=${{ matrix.python }} ${{ env.TEST_CHANNELS }} --only-deps --dry-run > lockfile
316-
env:
317-
TEST_CHANNELS: '-c ${{ env.channel-path }} ${{ env.CHANNELS }}'
318-
319-
- name: Dump lockfile
320-
run: more lockfile
321-
322-
- name: Cache conda packages
323-
uses: actions/cache@6849a6489940f00c2f30c0fb92c6274307ccb58a # v4.1.2
324-
env:
325-
CACHE_NUMBER: 1 # Increase to reset cache
326-
with:
327-
path: ${{ env.conda-pkgs }}
328-
key:
329-
${{ runner.os }}-conda-${{ env.CACHE_NUMBER }}-python-${{ matrix.python }}-${{hashFiles('lockfile') }}
330-
restore-keys: |
331-
${{ runner.os }}-conda-${{ env.CACHE_NUMBER }}-python-${{ matrix.python }}-
332-
${{ runner.os }}-conda-${{ env.CACHE_NUMBER }}-
333-
334284
- name: Install dpnp
335285
run: |
336286
@echo on

dpnp/backend/extensions/statistics/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
set(python_module_name _statistics_impl)
2828
set(_module_src
2929
${CMAKE_CURRENT_SOURCE_DIR}/common.cpp
30+
${CMAKE_CURRENT_SOURCE_DIR}/bincount.cpp
3031
${CMAKE_CURRENT_SOURCE_DIR}/histogram.cpp
3132
${CMAKE_CURRENT_SOURCE_DIR}/histogram_common.cpp
3233
${CMAKE_CURRENT_SOURCE_DIR}/statistics_py.cpp
Lines changed: 231 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,231 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2024, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
//
13+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23+
// THE POSSIBILITY OF SUCH DAMAGE.
24+
//*****************************************************************************
25+
26+
#include <memory>
27+
28+
#include <pybind11/pybind11.h>
29+
#include <pybind11/stl.h>
30+
31+
#include "bincount.hpp"
32+
#include "histogram_common.hpp"
33+
34+
using dpctl::tensor::usm_ndarray;
35+
36+
using namespace statistics::histogram;
37+
using namespace statistics::common;
38+
39+
namespace
40+
{
41+
42+
template <typename T>
43+
struct BincountEdges
44+
{
45+
static constexpr bool const sync_after_init = false;
46+
using boundsT = std::tuple<T, T>;
47+
48+
BincountEdges(const T &min, const T &max)
49+
{
50+
this->min = min;
51+
this->max = max;
52+
}
53+
54+
template <int _Dims>
55+
void init(const sycl::nd_item<_Dims> &) const
56+
{
57+
}
58+
59+
boundsT get_bounds() const
60+
{
61+
return {min, max};
62+
}
63+
64+
template <int _Dims, typename dT>
65+
size_t get_bin(const sycl::nd_item<_Dims> &,
66+
const dT *val,
67+
const boundsT &) const
68+
{
69+
return val[0] - min;
70+
}
71+
72+
template <typename dT>
73+
bool in_bounds(const dT *val, const boundsT &bounds) const
74+
{
75+
return check_in_bounds(val[0], std::get<0>(bounds),
76+
std::get<1>(bounds));
77+
}
78+
79+
private:
80+
T min;
81+
T max;
82+
};
83+
84+
template <typename T, typename HistType = size_t>
85+
struct BincountF
86+
{
87+
static sycl::event impl(sycl::queue &exec_q,
88+
const void *vin,
89+
const int64_t min,
90+
const int64_t max,
91+
const void *vweights,
92+
void *vout,
93+
const size_t,
94+
const size_t size,
95+
const std::vector<sycl::event> &depends)
96+
{
97+
const T *in = static_cast<const T *>(vin);
98+
const HistType *weights = static_cast<const HistType *>(vweights);
99+
// shift output pointer by min elements
100+
HistType *out = static_cast<HistType *>(vout) + min;
101+
102+
const size_t needed_bins_count = (max - min) + 1;
103+
104+
const uint32_t local_size = get_max_local_size(exec_q);
105+
106+
constexpr uint32_t WorkPI = 128; // empirically found number
107+
const auto nd_range = make_ndrange(size, local_size, WorkPI);
108+
109+
return exec_q.submit([&](sycl::handler &cgh) {
110+
cgh.depends_on(depends);
111+
constexpr uint32_t dims = 1;
112+
113+
auto dispatch_bins = [&](const auto &weights) {
114+
const auto local_mem_size =
115+
get_local_mem_size_in_items<T>(exec_q);
116+
if (local_mem_size >= needed_bins_count) {
117+
const uint32_t local_hist_count =
118+
get_local_hist_copies_count(local_mem_size, local_size,
119+
needed_bins_count);
120+
121+
auto hist = HistWithLocalCopies<HistType>(
122+
out, needed_bins_count, local_hist_count, cgh);
123+
124+
auto edges = BincountEdges(min, max);
125+
submit_histogram(in, size, dims, WorkPI, hist, edges,
126+
weights, nd_range, cgh);
127+
}
128+
else {
129+
auto hist = HistGlobalMemory<HistType>(out);
130+
auto edges = BincountEdges(min, max);
131+
submit_histogram(in, size, dims, WorkPI, hist, edges,
132+
weights, nd_range, cgh);
133+
}
134+
};
135+
136+
if (weights) {
137+
auto _weights = Weights(weights);
138+
dispatch_bins(_weights);
139+
}
140+
else {
141+
auto _weights = NoWeights();
142+
dispatch_bins(_weights);
143+
}
144+
});
145+
}
146+
};
147+
148+
using SupportedTypes = std::tuple<std::tuple<int64_t, int64_t>,
149+
std::tuple<int64_t, float>,
150+
std::tuple<int64_t, double>>;
151+
152+
} // namespace
153+
154+
Bincount::Bincount() : dispatch_table("sample", "histogram")
155+
{
156+
dispatch_table.populate_dispatch_table<SupportedTypes, BincountF>();
157+
}
158+
159+
std::tuple<sycl::event, sycl::event> Bincount::call(
160+
const dpctl::tensor::usm_ndarray &sample,
161+
const int64_t min,
162+
const int64_t max,
163+
const std::optional<const dpctl::tensor::usm_ndarray> &weights,
164+
dpctl::tensor::usm_ndarray &histogram,
165+
const std::vector<sycl::event> &depends)
166+
{
167+
validate(sample, std::optional<const dpctl::tensor::usm_ndarray>(), weights,
168+
histogram);
169+
170+
if (sample.get_size() == 0) {
171+
return {sycl::event(), sycl::event()};
172+
}
173+
174+
const int sample_typenum = sample.get_typenum();
175+
const int hist_typenum = histogram.get_typenum();
176+
177+
auto bincount_func = dispatch_table.get(sample_typenum, hist_typenum);
178+
179+
auto exec_q = sample.get_queue();
180+
181+
void *weights_ptr =
182+
weights.has_value() ? weights.value().get_data() : nullptr;
183+
184+
auto ev = bincount_func(exec_q, sample.get_data(), min, max, weights_ptr,
185+
histogram.get_data(), histogram.get_shape(0),
186+
sample.get_shape(0), depends);
187+
188+
sycl::event args_ev;
189+
if (weights.has_value()) {
190+
args_ev = dpctl::utils::keep_args_alive(
191+
exec_q, {sample, weights.value(), histogram}, {ev});
192+
}
193+
else {
194+
args_ev =
195+
dpctl::utils::keep_args_alive(exec_q, {sample, histogram}, {ev});
196+
}
197+
198+
return {args_ev, ev};
199+
}
200+
201+
std::unique_ptr<Bincount> bincount;
202+
203+
void statistics::histogram::populate_bincount(py::module_ m)
204+
{
205+
using namespace std::placeholders;
206+
207+
bincount.reset(new Bincount());
208+
209+
auto bincount_func =
210+
[bincountp = bincount.get()](
211+
const dpctl::tensor::usm_ndarray &sample, int64_t min, int64_t max,
212+
std::optional<const dpctl::tensor::usm_ndarray> &weights,
213+
dpctl::tensor::usm_ndarray &histogram,
214+
const std::vector<sycl::event> &depends) {
215+
return bincountp->call(sample, min, max, weights, histogram,
216+
depends);
217+
};
218+
219+
m.def("bincount", bincount_func,
220+
"Count number of occurrences of each value in array of non-negative "
221+
"ints.",
222+
py::arg("sample"), py::arg("min"), py::arg("max"), py::arg("weights"),
223+
py::arg("histogram"), py::arg("depends") = py::list());
224+
225+
auto bincount_dtypes = [bincountp = bincount.get()]() {
226+
return bincountp->dispatch_table.get_all_supported_types();
227+
};
228+
229+
m.def("bincount_dtypes", bincount_dtypes,
230+
"Get the supported data types for bincount.");
231+
}

0 commit comments

Comments
 (0)