Skip to content

Commit 2ad5382

Browse files
Fix build break
Moved common constexpr variables out of branches. Replaced `static constexpr` with `constexpr`. Since these are defined in procedure scope, `static` is not required. Introduced typed temporary variables, so that type deduction for `sycl::min`, `sycl::max`, `sycl::clamp` can work and removed explicit use of their template parameter. Added explicit static_cast on value of `projected` variable computed as IndT type.
1 parent 9e40df8 commit 2ad5382

File tree

1 file changed

+39
-42
lines changed

1 file changed

+39
-42
lines changed

dpctl/tensor/libtensor/include/utils/indexing_utils.hpp

Lines changed: 39 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -49,41 +49,40 @@ template <typename IndT> struct WrapIndex
4949
ssize_t operator()(ssize_t max_item, IndT ind) const
5050
{
5151
ssize_t projected;
52-
max_item = sycl::max<ssize_t>(max_item, 1);
52+
constexpr ssize_t unit(1);
53+
max_item = sycl::max(max_item, unit);
54+
55+
constexpr std::uintmax_t ind_max = std::numeric_limits<IndT>::max();
56+
constexpr std::uintmax_t ssize_max =
57+
std::numeric_limits<ssize_t>::max();
5358

5459
if constexpr (std::is_signed_v<IndT>) {
55-
static constexpr std::uintmax_t ind_max =
56-
std::numeric_limits<IndT>::max();
57-
static constexpr std::uintmax_t ssize_max =
58-
std::numeric_limits<ssize_t>::max();
59-
static constexpr std::intmax_t ind_min =
60-
std::numeric_limits<IndT>::min();
61-
static constexpr std::intmax_t ssize_min =
60+
constexpr std::intmax_t ind_min = std::numeric_limits<IndT>::min();
61+
constexpr std::intmax_t ssize_min =
6262
std::numeric_limits<ssize_t>::min();
6363

6464
if constexpr (ind_max <= ssize_max && ind_min >= ssize_min) {
65-
projected = sycl::clamp<ssize_t>(static_cast<ssize_t>(ind),
66-
-max_item, max_item - 1);
65+
const ssize_t ind_ = static_cast<ssize_t>(ind);
66+
const ssize_t lb = -max_item;
67+
const ssize_t ub = max_item - 1;
68+
projected = sycl::clamp(ind_, lb, ub);
6769
}
6870
else {
69-
projected = sycl::clamp<IndT>(ind, static_cast<IndT>(-max_item),
70-
static_cast<IndT>(max_item - 1));
71+
const IndT lb = static_cast<IndT>(-max_item);
72+
const IndT ub = static_cast<IndT>(max_item - 1);
73+
projected = static_cast<ssize_t>(sycl::clamp(ind, lb, ub));
7174
}
7275
return (projected < 0) ? projected + max_item : projected;
7376
}
7477
else {
75-
static constexpr std::uintmax_t ind_max =
76-
std::numeric_limits<IndT>::max();
77-
static constexpr std::uintmax_t ssize_max =
78-
std::numeric_limits<ssize_t>::max();
79-
8078
if constexpr (ind_max <= ssize_max) {
81-
projected =
82-
sycl::min<ssize_t>(static_cast<ssize_t>(ind), max_item - 1);
79+
const ssize_t ind_ = static_cast<ssize_t>(ind);
80+
const ssize_t ub = max_item - 1;
81+
projected = sycl::min(ind_, ub);
8382
}
8483
else {
85-
projected =
86-
sycl::min<IndT>(ind, static_cast<IndT>(max_item - 1));
84+
const IndT ub = static_cast<IndT>(max_item - 1);
85+
projected = static_cast<ssize_t>(sycl::min(ind, ub));
8786
}
8887
return projected;
8988
}
@@ -95,40 +94,38 @@ template <typename IndT> struct ClipIndex
9594
ssize_t operator()(ssize_t max_item, IndT ind) const
9695
{
9796
ssize_t projected;
98-
max_item = sycl::max<ssize_t>(max_item, 1);
97+
constexpr ssize_t unit(1);
98+
max_item = sycl::max<ssize_t>(max_item, unit);
9999

100+
constexpr std::uintmax_t ind_max = std::numeric_limits<IndT>::max();
101+
constexpr std::uintmax_t ssize_max =
102+
std::numeric_limits<ssize_t>::max();
100103
if constexpr (std::is_signed_v<IndT>) {
101-
static constexpr std::uintmax_t ind_max =
102-
std::numeric_limits<IndT>::max();
103-
static constexpr std::uintmax_t ssize_max =
104-
std::numeric_limits<ssize_t>::max();
105-
static constexpr std::intmax_t ind_min =
106-
std::numeric_limits<IndT>::min();
107-
static constexpr std::intmax_t ssize_min =
104+
constexpr std::intmax_t ind_min = std::numeric_limits<IndT>::min();
105+
constexpr std::intmax_t ssize_min =
108106
std::numeric_limits<ssize_t>::min();
109107

110108
if constexpr (ind_max <= ssize_max && ind_min >= ssize_min) {
111-
projected = sycl::clamp<ssize_t>(static_cast<ssize_t>(ind),
112-
ssize_t(0), max_item - 1);
109+
const ssize_t ind_ = static_cast<ssize_t>(ind);
110+
constexpr ssize_t lb(0);
111+
const ssize_t ub = max_item - 1;
112+
projected = sycl::clamp(ind_, lb, ub);
113113
}
114114
else {
115-
projected = sycl::clamp<IndT>(ind, IndT(0),
116-
static_cast<IndT>(max_item - 1));
115+
constexpr IndT lb(0);
116+
const IndT ub = static_cast<IndT>(max_item - 1);
117+
projected = static_cast<size_t>(sycl::clamp(ind, lb, ub));
117118
}
118119
}
119120
else {
120-
static constexpr std::uintmax_t ind_max =
121-
std::numeric_limits<IndT>::max();
122-
static constexpr std::uintmax_t ssize_max =
123-
std::numeric_limits<ssize_t>::max();
124-
125121
if constexpr (ind_max <= ssize_max) {
126-
projected =
127-
sycl::min<ssize_t>(static_cast<ssize_t>(ind), max_item - 1);
122+
const ssize_t ind_ = static_cast<ssize_t>(ind);
123+
const ssize_t ub = max_item - 1;
124+
projected = sycl::min(ind_, ub);
128125
}
129126
else {
130-
projected =
131-
sycl::min<IndT>(ind, static_cast<IndT>(max_item - 1));
127+
const IndT ub = static_cast<IndT>(max_item - 1);
128+
projected = static_cast<ssize_t>(sycl::min(ind, ub));
132129
}
133130
}
134131
return projected;

0 commit comments

Comments
 (0)