Skip to content

Commit 728b8e6

Browse files
committed
Fixed missing cast for indices clip/wrap
1 parent a0895be commit 728b8e6

File tree

1 file changed

+18
-18
lines changed

1 file changed

+18
-18
lines changed

dpctl/tensor/libtensor/include/kernels/advanced_indexing.hpp

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -45,34 +45,32 @@ namespace py = pybind11;
4545
template <typename ProjectorT, typename Ty, typename indT> class take_kernel;
4646
template <typename ProjectorT, typename Ty, typename indT> class put_kernel;
4747

48-
template <typename indT> class ClipIndex
48+
class ClipIndex
4949
{
5050
public:
5151
ClipIndex() = default;
5252

53-
void operator()(py::ssize_t max_item, indT &ind) const
53+
void operator()(py::ssize_t max_item, py::ssize_t &ind) const
5454
{
5555
max_item = (max_item > 0) ? max_item : 1;
56-
py::ssize_t clip_ind = static_cast<py::ssize_t>(ind);
57-
ind = (ind < 0) ? (clip_ind <= -max_item) ? (0) : (clip_ind + max_item)
58-
: (clip_ind >= max_item) ? (max_item - 1)
59-
: ind;
56+
ind = (ind < 0) ? (ind <= -max_item) ? (0) : (ind + max_item)
57+
: (ind >= max_item) ? (max_item - 1)
58+
: ind;
6059
return;
6160
}
6261
};
6362

64-
template <typename indT> class WrapIndex
63+
class WrapIndex
6564
{
6665
public:
6766
WrapIndex() = default;
6867

69-
void operator()(py::ssize_t max_item, indT &ind) const
68+
void operator()(py::ssize_t max_item, py::ssize_t &ind) const
7069
{
7170
max_item = (max_item > 0) ? max_item : 1;
72-
py::ssize_t wrap_ind = static_cast<py::ssize_t>(ind);
73-
ind = (ind < 0) ? max_item - (-wrap_ind % max_item)
74-
: (wrap_ind >= max_item) ? wrap_ind % max_item
75-
: ind;
71+
ind = (ind < 0) ? max_item - (-ind % max_item)
72+
: (ind >= max_item) ? ind % max_item
73+
: ind;
7674
return;
7775
}
7876
};
@@ -146,7 +144,8 @@ template <typename ProjectorT, typename T, typename indT> class TakeFunctor
146144
ind_shape_and_strides_ + ((axis_idx + 1) * ind_nd_),
147145
ind_arr_idx);
148146
indT *ind_data = reinterpret_cast<indT *>(ind_[axis_idx]);
149-
indT i = ind_data[ind_arr_idx + ind_offsets_[axis_idx]];
147+
py::ssize_t i = static_cast<py::ssize_t>(
148+
ind_data[ind_arr_idx + ind_offsets_[axis_idx]]);
150149
proj(axes_shape_and_strides_[axis_idx], i);
151150
src_orthog_idx += i * axes_shape_and_strides_[k_ + axis_idx];
152151
}
@@ -282,7 +281,8 @@ template <typename ProjectorT, typename T, typename indT> class PutFunctor
282281
ind_shape_and_strides_ + ((axis_idx + 1) * ind_nd_),
283282
ind_arr_idx);
284283
indT *ind_data = reinterpret_cast<indT *>(ind_[axis_idx]);
285-
indT i = ind_data[ind_arr_idx + ind_offsets_[axis_idx]];
284+
py::ssize_t i = static_cast<py::ssize_t>(
285+
ind_data[ind_arr_idx + ind_offsets_[axis_idx]]);
286286
proj(axes_shape_and_strides_[axis_idx], i);
287287
dst_orthog_idx += i * axes_shape_and_strides_[k_ + axis_idx];
288288
}
@@ -355,7 +355,7 @@ template <typename fnT, typename T, typename indT> struct TakeWrapFactory
355355
{
356356
if constexpr (std::is_integral<indT>::value &&
357357
!std::is_same<indT, bool>::value) {
358-
fnT fn = take_impl<WrapIndex<indT>, T, indT>;
358+
fnT fn = take_impl<WrapIndex, T, indT>;
359359
return fn;
360360
}
361361
else {
@@ -371,7 +371,7 @@ template <typename fnT, typename T, typename indT> struct TakeClipFactory
371371
{
372372
if constexpr (std::is_integral<indT>::value &&
373373
!std::is_same<indT, bool>::value) {
374-
fnT fn = take_impl<ClipIndex<indT>, T, indT>;
374+
fnT fn = take_impl<ClipIndex, T, indT>;
375375
return fn;
376376
}
377377
else {
@@ -387,7 +387,7 @@ template <typename fnT, typename T, typename indT> struct PutWrapFactory
387387
{
388388
if constexpr (std::is_integral<indT>::value &&
389389
!std::is_same<indT, bool>::value) {
390-
fnT fn = put_impl<WrapIndex<indT>, T, indT>;
390+
fnT fn = put_impl<WrapIndex, T, indT>;
391391
return fn;
392392
}
393393
else {
@@ -403,7 +403,7 @@ template <typename fnT, typename T, typename indT> struct PutClipFactory
403403
{
404404
if constexpr (std::is_integral<indT>::value &&
405405
!std::is_same<indT, bool>::value) {
406-
fnT fn = put_impl<ClipIndex<indT>, T, indT>;
406+
fnT fn = put_impl<ClipIndex, T, indT>;
407407
return fn;
408408
}
409409
else {

0 commit comments

Comments
 (0)