Skip to content

Commit 15da303

Browse files
oleksandr-pavlykndgrigorian
authored andcommitted
Replace use of std::log2(size_t_value)
Replaced it with hand-written implementation of ceil_log2(n), such that n <= (dectype(n){1} << ceil_log2(n)) is true for all positive values of `n` in the range.
1 parent 2608aea commit 15da303

File tree

1 file changed

+40
-1
lines changed

1 file changed

+40
-1
lines changed

dpctl/tensor/libtensor/include/kernels/sorting/radix_sort.hpp

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,45 @@ class radix_sort_reorder_peer_kernel;
6363
template <std::uint32_t, bool, typename... TrailingNames>
6464
class radix_sort_reorder_kernel;
6565

66+
/*! @brief Computes smallest exponent such that `n <= (1 << exponent)` */
67+
template <typename SizeT,
68+
std::enable_if_t<std::is_unsigned_v<SizeT> &&
69+
sizeof(SizeT) == sizeof(std::uint64_t),
70+
int> = 0>
71+
std::uint32_t ceil_log2(SizeT n)
72+
{
73+
if (n <= 1)
74+
return std::uint32_t{1};
75+
76+
std::uint32_t exp{1};
77+
--n;
78+
if (n >= (SizeT{1} << 32)) {
79+
n >>= 32;
80+
exp += 32;
81+
}
82+
if (n >= (SizeT{1} << 16)) {
83+
n >>= 16;
84+
exp += 16;
85+
}
86+
if (n >= (SizeT{1} << 8)) {
87+
n >>= 8;
88+
exp += 8;
89+
}
90+
if (n >= (SizeT{1} << 4)) {
91+
n >>= 4;
92+
exp += 4;
93+
}
94+
if (n >= (SizeT{1} << 2)) {
95+
n >>= 2;
96+
exp += 2;
97+
}
98+
if (n >= (SizeT{1} << 1)) {
99+
n >>= 1;
100+
++exp;
101+
}
102+
return exp;
103+
}
104+
66105
//----------------------------------------------------------
67106
// bitwise order-preserving conversions to unsigned integers
68107
//----------------------------------------------------------
@@ -1145,7 +1184,7 @@ struct subgroup_radix_sort
11451184
const std::size_t max_slm_size =
11461185
dev.template get_info<sycl::info::device::local_mem_size>() / 2;
11471186

1148-
const auto n_uniform = 1 << (std::uint32_t(std::log2(n - 1)) + 1);
1187+
const auto n_uniform = 1 << ceil_log2(n);
11491188
const auto req_slm_size_val = sizeof(T) * n_uniform;
11501189

11511190
return ((req_slm_size_val + req_slm_size_counters) <= max_slm_size)

0 commit comments

Comments
 (0)