Skip to content

Commit ed7cbcc

Browse files
Fix review comments
1 parent 7fc1df1 commit ed7cbcc

File tree

3 files changed

+24
-4
lines changed

3 files changed

+24
-4
lines changed

dpnp/backend/extensions/statistics/common.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ struct IsNan
101101
{
102102
static bool isnan(const T &v)
103103
{
104-
if constexpr (std::is_floating_point<T>::value) {
104+
if constexpr (std::is_floating_point_v<T>) {
105105
return sycl::isnan(v);
106106
}
107107

dpnp/backend/extensions/statistics/histogram_common.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,10 @@
3030
#include <vector>
3131

3232
#include "dpctl4pybind11.hpp"
33+
#include "utils/memory_overlap.hpp"
34+
#include "utils/output_validation.hpp"
3335
#include "utils/type_dispatch.hpp"
36+
3437
#include <pybind11/pybind11.h>
3538

3639
#include "histogram_common.hpp"
@@ -73,6 +76,8 @@ void validate(const usm_ndarray &sample,
7376
return "'" + name_it->second + "'";
7477
};
7578

79+
dpctl::tensor::validation::CheckWritable::throw_if_not_writable(histogram);
80+
7681
auto unequal_queue =
7782
std::find_if(arrays.cbegin(), arrays.cend(), [&](const array_ptr &arr) {
7883
return arr->get_queue() != exec_q;
@@ -95,6 +100,21 @@ void validate(const usm_ndarray &sample,
95100
" parameter is not c-contiguos");
96101
}
97102

103+
auto check_overlaping = [&](const array_ptr &first,
104+
const array_ptr &second) {
105+
const auto &overlap = dpctl::tensor::overlap::MemoryOverlap();
106+
107+
if (overlap(sample, histogram)) {
108+
throw py::value_error(get_name(first) +
109+
" has overlapping memory segments with " +
110+
get_name(second));
111+
}
112+
};
113+
114+
check_overlaping(&sample, &histogram);
115+
check_overlaping(&bins, &histogram);
116+
check_overlaping(weights_ptr, &histogram);
117+
98118
if (bins.get_size() < 2) {
99119
throw py::value_error(get_name(&bins) +
100120
" parameter must have at least 2 elements");

dpnp/dpnp_iface_histograms.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -501,10 +501,10 @@ def histogram(a, bins=10, range=None, density=None, weights=None):
501501
n = dpnp.astype(n_casted, ntype, copy=False)
502502

503503
if density:
504-
db = dpnp.diff(bin_edges).astype(
505-
dpnp.default_float_type(sycl_queue=queue)
504+
db = dpnp.astype(
505+
dpnp.diff(bin_edges), dpnp.default_float_type(sycl_queue=queue)
506506
)
507-
return n / db / n.sum(), bin_edges
507+
return n / db / dpnp.sum(n), bin_edges
508508

509509
return n, bin_edges
510510

0 commit comments

Comments
 (0)