Skip to content

Commit e296d87

Browse files
committed
Moved advanced_indexing pointer range validation
1 parent 877c3c7 commit e296d87

File tree

1 file changed

+22
-24
lines changed

1 file changed

+22
-24
lines changed

dpctl/tensor/libtensor/source/advanced_indexing.cpp

Lines changed: 22 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -99,8 +99,6 @@ std::vector<sycl::event> _populate_packed_shapes_strides_for_indexing(
9999
std::shared_ptr<shT> packed_host_axes_shapes_strides_shp =
100100
std::make_shared<shT>(2 * k + along_sh_elems, allocator);
101101

102-
// can be made more efficient by checking if inp_nd > 1, then performing
103-
// same treatment of orthog_sh_elems as for 0D (orthog will not exist)
104102
if (inp_nd > 0) {
105103
std::copy(inp_shape, inp_shape + axis_start,
106104
packed_host_shapes_strides_shp->begin());
@@ -403,6 +401,17 @@ usm_ndarray_take(dpctl::tensor::usm_ndarray src,
403401
}
404402
}
405403

404+
// destination must be ample enough to accommodate all elements
405+
{
406+
size_t range =
407+
static_cast<size_t>(dst_offsets.second - dst_offsets.first);
408+
if ((range + 1) < (orthog_nelems * ind_nelems)) {
409+
throw py::value_error(
410+
"Destination array can not accommodate all the "
411+
"elements of source array.");
412+
}
413+
}
414+
406415
auto ind_sh_elems = (ind_nd > 0) ? ind_nd : 1;
407416

408417
std::vector<char *> ind_ptrs;
@@ -580,17 +589,6 @@ usm_ndarray_take(dpctl::tensor::usm_ndarray src,
580589
const py::ssize_t *src_strides = src.get_strides_raw();
581590
const py::ssize_t *dst_strides = dst.get_strides_raw();
582591

583-
// destination must be ample enough to accommodate all elements
584-
{
585-
size_t range =
586-
static_cast<size_t>(dst_offsets.second - dst_offsets.first);
587-
if ((range + 1) < (orthog_nelems * ind_nelems)) {
588-
throw py::value_error(
589-
"Destination array can not accommodate all the "
590-
"elements of source array.");
591-
}
592-
}
593-
594592
// packed_shapes_strides = [src_shape[:axis] + src_shape[axis+k:],
595593
// src_strides[:axis] + src_strides[axis+k:],
596594
// dst_strides[:axis] + dst_strides[axis+k:]]
@@ -765,6 +763,17 @@ usm_ndarray_put(dpctl::tensor::usm_ndarray dst,
765763
throw py::value_error("Arrays index overlapping segments of memory");
766764
}
767765

766+
// destination must be ample enough to accommodate all possible elements
767+
{
768+
size_t range =
769+
static_cast<size_t>(dst_offsets.second - dst_offsets.first);
770+
if ((range + 1) < dst_nelems) {
771+
throw py::value_error(
772+
"Destination array can not accommodate all the "
773+
"elements of source array.");
774+
}
775+
}
776+
768777
int dst_typenum = dst.get_typenum();
769778
int val_typenum = val.get_typenum();
770779

@@ -965,17 +974,6 @@ usm_ndarray_put(dpctl::tensor::usm_ndarray dst,
965974
const py::ssize_t *dst_strides = dst.get_strides_raw();
966975
const py::ssize_t *val_strides = val.get_strides_raw();
967976

968-
// destination must be ample enough to accommodate all possible elements
969-
{
970-
size_t range =
971-
static_cast<size_t>(dst_offsets.second - dst_offsets.first);
972-
if ((range + 1) < dst_nelems) {
973-
throw py::value_error(
974-
"Destination array can not accommodate all the "
975-
"elements of source array.");
976-
}
977-
}
978-
979977
// packed_shapes_strides = [dst_shape[:axis] + dst_shape[axis+k:],
980978
// dst_strides[:axis] + dst_strides[axis+k:],
981979
// val_strides[:axis] + val_strides[axis+k:]]

0 commit comments

Comments
 (0)