Skip to content

Commit 520e3d1

Browse files
committed
Move first iteration of indices validation into loop
1 parent 4138d8a commit 520e3d1

File tree

1 file changed

+16
-28
lines changed

1 file changed

+16
-28
lines changed

dpnp/backend/extensions/indexing/choose.cpp

Lines changed: 16 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -239,39 +239,27 @@ std::pair<sycl::event, sycl::event>
239239
auto sh_nelems = std::max<int>(nd, 1);
240240
std::vector<py::ssize_t> chc_strides(n_chcs * sh_nelems, 0);
241241

242-
// first iteration for first choice array chc_rep
243-
if (overlap(dst, chc_rep)) {
244-
throw py::value_error("Arrays index overlapping segments of memory");
245-
}
246-
247-
// chc_strides is initialized to 0 for 0D choices, so skip
248-
if (nd > 0) {
249-
auto chc_strides_ = chc_rep.get_strides_vector();
250-
std::copy(chc_strides_.begin(), chc_strides_.end(),
251-
chc_strides.begin());
252-
}
253-
254-
chc_ptrs.push_back(chc_rep.get_data());
255-
chc_offsets.push_back(py::ssize_t(0));
256-
257-
for (auto i = 1; i < n_chcs; ++i) {
242+
for (auto i = 0; i < n_chcs; ++i) {
258243
dpctl::tensor::usm_ndarray chc_ = chcs[i];
259244

260245
// ndim, type, and shape are checked against the first array
261-
if (!(chc_.get_ndim() == nd)) {
262-
throw py::value_error("Choice array dimensions are not the same");
263-
}
246+
if (i > 0) {
247+
if (!(chc_.get_ndim() == nd)) {
248+
throw py::value_error(
249+
"Choice array dimensions are not the same");
250+
}
264251

265-
if (!(chc_type_id ==
266-
array_types.typenum_to_lookup_id(chc_.get_typenum()))) {
267-
throw py::type_error(
268-
"Choice array data types are not all the same.");
269-
}
252+
if (!(chc_type_id ==
253+
array_types.typenum_to_lookup_id(chc_.get_typenum()))) {
254+
throw py::type_error(
255+
"Choice array data types are not all the same.");
256+
}
270257

271-
const py::ssize_t *chc_shape_ = chc_.get_shape_raw();
272-
for (int dim = 0; dim < nd; ++dim) {
273-
if (!(chc_shape[dim] == chc_shape_[dim])) {
274-
throw py::value_error("Choice shapes are not all equal.");
258+
const py::ssize_t *chc_shape_ = chc_.get_shape_raw();
259+
for (int dim = 0; dim < nd; ++dim) {
260+
if (!(chc_shape[dim] == chc_shape_[dim])) {
261+
throw py::value_error("Choice shapes are not all equal.");
262+
}
275263
}
276264
}
277265

0 commit comments

Comments
 (0)