Skip to content

Commit 6dc298c

Browse files
Merge pull request #1895 from IntelPython/contribution-to-1894
Contribution to 1894
2 parents 9e40df8 + 0493fdd commit 6dc298c

File tree

2 files changed

+41
-44
lines changed

2 files changed

+41
-44
lines changed

dpctl/tensor/libtensor/include/kernels/integer_advanced_indexing.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ class TakeFunctor
9393
ssize_t src_offset = orthog_offsets.get_first_offset();
9494
ssize_t dst_offset = orthog_offsets.get_second_offset();
9595

96-
const ProjectorT proj{};
96+
constexpr ProjectorT proj{};
9797
for (int axis_idx = 0; axis_idx < k_; ++axis_idx) {
9898
indT *ind_data = reinterpret_cast<indT *>(ind_[axis_idx]);
9999

@@ -239,7 +239,7 @@ class PutFunctor
239239
ssize_t dst_offset = orthog_offsets.get_first_offset();
240240
ssize_t val_offset = orthog_offsets.get_second_offset();
241241

242-
const ProjectorT proj{};
242+
constexpr ProjectorT proj{};
243243
for (int axis_idx = 0; axis_idx < k_; ++axis_idx) {
244244
indT *ind_data = reinterpret_cast<indT *>(ind_[axis_idx]);
245245

dpctl/tensor/libtensor/include/utils/indexing_utils.hpp

Lines changed: 39 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -49,41 +49,40 @@ template <typename IndT> struct WrapIndex
4949
ssize_t operator()(ssize_t max_item, IndT ind) const
5050
{
5151
ssize_t projected;
52-
max_item = sycl::max<ssize_t>(max_item, 1);
52+
constexpr ssize_t unit(1);
53+
max_item = sycl::max(max_item, unit);
54+
55+
constexpr std::uintmax_t ind_max = std::numeric_limits<IndT>::max();
56+
constexpr std::uintmax_t ssize_max =
57+
std::numeric_limits<ssize_t>::max();
5358

5459
if constexpr (std::is_signed_v<IndT>) {
55-
static constexpr std::uintmax_t ind_max =
56-
std::numeric_limits<IndT>::max();
57-
static constexpr std::uintmax_t ssize_max =
58-
std::numeric_limits<ssize_t>::max();
59-
static constexpr std::intmax_t ind_min =
60-
std::numeric_limits<IndT>::min();
61-
static constexpr std::intmax_t ssize_min =
60+
constexpr std::intmax_t ind_min = std::numeric_limits<IndT>::min();
61+
constexpr std::intmax_t ssize_min =
6262
std::numeric_limits<ssize_t>::min();
6363

6464
if constexpr (ind_max <= ssize_max && ind_min >= ssize_min) {
65-
projected = sycl::clamp<ssize_t>(static_cast<ssize_t>(ind),
66-
-max_item, max_item - 1);
65+
const ssize_t ind_ = static_cast<ssize_t>(ind);
66+
const ssize_t lb = -max_item;
67+
const ssize_t ub = max_item - 1;
68+
projected = sycl::clamp(ind_, lb, ub);
6769
}
6870
else {
69-
projected = sycl::clamp<IndT>(ind, static_cast<IndT>(-max_item),
70-
static_cast<IndT>(max_item - 1));
71+
const IndT lb = static_cast<IndT>(-max_item);
72+
const IndT ub = static_cast<IndT>(max_item - 1);
73+
projected = static_cast<ssize_t>(sycl::clamp(ind, lb, ub));
7174
}
7275
return (projected < 0) ? projected + max_item : projected;
7376
}
7477
else {
75-
static constexpr std::uintmax_t ind_max =
76-
std::numeric_limits<IndT>::max();
77-
static constexpr std::uintmax_t ssize_max =
78-
std::numeric_limits<ssize_t>::max();
79-
8078
if constexpr (ind_max <= ssize_max) {
81-
projected =
82-
sycl::min<ssize_t>(static_cast<ssize_t>(ind), max_item - 1);
79+
const ssize_t ind_ = static_cast<ssize_t>(ind);
80+
const ssize_t ub = max_item - 1;
81+
projected = sycl::min(ind_, ub);
8382
}
8483
else {
85-
projected =
86-
sycl::min<IndT>(ind, static_cast<IndT>(max_item - 1));
84+
const IndT ub = static_cast<IndT>(max_item - 1);
85+
projected = static_cast<ssize_t>(sycl::min(ind, ub));
8786
}
8887
return projected;
8988
}
@@ -95,40 +94,38 @@ template <typename IndT> struct ClipIndex
9594
ssize_t operator()(ssize_t max_item, IndT ind) const
9695
{
9796
ssize_t projected;
98-
max_item = sycl::max<ssize_t>(max_item, 1);
97+
constexpr ssize_t unit(1);
98+
max_item = sycl::max<ssize_t>(max_item, unit);
9999

100+
constexpr std::uintmax_t ind_max = std::numeric_limits<IndT>::max();
101+
constexpr std::uintmax_t ssize_max =
102+
std::numeric_limits<ssize_t>::max();
100103
if constexpr (std::is_signed_v<IndT>) {
101-
static constexpr std::uintmax_t ind_max =
102-
std::numeric_limits<IndT>::max();
103-
static constexpr std::uintmax_t ssize_max =
104-
std::numeric_limits<ssize_t>::max();
105-
static constexpr std::intmax_t ind_min =
106-
std::numeric_limits<IndT>::min();
107-
static constexpr std::intmax_t ssize_min =
104+
constexpr std::intmax_t ind_min = std::numeric_limits<IndT>::min();
105+
constexpr std::intmax_t ssize_min =
108106
std::numeric_limits<ssize_t>::min();
109107

110108
if constexpr (ind_max <= ssize_max && ind_min >= ssize_min) {
111-
projected = sycl::clamp<ssize_t>(static_cast<ssize_t>(ind),
112-
ssize_t(0), max_item - 1);
109+
const ssize_t ind_ = static_cast<ssize_t>(ind);
110+
constexpr ssize_t lb(0);
111+
const ssize_t ub = max_item - 1;
112+
projected = sycl::clamp(ind_, lb, ub);
113113
}
114114
else {
115-
projected = sycl::clamp<IndT>(ind, IndT(0),
116-
static_cast<IndT>(max_item - 1));
115+
constexpr IndT lb(0);
116+
const IndT ub = static_cast<IndT>(max_item - 1);
117+
projected = static_cast<size_t>(sycl::clamp(ind, lb, ub));
117118
}
118119
}
119120
else {
120-
static constexpr std::uintmax_t ind_max =
121-
std::numeric_limits<IndT>::max();
122-
static constexpr std::uintmax_t ssize_max =
123-
std::numeric_limits<ssize_t>::max();
124-
125121
if constexpr (ind_max <= ssize_max) {
126-
projected =
127-
sycl::min<ssize_t>(static_cast<ssize_t>(ind), max_item - 1);
122+
const ssize_t ind_ = static_cast<ssize_t>(ind);
123+
const ssize_t ub = max_item - 1;
124+
projected = sycl::min(ind_, ub);
128125
}
129126
else {
130-
projected =
131-
sycl::min<IndT>(ind, static_cast<IndT>(max_item - 1));
127+
const IndT ub = static_cast<IndT>(max_item - 1);
128+
projected = static_cast<ssize_t>(sycl::min(ind, ub));
132129
}
133130
}
134131
return projected;

0 commit comments

Comments
 (0)