@@ -88,16 +88,14 @@ void validate(const usm_ndarray &sample,
8888 {&histogram}, names);
8989
9090 check_size_at_least (bins_ptr, 2 , names);
91-
9291 check_size_at_least (&histogram, 1 , names);
93- check_num_dims (&histogram, 1 , names);
9492
9593 if (weights_ptr) {
9694 check_num_dims (weights_ptr, 1 , names);
9795
98- auto sample_size = sample.get_size ( );
96+ auto sample_size = sample.get_shape ( 0 );
9997 auto weights_size = weights_ptr->get_size ();
100- if (sample. get_size () != weights_ptr->get_size ()) {
98+ if (sample_size != weights_ptr->get_size ()) {
10199 throw py::value_error (name_of (&sample, names) + " size (" +
102100 std::to_string (sample_size) + " ) and " +
103101 name_of (weights_ptr, names) + " size (" +
@@ -110,61 +108,74 @@ void validate(const usm_ndarray &sample,
110108
111109 if (sample.get_ndim () == 1 ) {
112110 check_num_dims (bins_ptr, 1 , names);
111+
112+ if (bins_ptr && histogram.get_size () != bins_ptr->get_size () - 1 ) {
113+ auto hist_size = histogram.get_size ();
114+ auto bins_size = bins_ptr->get_size ();
115+ throw py::value_error (
116+ name_of (&histogram, names) + " parameter and " +
117+ name_of (bins_ptr, names) + " parameters shape mismatch. " +
118+ name_of (&histogram, names) + " size is " +
119+ std::to_string (hist_size) + name_of (bins_ptr, names) +
120+ " must have size " + std::to_string (hist_size + 1 ) +
121+ " but have " + std::to_string (bins_size));
122+ }
113123 }
114124 else if (sample.get_ndim () == 2 ) {
115125 auto sample_count = sample.get_shape (0 );
116126 auto expected_dims = sample.get_shape (1 );
117127
118- if (bins_ptr != nullptr && bins_ptr-> get_ndim () != expected_dims) {
128+ if (histogram. get_ndim () != expected_dims) {
119129 throw py::value_error (
120- name_of (&sample, names) + " parameter has shape { " +
121- std::to_string (sample_count) + " x " +
122- std::to_string (expected_dims) + " } " + " , so " +
123- name_of (bins_ptr , names) + " parameter expected to be " +
130+ name_of (&sample, names) + " parameter has shape ( " +
131+ std::to_string (sample_count) + " , " +
132+ std::to_string (expected_dims) + " ) " + " , so " +
133+ name_of (&histogram , names) + " parameter expected to be " +
124134 std::to_string (expected_dims) +
125135 " d. "
126136 " Actual " +
127- std::to_string (bins-> get_ndim ()) + " d" );
137+ std::to_string (histogram. get_ndim ()) + " d" );
128138 }
129- }
130139
131- if (bins_ptr != nullptr ) {
132- py::ssize_t expected_hist_size = 1 ;
133- for (int i = 0 ; i < bins_ptr->get_ndim (); ++i) {
134- expected_hist_size *= (bins_ptr->get_shape (i) - 1 );
140+ if (bins_ptr != nullptr ) {
141+ py::ssize_t expected_bins_size = 0 ;
142+ for (int i = 0 ; i < histogram.get_ndim (); ++i) {
143+ expected_bins_size += histogram.get_shape (i) + 1 ;
144+ }
145+
146+ auto actual_bins_size = bins_ptr->get_size ();
147+ if (actual_bins_size != expected_bins_size) {
148+ throw py::value_error (
149+ name_of (&histogram, names) + " and " +
150+ name_of (bins_ptr, names) + " shape mismatch. " +
151+ name_of (bins_ptr, names) + " expected to have size = " +
152+ std::to_string (expected_bins_size) + " . Actual " +
153+ std::to_string (actual_bins_size));
154+ }
135155 }
136156
137- if (histogram.get_size () != expected_hist_size) {
138- throw py::value_error (
139- name_of (&histogram, names) + " and " +
140- name_of (bins_ptr, names) + " shape mismatch. " +
141- name_of (&histogram, names) + " expected to have size = " +
142- std::to_string (expected_hist_size) + " . Actual " +
143- std::to_string (histogram.get_size ()));
157+ int64_t max_hist_size = std::numeric_limits<uint32_t >::max () - 1 ;
158+ if (histogram.get_size () > max_hist_size) {
159+ throw py::value_error (name_of (&histogram, names) +
160+ " parameter size expected to be less than " +
161+ std::to_string (max_hist_size) + " . Actual " +
162+ std::to_string (histogram.get_size ()));
144163 }
145- }
146-
147- int64_t max_hist_size = std::numeric_limits<uint32_t >::max () - 1 ;
148- if (histogram.get_size () > max_hist_size) {
149- throw py::value_error (name_of (&histogram, names) +
150- " parameter size expected to be less than " +
151- std::to_string (max_hist_size) + " . Actual " +
152- std::to_string (histogram.get_size ()));
153- }
154164
155- auto array_types = dpctl_td_ns::usm_ndarray_types ();
156- auto hist_type = static_cast <typenum_t >(
157- array_types.typenum_to_lookup_id (histogram.get_typenum ()));
158- if (histogram.get_elemsize () == 8 && hist_type != typenum_t ::CFLOAT) {
159- auto device = exec_q.get_device ();
160- bool _64bit_atomics = device.has (sycl::aspect::atomic64);
161-
162- if (!_64bit_atomics) {
163- auto device_name = device.get_info <sycl::info::device::name>();
164- throw py::value_error (
165- name_of (&histogram, names) +
166- " parameter has 64-bit type, but 64-bit atomics " +
167- " are not supported for " + device_name);
165+ auto array_types = dpctl_td_ns::usm_ndarray_types ();
166+ auto hist_type = static_cast <typenum_t >(
167+ array_types.typenum_to_lookup_id (histogram.get_typenum ()));
168+ if (histogram.get_elemsize () == 8 && hist_type != typenum_t ::CFLOAT) {
169+ auto device = exec_q.get_device ();
170+ bool _64bit_atomics = device.has (sycl::aspect::atomic64);
171+
172+ if (!_64bit_atomics) {
173+ auto device_name = device.get_info <sycl::info::device::name>();
174+ throw py::value_error (
175+ name_of (&histogram, names) +
176+ " parameter has 64-bit type, but 64-bit atomics " +
177+ " are not supported for " + device_name);
178+ }
168179 }
169180 }
170181}
0 commit comments