Skip to content

Commit e2b6217

Browse files
Fix review comments
1 parent 46ddbe8 commit e2b6217

File tree

4 files changed

+18
-28
lines changed

4 files changed

+18
-28
lines changed

dpnp/backend/extensions/statistics/common.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,8 @@ struct IsNan
101101
{
102102
static bool isnan(const T &v)
103103
{
104-
if constexpr (std::is_floating_point_v<T>) {
104+
if constexpr (std::is_floating_point_v<T> ||
105+
std::is_same_v<T, sycl::half>) {
105106
return sycl::isnan(v);
106107
}
107108

dpnp/backend/extensions/statistics/histogram.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,6 @@
4141
#include "histogram.hpp"
4242
#include "histogram_common.hpp"
4343

44-
#include <iostream>
45-
4644
namespace dpctl_td_ns = dpctl::tensor::type_dispatch;
4745
using dpctl::tensor::usm_ndarray;
4846

@@ -139,7 +137,7 @@ struct histogram_kernel
139137
cgh.depends_on(depends);
140138
constexpr uint32_t dims = 1;
141139

142-
auto dispatch_edges = [&](uint32_t local_mem, auto &weights,
140+
auto dispatch_edges = [&](uint32_t local_mem, const auto &weights,
143141
auto &hist) {
144142
if (device.is_gpu() && (local_mem >= bins_count + 1)) {
145143
auto edges = CachedEdges(bins_edges, bins_count + 1, cgh);
@@ -153,7 +151,7 @@ struct histogram_kernel
153151
}
154152
};
155153

156-
auto dispatch_bins = [&](auto &weights) {
154+
auto dispatch_bins = [&](const auto &weights) {
157155
const auto local_mem_size =
158156
get_local_mem_size_in_items<T>(device);
159157
if (local_mem_size >= bins_count) {

dpnp/backend/extensions/statistics/histogram_common.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,9 +102,13 @@ void validate(const usm_ndarray &sample,
102102

103103
auto check_overlaping = [&](const array_ptr &first,
104104
const array_ptr &second) {
105+
if (first == nullptr || second == nullptr) {
106+
return;
107+
}
108+
105109
const auto &overlap = dpctl::tensor::overlap::MemoryOverlap();
106110

107-
if (overlap(sample, histogram)) {
111+
if (overlap(*first, *second)) {
108112
throw py::value_error(get_name(first) +
109113
" has overlapping memory segments with " +
110114
get_name(second));

dpnp/dpnp_iface_histograms.py

Lines changed: 9 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242

4343
import dpctl.utils as dpu
4444
import numpy
45+
from dpctl.tensor._type_utils import _can_cast
4546

4647
import dpnp
4748

@@ -304,26 +305,10 @@ def _result_type_for_device(dtype1, dtype2, device):
304305
return map_dtype_to_device(rt, device)
305306

306307

307-
def _supported_by_device(dtype, device):
308-
if dtype in [dpnp.float64, dpnp.complex128]:
309-
return device.has_aspect_fp64
310-
311-
if dtype == dpnp.float16:
312-
return device.has_aspect_fp16
313-
314-
return True
315-
316-
317-
def _can_cast(dtype1, dtype2, device):
318-
if not _supported_by_device(dtype1, device) or not _supported_by_device(
319-
dtype2, device
320-
):
321-
return False
322-
323-
return dpnp.can_cast(dtype1, dtype2)
324-
325-
326308
def _align_dtypes(a_dtype, bins_dtype, ntype, supported_types, device):
309+
has_fp64 = device.has_aspect_fp64
310+
has_fp16 = device.has_aspect_fp16
311+
327312
a_bin_dtype = _result_type_for_device(a_dtype, bins_dtype, device)
328313

329314
# histogram implementation doesn't support uint64 as histogram type
@@ -335,11 +320,12 @@ def _align_dtypes(a_dtype, bins_dtype, ntype, supported_types, device):
335320
return a_bin_dtype, ntype
336321

337322
for sample_type, hist_type in supported_types:
338-
if _can_cast(a_bin_dtype, sample_type, device) and _can_cast(
339-
ntype, hist_type, device
340-
):
323+
if _can_cast(
324+
a_bin_dtype, sample_type, has_fp16, has_fp64
325+
) and _can_cast(ntype, hist_type, has_fp16, has_fp64):
341326
return sample_type, hist_type
342327

328+
# should not happen
343329
return None, None
344330

345331

@@ -468,6 +454,7 @@ def histogram(a, bins=10, range=None, density=None, weights=None):
468454
# host usm memory
469455
n_usm_type = "device" if usm_type == "host" else usm_type
470456

457+
# histogram implementation requires output array to be filled with zeros
471458
n_casted = dpnp.zeros(
472459
bin_edges.size - 1,
473460
dtype=hist_dtype,

0 commit comments

Comments
 (0)