3838
3939#include " histogram_common.hpp"
4040
41+ #include " validation_utils.hpp"
42+
4143namespace dpctl_td_ns = dpctl::tensor::type_dispatch;
4244using dpctl::tensor::usm_ndarray;
4345using dpctl_td_ns::typenum_t ;
@@ -46,6 +48,15 @@ namespace statistics
4648{
4749using common::CeilDiv;
4850
51+ using validation::array_names;
52+ using validation::array_ptr;
53+
54+ using validation::check_max_dims;
55+ using validation::check_num_dims;
56+ using validation::check_size_at_least;
57+ using validation::common_checks;
58+ using validation::name_of;
59+
4960namespace histogram
5061{
5162
@@ -55,11 +66,9 @@ void validate(const usm_ndarray &sample,
5566 const usm_ndarray &histogram)
5667{
5768 auto exec_q = sample.get_queue ();
58- using array_ptr = const usm_ndarray *;
5969
6070 std::vector<array_ptr> arrays{&sample, &histogram};
61- std::unordered_map<array_ptr, std::string> names = {
62- {arrays[0 ], " sample" }, {arrays[1 ], " histogram" }};
71+ array_names names = {{arrays[0 ], " sample" }, {arrays[1 ], " histogram" }};
6372
6473 array_ptr bins_ptr = nullptr ;
6574
@@ -77,117 +86,48 @@ void validate(const usm_ndarray &sample,
7786 names.insert ({weights_ptr, " weights" });
7887 }
7988
80- auto get_name = [&](const array_ptr &arr) {
81- auto name_it = names.find (arr);
82- assert (name_it != names.end ());
83-
84- return " '" + name_it->second + " '" ;
85- };
86-
87- dpctl::tensor::validation::CheckWritable::throw_if_not_writable (histogram);
88-
89- auto unequal_queue =
90- std::find_if (arrays.cbegin (), arrays.cend (), [&](const array_ptr &arr) {
91- return arr->get_queue () != exec_q;
92- });
93-
94- if (unequal_queue != arrays.cend ()) {
95- throw py::value_error (
96- get_name (*unequal_queue) +
97- " parameter has incompatible queue with parameter " +
98- get_name (&sample));
99- }
100-
101- auto non_contig_array =
102- std::find_if (arrays.cbegin (), arrays.cend (), [&](const array_ptr &arr) {
103- return !arr->is_c_contiguous ();
104- });
89+ common_checks ({&sample, bins.has_value () ? &bins.value () : nullptr ,
90+ weights.has_value () ? &weights.value () : nullptr },
91+ {&histogram}, names);
10592
106- if (non_contig_array != arrays.cend ()) {
107- throw py::value_error (get_name (*non_contig_array) +
108- " parameter is not c-contiguos" );
109- }
93+ check_size_at_least (bins_ptr, 2 , names);
11094
111- auto check_overlaping = [&](const array_ptr &first,
112- const array_ptr &second) {
113- if (first == nullptr || second == nullptr ) {
114- return ;
115- }
116-
117- const auto &overlap = dpctl::tensor::overlap::MemoryOverlap ();
118-
119- if (overlap (*first, *second)) {
120- throw py::value_error (get_name (first) +
121- " has overlapping memory segments with " +
122- get_name (second));
123- }
124- };
125-
126- check_overlaping (&sample, &histogram);
127- check_overlaping (bins_ptr, &histogram);
128- check_overlaping (weights_ptr, &histogram);
129-
130- if (bins_ptr && bins_ptr->get_size () < 2 ) {
131- throw py::value_error (get_name (bins_ptr) +
132- " parameter must have at least 2 elements" );
133- }
134-
135- if (histogram.get_size () < 1 ) {
136- throw py::value_error (get_name (&histogram) +
137- " parameter must have at least 1 element" );
138- }
139-
140- if (histogram.get_ndim () != 1 ) {
141- throw py::value_error (get_name (&histogram) +
142- " parameter must be 1d. Actual " +
143- std::to_string (histogram.get_ndim ()) + " d" );
144- }
95+ check_size_at_least (&histogram, 1 , names);
96+ check_num_dims (&histogram, 1 , names);
14597
14698 if (weights_ptr) {
147- if (weights_ptr->get_ndim () != 1 ) {
148- throw py::value_error (
149- get_name (weights_ptr) + " parameter must be 1d. Actual " +
150- std::to_string (weights_ptr->get_ndim ()) + " d" );
151- }
99+ check_num_dims (weights_ptr, 1 , names);
152100
153101 auto sample_size = sample.get_size ();
154102 auto weights_size = weights_ptr->get_size ();
155103 if (sample.get_size () != weights_ptr->get_size ()) {
156- throw py::value_error (
157- get_name (&sample) + " size (" + std::to_string (sample_size) +
158- " ) and " + get_name (weights_ptr) + " size (" +
159- std::to_string (weights_size) + " )" + " must match" );
104+ throw py::value_error (name_of (&sample, names) + " size (" +
105+ std::to_string (sample_size) + " ) and " +
106+ name_of (weights_ptr, names) + " size (" +
107+ std::to_string (weights_size) + " )" +
108+ " must match" );
160109 }
161110 }
162111
163- if (sample.get_ndim () > 2 ) {
164- throw py::value_error (
165- get_name (&sample) +
166- " parameter must have no more than 2 dimensions. Actual " +
167- std::to_string (sample.get_ndim ()) + " d" );
168- }
112+ check_max_dims (&sample, 2 , names);
169113
170114 if (sample.get_ndim () == 1 ) {
171- if (bins_ptr != nullptr && bins_ptr->get_ndim () != 1 ) {
172- throw py::value_error (get_name (&sample) + " parameter is 1d, but " +
173- get_name (bins_ptr) + " is " +
174- std::to_string (bins_ptr->get_ndim ()) + " d" );
175- }
115+ check_num_dims (bins_ptr, 1 , names);
176116 }
177117 else if (sample.get_ndim () == 2 ) {
178118 auto sample_count = sample.get_shape (0 );
179119 auto expected_dims = sample.get_shape (1 );
180120
181121 if (bins_ptr != nullptr && bins_ptr->get_ndim () != expected_dims) {
182- throw py::value_error (get_name (&sample) + " parameter has shape { " +
183- std::to_string (sample_count) + " x " +
184- std::to_string (expected_dims ) + " } " +
185- " , so " + get_name (bins_ptr) +
186- " parameter expected to be " +
187- std::to_string (expected_dims) +
188- " d. "
189- " Actual " +
190- std::to_string (bins->get_ndim ()) + " d" );
122+ throw py::value_error (
123+ name_of (&sample, names) + " parameter has shape { " +
124+ std::to_string (sample_count ) + " x " +
125+ std::to_string (expected_dims) + " } " + " , so " +
126+ name_of (bins_ptr, names) + " parameter expected to be " +
127+ std::to_string (expected_dims) +
128+ " d. "
129+ " Actual " +
130+ std::to_string (bins->get_ndim ()) + " d" );
191131 }
192132 }
193133
@@ -199,17 +139,17 @@ void validate(const usm_ndarray &sample,
199139
200140 if (histogram.get_size () != expected_hist_size) {
201141 throw py::value_error (
202- get_name (&histogram) + " and " + get_name (bins_ptr) +
203- " shape mismatch. " + get_name (&histogram) +
204- " expected to have size = " +
142+ name_of (&histogram, names ) + " and " +
143+ name_of (bins_ptr, names) + " shape mismatch. " +
144+ name_of (&histogram, names) + " expected to have size = " +
205145 std::to_string (expected_hist_size) + " . Actual " +
206146 std::to_string (histogram.get_size ()));
207147 }
208148 }
209149
210150 int64_t max_hist_size = std::numeric_limits<uint32_t >::max () - 1 ;
211151 if (histogram.get_size () > max_hist_size) {
212- throw py::value_error (get_name (&histogram) +
152+ throw py::value_error (name_of (&histogram, names ) +
213153 " parameter size expected to be less than " +
214154 std::to_string (max_hist_size) + " . Actual " +
215155 std::to_string (histogram.get_size ()));
@@ -225,7 +165,7 @@ void validate(const usm_ndarray &sample,
225165 if (!_64bit_atomics) {
226166 auto device_name = device.get_info <sycl::info::device::name>();
227167 throw py::value_error (
228- get_name (&histogram) +
168+ name_of (&histogram, names ) +
229169 " parameter has 64-bit type, but 64-bit atomics " +
230170 " are not supported for " + device_name);
231171 }
0 commit comments