Skip to content

Commit 6239eb7

Browse files
committed
Changes to advanced indexing
- Clipping now clips indices to -n <= i < n for n = axis size - Fixed a segfault caused by a typo when copying strides
1 parent 832a981 commit 6239eb7

File tree

2 files changed

+4
-2
lines changed

2 files changed

+4
-2
lines changed

dpctl/tensor/libtensor/include/kernels/advanced_indexing.hpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,9 @@ template <typename indT> class ClipIndex
5454
{
5555
max_item = (max_item > 0) ? max_item : 1;
5656
py::ssize_t clip_ind = static_cast<py::ssize_t>(ind);
57-
ind = (ind < 0) ? 0 : (clip_ind >= max_item) ? (max_item - 1) : ind;
57+
ind = (ind < 0) ? (clip_ind <= -max_item) ? (0) : (clip_ind + max_item)
58+
: (clip_ind >= max_item) ? (max_item - 1)
59+
: ind;
5860
return;
5961
}
6062
};

dpctl/tensor/libtensor/source/advanced_indexing.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ std::vector<sycl::event> _populate_packed_shapes_strides_for_indexing(
197197
std::copy(arr_strides, arr_strides + axis_start,
198198
packed_host_shapes_strides_shp->begin() +
199199
2 * orthog_sh_elems);
200-
std::copy(arr_strides + axis_start + ind_nd, arr_strides + inp_nd,
200+
std::copy(arr_strides + axis_start + ind_nd, arr_strides + arr_nd,
201201
packed_host_shapes_strides_shp->begin() +
202202
2 * orthog_sh_elems + axis_start);
203203
std::copy(arr_strides + axis_start,

0 commit comments

Comments
 (0)