Skip to content

Commit c1fad0b

Browse files
Implementing histogramdd
1 parent 8464d9b commit c1fad0b

File tree

9 files changed

+864
-89
lines changed

9 files changed

+864
-89
lines changed

dpnp/backend/extensions/statistics/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ set(python_module_name _statistics_impl)
2828
set(_module_src
2929
${CMAKE_CURRENT_SOURCE_DIR}/common.cpp
3030
${CMAKE_CURRENT_SOURCE_DIR}/histogram.cpp
31+
${CMAKE_CURRENT_SOURCE_DIR}/histogramdd.cpp
3132
${CMAKE_CURRENT_SOURCE_DIR}/histogram_common.cpp
3233
${CMAKE_CURRENT_SOURCE_DIR}/statistics_py.cpp
3334
)

dpnp/backend/extensions/statistics/common.hpp

Lines changed: 53 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,30 @@
3434
// so sycl.hpp must be included before math_utils.hpp
3535
#include <sycl/sycl.hpp>
3636
#include "utils/math_utils.hpp"
37+
#include "utils/type_utils.hpp"
3738
// clang-format on
3839

40+
namespace dpctl
41+
{
42+
namespace tensor
43+
{
44+
namespace type_utils
45+
{
46+
// Upstream to dpctl
47+
template <class T>
48+
struct is_complex<const std::complex<T>> : std::true_type
49+
{
50+
};
51+
52+
template <typename T>
53+
constexpr bool is_complex_v = is_complex<T>::value;
54+
55+
} // namespace type_utils
56+
} // namespace tensor
57+
} // namespace dpctl
58+
59+
namespace type_utils = dpctl::tensor::type_utils;
60+
3961
namespace statistics
4062
{
4163
namespace common
@@ -56,24 +78,20 @@ constexpr auto Align(N n, D d)
5678
template <typename T, sycl::memory_order Order, sycl::memory_scope Scope>
5779
struct AtomicOp
5880
{
59-
static void add(T &lhs, const T value)
81+
static void add(T &lhs, const T &value)
6082
{
61-
sycl::atomic_ref<T, Order, Scope> lh(lhs);
62-
lh += value;
63-
}
64-
};
83+
if constexpr (type_utils::is_complex_v<T>) {
84+
using vT = typename T::value_type;
85+
vT *_lhs = reinterpret_cast<vT(&)[2]>(lhs);
86+
const vT *_val = reinterpret_cast<const vT(&)[2]>(value);
6587

66-
template <typename T, sycl::memory_order Order, sycl::memory_scope Scope>
67-
struct AtomicOp<std::complex<T>, Order, Scope>
68-
{
69-
static void add(std::complex<T> &lhs, const std::complex<T> value)
70-
{
71-
T *_lhs = reinterpret_cast<T(&)[2]>(lhs);
72-
const T *_val = reinterpret_cast<const T(&)[2]>(value);
73-
sycl::atomic_ref<T, Order, Scope> lh0(_lhs[0]);
74-
lh0 += _val[0];
75-
sycl::atomic_ref<T, Order, Scope> lh1(_lhs[1]);
76-
lh1 += _val[1];
88+
AtomicOp<vT, Order, Scope>::add(_lhs[0], _val[0]);
89+
AtomicOp<vT, Order, Scope>::add(_lhs[1], _val[1]);
90+
}
91+
else {
92+
sycl::atomic_ref<T, Order, Scope> lh(lhs);
93+
lh += value;
94+
}
7795
}
7896
};
7997

@@ -82,17 +100,12 @@ struct Less
82100
{
83101
bool operator()(const T &lhs, const T &rhs) const
84102
{
85-
return std::less{}(lhs, rhs);
86-
}
87-
};
88-
89-
template <typename T>
90-
struct Less<std::complex<T>>
91-
{
92-
bool operator()(const std::complex<T> &lhs,
93-
const std::complex<T> &rhs) const
94-
{
95-
return dpctl::tensor::math_utils::less_complex(lhs, rhs);
103+
if constexpr (type_utils::is_complex_v<T>) {
104+
return dpctl::tensor::math_utils::less_complex(lhs, rhs);
105+
}
106+
else {
107+
return std::less{}(lhs, rhs);
108+
}
96109
}
97110
};
98111

@@ -101,26 +114,25 @@ struct IsNan
101114
{
102115
static bool isnan(const T &v)
103116
{
104-
if constexpr (std::is_floating_point_v<T> ||
105-
std::is_same_v<T, sycl::half>) {
106-
return sycl::isnan(v);
117+
if constexpr (type_utils::is_complex_v<T>) {
118+
const auto real1 = std::real(v);
119+
const auto imag1 = std::imag(v);
120+
121+
using vT = typename T::value_type;
122+
123+
return IsNan<vT>::isnan(real1) || IsNan<vT>::isnan(imag1);
124+
}
125+
else {
126+
if constexpr (std::is_floating_point_v<T> ||
127+
std::is_same_v<T, sycl::half>) {
128+
return sycl::isnan(v);
129+
}
107130
}
108131

109132
return false;
110133
}
111134
};
112135

113-
template <typename T>
114-
struct IsNan<std::complex<T>>
115-
{
116-
static bool isnan(const std::complex<T> &v)
117-
{
118-
T real1 = std::real(v);
119-
T imag1 = std::imag(v);
120-
return sycl::isnan(real1) || sycl::isnan(imag1);
121-
}
122-
};
123-
124136
size_t get_max_local_size(const sycl::device &device);
125137
size_t get_max_local_size(const sycl::device &device,
126138
int cpu_local_size_limit,

dpnp/backend/extensions/statistics/histogram_common.cpp

Lines changed: 22 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -129,22 +129,16 @@ void validate(const usm_ndarray &sample,
129129
" parameter must have at least 1 element");
130130
}
131131

132-
if (histogram.get_ndim() != 1) {
133-
throw py::value_error(get_name(&histogram) +
134-
" parameter must be 1d. Actual " +
135-
std::to_string(histogram.get_ndim()) + "d");
136-
}
137-
138132
if (weights_ptr) {
139133
if (weights_ptr->get_ndim() != 1) {
140134
throw py::value_error(
141135
get_name(weights_ptr) + " parameter must be 1d. Actual " +
142136
std::to_string(weights_ptr->get_ndim()) + "d");
143137
}
144138

145-
auto sample_size = sample.get_size();
139+
auto sample_size = sample.get_shape(0);
146140
auto weights_size = weights_ptr->get_size();
147-
if (sample.get_size() != weights_ptr->get_size()) {
141+
if (sample_size != weights_ptr->get_size()) {
148142
throw py::value_error(
149143
get_name(&sample) + " size (" + std::to_string(sample_size) +
150144
") and " + get_name(weights_ptr) + " size (" +
@@ -160,42 +154,40 @@ void validate(const usm_ndarray &sample,
160154
}
161155

162156
if (sample.get_ndim() == 1) {
163-
if (bins.get_ndim() != 1) {
157+
if (histogram.get_ndim() != 1) {
164158
throw py::value_error(get_name(&sample) + " parameter is 1d, but " +
165-
get_name(&bins) + " is " +
166-
std::to_string(bins.get_ndim()) + "d");
159+
get_name(&histogram) + " is " +
160+
std::to_string(histogram.get_ndim()) + "d");
161+
}
162+
163+
if (histogram.get_size() != bins.get_size() - 1) {
164+
auto hist_size = histogram.get_size();
165+
auto bins_size = bins.get_size();
166+
throw py::value_error(
167+
get_name(&histogram) + " parameter and " + get_name(&bins) +
168+
" parameters shape mismatch. " + get_name(&histogram) +
169+
" size is " + std::to_string(hist_size) + get_name(&bins) +
170+
" must have size " + std::to_string(hist_size + 1) +
171+
" but have " + std::to_string(bins_size));
167172
}
168173
}
169174
else if (sample.get_ndim() == 2) {
170175
auto sample_count = sample.get_shape(0);
171176
auto expected_dims = sample.get_shape(1);
172177

173-
if (bins.get_ndim() != expected_dims) {
174-
throw py::value_error(get_name(&sample) + " parameter has shape {" +
175-
std::to_string(sample_count) + "x" +
176-
std::to_string(expected_dims) + "}" +
177-
", so " + get_name(&bins) +
178+
if (histogram.get_ndim() != expected_dims) {
179+
throw py::value_error(get_name(&sample) + " parameter has shape (" +
180+
std::to_string(sample_count) + ", " +
181+
std::to_string(expected_dims) + ")" +
182+
", so " + get_name(&histogram) +
178183
" parameter expected to be " +
179184
std::to_string(expected_dims) +
180185
"d. "
181186
"Actual " +
182-
std::to_string(bins.get_ndim()) + "d");
187+
std::to_string(histogram.get_ndim()) + "d");
183188
}
184189
}
185190

186-
py::ssize_t expected_hist_size = 1;
187-
for (int i = 0; i < bins.get_ndim(); ++i) {
188-
expected_hist_size *= (bins.get_shape(i) - 1);
189-
}
190-
191-
if (histogram.get_size() != expected_hist_size) {
192-
throw py::value_error(
193-
get_name(&histogram) + " and " + get_name(&bins) +
194-
" shape mismatch. " + get_name(&histogram) +
195-
" expected to have size = " + std::to_string(expected_hist_size) +
196-
". Actual " + std::to_string(histogram.get_size()));
197-
}
198-
199191
int64_t max_hist_size = std::numeric_limits<uint32_t>::max() - 1;
200192
if (histogram.get_size() > max_hist_size) {
201193
throw py::value_error(get_name(&histogram) +

dpnp/backend/extensions/statistics/histogram_common.hpp

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,12 +52,15 @@ template <typename T, int Dims>
5252
struct CachedData
5353
{
5454
static constexpr bool const sync_after_init = true;
55-
using pointer_type = T *;
55+
using Shape = sycl::range<Dims>;
56+
using value_type = T;
57+
using pointer_type = value_type *;
58+
static constexpr auto dims = Dims;
5659

57-
using ncT = typename std::remove_const<T>::type;
60+
using ncT = typename std::remove_const<value_type>::type;
5861
using LocalData = sycl::local_accessor<ncT, Dims>;
5962

60-
CachedData(T *global_data, sycl::range<Dims> shape, sycl::handler &cgh)
63+
CachedData(T *global_data, Shape shape, sycl::handler &cgh)
6164
{
6265
this->global_data = global_data;
6366
local_data = LocalData(shape, cgh);
@@ -87,17 +90,30 @@ struct CachedData
8790
return local_data.size();
8891
}
8992

93+
T &operator[](const sycl::id<Dims> &id) const
94+
{
95+
return local_data[id];
96+
}
97+
98+
template <typename = std::enable_if_t<Dims == 1>>
99+
T &operator[](const size_t id) const
100+
{
101+
return local_data[id];
102+
}
103+
90104
private:
91105
LocalData local_data;
92-
T *global_data = nullptr;
106+
value_type *global_data = nullptr;
93107
};
94108

95109
template <typename T, int Dims>
96110
struct UncachedData
97111
{
98112
static constexpr bool const sync_after_init = false;
99113
using Shape = sycl::range<Dims>;
100-
using pointer_type = T *;
114+
using value_type = T;
115+
using pointer_type = value_type *;
116+
static constexpr auto dims = Dims;
101117

102118
UncachedData(T *global_data, const Shape &shape, sycl::handler &)
103119
{
@@ -120,6 +136,17 @@ struct UncachedData
120136
return _shape.size();
121137
}
122138

139+
T &operator[](const sycl::id<Dims> &id) const
140+
{
141+
return global_data[id];
142+
}
143+
144+
template <typename = std::enable_if_t<Dims == 1>>
145+
T &operator[](const size_t id) const
146+
{
147+
return global_data[id];
148+
}
149+
123150
private:
124151
T *global_data = nullptr;
125152
Shape _shape;

0 commit comments

Comments
 (0)