2424// *****************************************************************************
2525
2626#include < algorithm>
27+ #include < limits>
2728#include < string>
2829#include < unordered_map>
2930#include < vector>
3031
3132#include " dpctl4pybind11.hpp"
33+ #include " utils/type_dispatch.hpp"
3234#include < pybind11/pybind11.h>
3335
3436#include " histogram_common.hpp"
3537
38+ namespace dpctl_td_ns = dpctl::tensor::type_dispatch;
3639using dpctl::tensor::usm_ndarray;
40+ using dpctl_td_ns::typenum_t ;
3741
3842namespace histogram
3943{
@@ -45,9 +49,8 @@ void validate(const usm_ndarray &sample,
4549{
4650 auto exec_q = sample.get_queue ();
4751 using array_ptr = const usm_ndarray *;
48- using array_list = std::vector<array_ptr>;
4952
50- array_list arrays{&sample, &bins, &histogram};
53+ std::vector<array_ptr> arrays{&sample, &bins, &histogram};
5154 std::unordered_map<array_ptr, std::string> names = {
5255 {arrays[0 ], " sample" }, {arrays[1 ], " bins" }, {arrays[2 ], " histogram" }};
5356
@@ -94,10 +97,21 @@ void validate(const usm_ndarray &sample,
9497 std::to_string (histogram.get_ndim ()) + " d" );
9598 }
9699
97- if (weights_ptr && weights_ptr->get_ndim () != 1 ) {
98- throw py::value_error (get_name (weights_ptr) +
99- " parameter must be 1d. Actual " +
100- std::to_string (weights_ptr->get_ndim ()) + " d" );
100+ if (weights_ptr) {
101+ if (weights_ptr->get_ndim () != 1 ) {
102+ throw py::value_error (
103+ get_name (weights_ptr) + " parameter must be 1d. Actual " +
104+ std::to_string (weights_ptr->get_ndim ()) + " d" );
105+ }
106+
107+ auto sample_size = sample.get_size ();
108+ auto weights_size = weights_ptr->get_size ();
109+ if (sample.get_size () != weights_ptr->get_size ()) {
110+ throw py::value_error (
111+ get_name (&sample) + " size (" + std::to_string (sample_size) +
112+ " ) and " + get_name (weights_ptr) + " size (" +
113+ std::to_string (weights_size) + " )" + " must match" );
114+ }
101115 }
102116
103117 if (sample.get_ndim () > 2 ) {
@@ -143,6 +157,32 @@ void validate(const usm_ndarray &sample,
143157 " expected to have size = " + std::to_string (expected_hist_size) +
144158 " . Actual " + std::to_string (histogram.get_size ()));
145159 }
160+
161+ int64_t max_hist_size = std::numeric_limits<uint32_t >::max () - 1 ;
162+ if (histogram.get_size () > max_hist_size) {
163+ throw py::value_error (get_name (&histogram) +
164+ " parameter size expected to be less than " +
165+ std::to_string (max_hist_size) + " . Actual " +
166+ std::to_string (histogram.get_size ())
167+ );
168+ }
169+
170+ auto array_types = dpctl_td_ns::usm_ndarray_types ();
171+ auto hist_type = static_cast <typenum_t >(
172+ array_types.typenum_to_lookup_id (histogram.get_typenum ()));
173+ if (histogram.get_elemsize () == 8 && hist_type != typenum_t ::CFLOAT) {
174+ auto device = exec_q.get_device ();
175+ bool _64bit_atomics = device.has (sycl::aspect::atomic64);
176+
177+ if (!_64bit_atomics) {
178+ auto device_name = device.get_info <sycl::info::device::name>();
179+ throw py::value_error (
180+ get_name (&histogram) +
181+ " parameter has 64-bit type, but 64-bit atomics " +
182+ " are not supported for " + device_name
183+ );
184+ }
185+ }
146186}
147187
148188} // namespace histogram
0 commit comments