Skip to content

Commit cd9008d

Browse files
Applying review comments
1 parent c7deaff commit cd9008d

File tree

10 files changed

+85
-65
lines changed

10 files changed

+85
-65
lines changed

dpnp/backend/extensions/statistics/bincount.cpp

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -72,9 +72,8 @@ struct BincountEdges
7272
template <typename dT>
7373
bool in_bounds(const dT *val, const boundsT &bounds) const
7474
{
75-
Less<dT> _less;
76-
return !_less(val[0], std::get<0>(bounds)) &&
77-
!_less(std::get<1>(bounds), val[0]) && !IsNan<dT>::isnan(val[0]);
75+
return check_in_bounds(val[0], std::get<0>(bounds),
76+
std::get<1>(bounds));
7877
}
7978

8079
private:
@@ -87,8 +86,8 @@ struct BincountF
8786
{
8887
static sycl::event impl(sycl::queue &exec_q,
8988
const void *vin,
90-
int64_t min,
91-
int64_t max,
89+
const int64_t min,
90+
const int64_t max,
9291
const void *vweights,
9392
void *vout,
9493
const size_t,
@@ -100,22 +99,20 @@ struct BincountF
10099
// shift output pointer by min elements
101100
HistType *out = static_cast<HistType *>(vout) + min;
102101

103-
size_t needed_bins_count = (max - min) + 1;
104-
auto device = exec_q.get_device();
102+
const size_t needed_bins_count = (max - min) + 1;
105103

106-
uint32_t local_size = get_max_local_size(exec_q);
104+
const uint32_t local_size = get_max_local_size(exec_q);
107105

108-
uint32_t WorkPI = 128; // empirically found number
109-
auto nd_range = make_ndrange(size, local_size, WorkPI);
106+
constexpr uint32_t WorkPI = 128; // empirically found number
107+
const auto nd_range = make_ndrange(size, local_size, WorkPI);
110108

111109
return exec_q.submit([&](sycl::handler &cgh) {
112110
cgh.depends_on(depends);
113111
constexpr uint32_t dims = 1;
114112

115113
auto dispatch_bins = [&](const auto &weights) {
116-
auto local_mem_size =
117-
device.get_info<sycl::info::device::local_mem_size>() /
118-
sizeof(T);
114+
const auto local_mem_size =
115+
get_local_mem_size_in_items<T>(exec_q);
119116
if (local_mem_size >= needed_bins_count) {
120117
uint32_t local_hist_count = get_local_hist_copies_count(
121118
local_mem_size, local_size, needed_bins_count);

dpnp/backend/extensions/statistics/common.hpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,19 @@ size_t get_local_mem_size_in_items(const sycl::device &device, size_t reserve)
165165
return get_local_mem_size_in_bytes(device, sizeof(T) * reserve) / sizeof(T);
166166
}
167167

168+
template <typename T>
169+
inline size_t get_local_mem_size_in_items(const sycl::queue &queue)
170+
{
171+
return get_local_mem_size_in_items<T>(queue.get_device());
172+
}
173+
174+
template <typename T>
175+
inline size_t get_local_mem_size_in_items(const sycl::queue &queue,
176+
size_t reserve)
177+
{
178+
return get_local_mem_size_in_items<T>(queue.get_device(), reserve);
179+
}
180+
168181
template <int Dims>
169182
sycl::nd_range<Dims> make_ndrange(const sycl::range<Dims> &global_range,
170183
const sycl::range<Dims> &local_range,

dpnp/backend/extensions/statistics/histogram.cpp

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -94,9 +94,8 @@ struct HistogramEdges
9494
template <typename dT>
9595
bool in_bounds(const dT *val, const boundsT &bounds) const
9696
{
97-
Less<dT> _less;
98-
return !_less(val[0], std::get<0>(bounds)) &&
99-
!_less(std::get<1>(bounds), val[0]) && !IsNan<dT>::isnan(val[0]);
97+
return check_in_bounds(val[0], std::get<0>(bounds),
98+
std::get<1>(bounds));
10099
}
101100

102101
private:
@@ -110,7 +109,7 @@ template <typename T>
110109
using UncachedEdges = HistogramEdges<T, UncachedData<const T, 1>>;
111110

112111
template <typename T, typename BinsT, typename HistType = size_t>
113-
struct histogram_kernel
112+
struct HistogramF
114113
{
115114
static sycl::event impl(sycl::queue &exec_q,
116115
const void *vin,
@@ -185,7 +184,7 @@ struct histogram_kernel
185184
};
186185

187186
template <typename SampleType, typename HistType>
188-
using histogram_kernel_ = histogram_kernel<SampleType, SampleType, HistType>;
187+
using HistogramF_ = HistogramF<SampleType, SampleType, HistType>;
189188

190189
} // namespace
191190

@@ -212,7 +211,7 @@ using SupportedTypes = std::tuple<std::tuple<uint64_t, int64_t>,
212211

213212
Histogram::Histogram() : dispatch_table("sample", "histogram")
214213
{
215-
dispatch_table.populate_dispatch_table<SupportedTypes, histogram_kernel_>();
214+
dispatch_table.populate_dispatch_table<SupportedTypes, HistogramF_>();
216215
}
217216

218217
std::tuple<sycl::event, sycl::event>

dpnp/backend/extensions/statistics/histogram_common.hpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,13 @@ struct Weights
278278
T *data = nullptr;
279279
};
280280

281+
template <typename dT>
282+
bool check_in_bounds(const dT &val, const dT &min, const dT &max)
283+
{
284+
Less<dT> _less;
285+
return !_less(val, min) && !_less(max, val) && !IsNan<dT>::isnan(val);
286+
}
287+
281288
template <typename T, typename HistImpl, typename Edges, typename Weights>
282289
class histogram_kernel;
283290

dpnp/dpnp_iface_histograms.py

Lines changed: 18 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@
3838
"""
3939

4040
import operator
41-
import warnings
4241

4342
import dpctl.utils as dpu
4443
import numpy
@@ -104,16 +103,6 @@ def _ravel_check_a_and_weights(a, weights):
104103
dpnp.check_supported_arrays_type(a)
105104
usm_type = a.usm_type
106105

107-
# ensure that the array is a "subtractable" dtype
108-
if a.dtype == dpnp.bool:
109-
warnings.warn(
110-
f"Converting input from {a.dtype} to {numpy.uint8} "
111-
"for compatibility.",
112-
RuntimeWarning,
113-
stacklevel=3,
114-
)
115-
a = dpnp.astype(a, numpy.uint8)
116-
117106
if weights is not None:
118107
# check that `weights` array has supported type
119108
dpnp.check_supported_arrays_type(weights)
@@ -323,22 +312,26 @@ def bincount(x, weights=None, minlength=None):
323312
324313
Parameters
325314
----------
326-
x : {dpnp.ndarray, usm_ndarray}, 1 dimension, nonnegative ints
327-
Input array.
328-
weights : {dpnp.ndarray, usm_ndarray}, optional
315+
x : {dpnp.ndarray, usm_ndarray}
316+
Input 1-dimensional array with nonnegative integer values.
317+
weights : {None, dpnp.ndarray, usm_ndarray}, optional
329318
Weights, array of the same shape as `x`.
330-
minlength : int, optional
319+
Default: ``None``
320+
minlength : {None, int}, optional
331321
A minimum number of bins for the output array.
322+
Default: ``None``
332323
333324
Returns
334325
-------
335326
out : dpnp.ndarray of ints
336327
The result of binning the input array.
337-
The length of `out` is equal to ``np.amax(x)+1``.
328+
The length of `out` is equal to ``np.amax(x) + 1``.
338329
339330
See Also
340331
--------
341-
dpnp.histogram, dpnp.digitize, dpnp.unique
332+
:obj:`dpnp.histogram` : Compute the histogram of a data set.
333+
:obj:`dpnp.digitize` : Return the indices of the bins to which each value
334+
:obj:`dpnp.unique` : Find the unique elements of an array.
342335
343336
Examples
344337
--------
@@ -349,25 +342,24 @@ def bincount(x, weights=None, minlength=None):
349342
array([1, 3, 1, 1, 0, 0, 0, 1])
350343
351344
>>> x = np.array([0, 1, 1, 3, 2, 1, 7, 23])
352-
>>> np.bincount(x).size == np.amax(x)+1
353-
True
345+
>>> np.bincount(x).size == np.amax(x) + 1
346+
array(True)
354347
355348
The input array needs to be of integer dtype, otherwise a
356349
TypeError is raised:
357350
358-
>>> np.bincount(np.arange(5, dtype=float))
351+
>>> np.bincount(np.arange(5, dtype=np.float32))
359352
Traceback (most recent call last):
360353
...
361-
TypeError: Cannot cast array data from dtype('float64') to dtype('int64')
362-
according to the rule 'safe'
354+
TypeError: x must be an integer array
363355
364356
A possible use of ``bincount`` is to perform sums over
365-
variable-size chunks of an array, using the ``weights`` keyword.
357+
variable-size chunks of an array, using the `weights` keyword.
366358
367-
>>> w = np.array([0.3, 0.5, 0.2, 0.7, 1., -0.6]) # weights
359+
>>> w = np.array([0.3, 0.5, 0.2, 0.7, 1., -0.6], dtype=np.float32) # weights
368360
>>> x = np.array([0, 1, 1, 2, 2, 2])
369-
>>> np.bincount(x, weights=w)
370-
array([ 0.3, 0.7, 1.1])
361+
>>> np.bincount(x, weights=w)
362+
array([0.3, 0.7, 1.1], dtype=float32)
371363
372364
"""
373365

tests/helper.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
from collections.abc import Iterable
21
from sys import platform
32

43
import dpctl

tests/test_histogram.py

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -258,22 +258,6 @@ def test_outliers_normalization_weights(self):
258258
assert_allclose(result_hist, expected_hist)
259259
assert_allclose(result_edges, expected_edges)
260260

261-
@pytest.mark.parametrize("xp", [numpy, dpnp])
262-
def test_bool_conversion(self, xp):
263-
a = xp.array([1, 1, 0], dtype=numpy.uint8)
264-
int_hist, int_edges = xp.histogram(a)
265-
266-
with suppress_warnings() as sup:
267-
rec = sup.record(RuntimeWarning, "Converting input from .*")
268-
269-
v = xp.array([True, True, False])
270-
hist, edges = xp.histogram(v)
271-
272-
# A warning should be issued
273-
assert len(rec) == 1
274-
assert_array_equal(hist, int_hist)
275-
assert_array_equal(edges, int_edges)
276-
277261
@pytest.mark.parametrize("density", [True, False])
278262
def test_weights(self, density):
279263
v = numpy.random.rand(100)
@@ -574,6 +558,13 @@ def test_weights_another_sycl_queue(self):
574558
with assert_raises(ValueError):
575559
dpnp.bincount(v, weights=w)
576560

561+
@pytest.mark.parametrize("xp", [numpy, dpnp])
562+
def test_weights_unsupported_dtype(self, xp):
563+
v = dpnp.arange(5)
564+
w = dpnp.arange(5, dtype=dpnp.complex64)
565+
with assert_raises(ValueError):
566+
dpnp.bincount(v, weights=w)
567+
577568
@pytest.mark.parametrize(
578569
"bins_count",
579570
[10, 10**2, 10**3, 10**4, 10**5, 10**6],

tests/test_sycl_queue.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2460,6 +2460,27 @@ def test_lstsq(m, n, nrhs, device):
24602460
assert_sycl_queue_equal(param_dp.sycl_queue, b_dp.sycl_queue)
24612461

24622462

2463+
@pytest.mark.parametrize("weights", [None, numpy.arange(7, 12)])
2464+
@pytest.mark.parametrize(
2465+
"device",
2466+
valid_devices,
2467+
ids=[device.filter_string for device in valid_devices],
2468+
)
2469+
def test_bincount(weights, device):
2470+
v = numpy.arange(5)
2471+
w = weights
2472+
2473+
iv = dpnp.array(v, device=device)
2474+
iw = None if weights is None else dpnp.array(w, sycl_queue=iv.sycl_queue)
2475+
2476+
expected_hist = numpy.bincount(v, weights=w)
2477+
result_hist = dpnp.bincount(iv, weights=iw)
2478+
assert_array_equal(result_hist, expected_hist)
2479+
2480+
hist_queue = result_hist.sycl_queue
2481+
assert_sycl_queue_equal(hist_queue, iv.sycl_queue)
2482+
2483+
24632484
@pytest.mark.parametrize("weights", [None, numpy.arange(7, 12)])
24642485
@pytest.mark.parametrize(
24652486
"device",

tests/test_usm_type.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1475,11 +1475,10 @@ def test_bincount(usm_type_v, usm_type_w):
14751475
v = dp.arange(5, usm_type=usm_type_v)
14761476
w = dp.arange(7, 12, usm_type=usm_type_w)
14771477

1478-
hist, edges = dp.histogram(v, weights=w)
1478+
hist = dp.bincount(v, weights=w)
14791479
assert v.usm_type == usm_type_v
14801480
assert w.usm_type == usm_type_w
14811481
assert hist.usm_type == du.get_coerced_usm_type([usm_type_v, usm_type_w])
1482-
assert edges.usm_type == du.get_coerced_usm_type([usm_type_v, usm_type_w])
14831482

14841483

14851484
@pytest.mark.parametrize(

tests/third_party/cupy/statistics_tests/test_histogram.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,9 @@ def test_bincount_duplicated_value(self, xp, dtype):
271271
return xp.bincount(x)
272272

273273
@for_all_dtypes_combination_bincount(names=["x_type", "w_type"])
274-
@testing.numpy_cupy_allclose(accept_error=TypeError, type_check=False)
274+
@testing.numpy_cupy_allclose(
275+
accept_error=TypeError, type_check=has_support_aspect64()
276+
)
275277
def test_bincount_with_weight(self, xp, x_type, w_type):
276278
x = testing.shaped_arange((3,), xp, x_type)
277279
w = testing.shaped_arange((3,), xp, w_type)

0 commit comments

Comments
 (0)