Skip to content

Commit 406dd47

Browse files
Fix review comments
1 parent 7fc1df1 commit 406dd47

File tree

3 files changed

+26
-6
lines changed

3 files changed

+26
-6
lines changed

dpnp/backend/extensions/statistics/common.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ struct IsNan
101101
{
102102
static bool isnan(const T &v)
103103
{
104-
if constexpr (std::is_floating_point<T>::value) {
104+
if constexpr (std::is_floating_point_v<T>) {
105105
return sycl::isnan(v);
106106
}
107107

dpnp/backend/extensions/statistics/histogram_common.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,10 @@
3030
#include <vector>
3131

3232
#include "dpctl4pybind11.hpp"
33+
#include "utils/memory_overlap.hpp"
34+
#include "utils/output_validation.hpp"
3335
#include "utils/type_dispatch.hpp"
36+
3437
#include <pybind11/pybind11.h>
3538

3639
#include "histogram_common.hpp"
@@ -73,6 +76,8 @@ void validate(const usm_ndarray &sample,
7376
return "'" + name_it->second + "'";
7477
};
7578

79+
dpctl::tensor::validation::CheckWritable::throw_if_not_writable(histogram);
80+
7681
auto unequal_queue =
7782
std::find_if(arrays.cbegin(), arrays.cend(), [&](const array_ptr &arr) {
7883
return arr->get_queue() != exec_q;
@@ -95,6 +100,21 @@ void validate(const usm_ndarray &sample,
95100
" parameter is not c-contiguos");
96101
}
97102

103+
auto check_overlaping = [&](const array_ptr &first,
104+
const array_ptr &second) {
105+
const auto &overlap = dpctl::tensor::overlap::MemoryOverlap();
106+
107+
if (overlap(sample, histogram)) {
108+
throw py::value_error(get_name(first) +
109+
" has overlapping memory segments with " +
110+
get_name(second));
111+
}
112+
};
113+
114+
check_overlaping(&sample, &histogram);
115+
check_overlaping(&bins, &histogram);
116+
check_overlaping(weights_ptr, &histogram);
117+
98118
if (bins.get_size() < 2) {
99119
throw py::value_error(get_name(&bins) +
100120
" parameter must have at least 2 elements");

dpnp/dpnp_iface_histograms.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -486,25 +486,25 @@ def histogram(a, bins=10, range=None, density=None, weights=None):
486486
)
487487
n_usm = dpnp.get_usm_ndarray(n_casted)
488488

489-
ht_ev, mem_ev = statistics_ext.histogram(
489+
mem_ev, ht_ev = statistics_ext.histogram(
490490
a_usm,
491491
bins_usm,
492492
weights_usm,
493493
n_usm,
494494
depends=_manager.submitted_events,
495495
)
496-
_manager.add_event_pair(ht_ev, mem_ev)
496+
_manager.add_event_pair(mem_ev, ht_ev)
497497

498498
if usm_type != n_usm_type:
499499
n = dpnp.asarray(n_casted, dtype=ntype, usm_type=usm_type)
500500
else:
501501
n = dpnp.astype(n_casted, ntype, copy=False)
502502

503503
if density:
504-
db = dpnp.diff(bin_edges).astype(
505-
dpnp.default_float_type(sycl_queue=queue)
504+
db = dpnp.astype(
505+
dpnp.diff(bin_edges), dpnp.default_float_type(sycl_queue=queue)
506506
)
507-
return n / db / n.sum(), bin_edges
507+
return n / db / dpnp.sum(n), bin_edges
508508

509509
return n, bin_edges
510510

0 commit comments

Comments
 (0)