Skip to content

Commit 639ad23

Browse files
committed
Move radix_sort utilities to a common header file and add set_bucket_id function
1 parent 5e35e0e commit 639ad23

File tree

2 files changed

+399
-194
lines changed

2 files changed

+399
-194
lines changed
Lines changed: 354 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,354 @@
1+
//
2+
// Data Parallel Control (dpctl)
3+
//
4+
// Copyright 2020-2024 Intel Corporation
5+
//
6+
// Licensed under the Apache License, Version 2.0 (the "License");
7+
// you may not use this file except in compliance with the License.
8+
// You may obtain a copy of the License at
9+
//
10+
// http://www.apache.org/licenses/LICENSE-2.0
11+
//
12+
// Unless required by applicable law or agreed to in writing, software
13+
// distributed under the License is distributed on an "AS IS" BASIS,
14+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
// See the License for the specific language governing permissions and
16+
// limitations under the License.
17+
//
18+
//===--------------------------------------------------------------------===//
19+
///
20+
/// \file
21+
/// This file defines utility functions common to radix-based algorithms
22+
/// such as radix sorting and radix selection.
23+
//===--------------------------------------------------------------------===//
24+
25+
#pragma once
26+
27+
#include <cstddef>
28+
#include <cstdint>
29+
#include <limits>
30+
#include <type_traits>
31+
32+
#include <sycl/sycl.hpp>
33+
34+
namespace dpctl
35+
{
36+
namespace tensor
37+
{
38+
namespace kernels
39+
{
40+
41+
namespace radix_common
42+
{
43+
44+
//----------------------------------------------------------
45+
// bitwise order-preserving conversions to unsigned integers
46+
//----------------------------------------------------------
47+
48+
template <bool is_ascending, typename T, typename Enable = void>
49+
struct RadixTypeConfig
50+
{
51+
};
52+
53+
template <bool is_ascending> struct RadixTypeConfig<is_ascending, bool>
54+
{
55+
typedef bool RadixType;
56+
57+
static inline RadixType encode(bool val)
58+
{
59+
if constexpr (is_ascending)
60+
return val;
61+
else
62+
return !val;
63+
}
64+
65+
static inline bool decode(RadixType val)
66+
{
67+
if constexpr (is_ascending)
68+
return val;
69+
else
70+
return !val;
71+
}
72+
};
73+
74+
template <bool is_ascending, typename UIntT>
75+
struct RadixTypeConfig<is_ascending,
76+
UIntT,
77+
std::enable_if_t<std::is_unsigned_v<UIntT>>>
78+
{
79+
typedef UIntT RadixType;
80+
81+
static inline RadixType encode(UIntT val)
82+
{
83+
if constexpr (is_ascending) {
84+
return val;
85+
}
86+
else {
87+
// bitwise invert
88+
return (~val);
89+
}
90+
}
91+
92+
static inline UIntT decode(RadixType val)
93+
{
94+
if constexpr (is_ascending) {
95+
return val;
96+
}
97+
else {
98+
// bitwise invert
99+
return (~val);
100+
}
101+
}
102+
};
103+
104+
template <bool is_ascending, typename IntT>
105+
struct RadixTypeConfig<
106+
is_ascending,
107+
IntT,
108+
std::enable_if_t<std::is_integral_v<IntT> && std::is_signed_v<IntT>>>
109+
{
110+
typedef std::make_unsigned_t<IntT> RadixType;
111+
112+
static inline RadixType encode(IntT val)
113+
{
114+
// ascending_mask: 100..0
115+
constexpr RadixType ascending_mask =
116+
(RadixType(1) << std::numeric_limits<IntT>::digits);
117+
// descending_mask: 011..1
118+
constexpr RadixType descending_mask =
119+
(std::numeric_limits<RadixType>::max() >> 1);
120+
121+
constexpr RadixType mask =
122+
(is_ascending) ? ascending_mask : descending_mask;
123+
const RadixType uint_val = sycl::bit_cast<RadixType>(val);
124+
125+
return (uint_val ^ mask);
126+
}
127+
128+
static inline IntT decode(RadixType val)
129+
{
130+
// ascending_mask: 100..0
131+
constexpr RadixType ascending_mask =
132+
(RadixType(1) << std::numeric_limits<IntT>::digits);
133+
// descending_mask: 011..1
134+
constexpr RadixType descending_mask =
135+
(std::numeric_limits<RadixType>::max() >> 1);
136+
137+
constexpr RadixType mask =
138+
(is_ascending) ? ascending_mask : descending_mask;
139+
const IntT int_val = sycl::bit_cast<IntT>(val);
140+
141+
return (int_val ^ mask);
142+
}
143+
};
144+
145+
template <bool is_ascending> struct RadixTypeConfig<is_ascending, sycl::half>
146+
{
147+
typedef std::uint16_t RadixType;
148+
149+
static inline RadixType encode(sycl::half val)
150+
{
151+
const RadixType uint_val = sycl::bit_cast<RadixType>(
152+
(sycl::isnan(val)) ? std::numeric_limits<sycl::half>::quiet_NaN()
153+
: val);
154+
RadixType mask;
155+
156+
// test the sign bit of the original value
157+
const bool zero_fp_sign_bit = (RadixType(0) == (uint_val >> 15));
158+
159+
constexpr RadixType zero_mask = RadixType(0x8000u);
160+
constexpr RadixType nonzero_mask = RadixType(0xFFFFu);
161+
162+
constexpr RadixType inv_zero_mask = static_cast<RadixType>(~zero_mask);
163+
constexpr RadixType inv_nonzero_mask =
164+
static_cast<RadixType>(~nonzero_mask);
165+
166+
if constexpr (is_ascending) {
167+
mask = (zero_fp_sign_bit) ? zero_mask : nonzero_mask;
168+
}
169+
else {
170+
mask = (zero_fp_sign_bit) ? (inv_zero_mask) : (inv_nonzero_mask);
171+
}
172+
173+
return (uint_val ^ mask);
174+
}
175+
176+
static inline sycl::half decode(RadixType uint_val)
177+
{
178+
RadixType mask;
179+
180+
// test the sign bit of the original value
181+
const bool zero_fp_sign_bit = (RadixType(0) == (uint_val >> 15));
182+
183+
constexpr RadixType nonzero_mask = RadixType(0x8000u);
184+
constexpr RadixType zero_mask = RadixType(0xFFFFu);
185+
186+
constexpr RadixType inv_nonzero_mask =
187+
static_cast<RadixType>(~nonzero_mask);
188+
constexpr RadixType inv_zero_mask = static_cast<RadixType>(~zero_mask);
189+
190+
if constexpr (is_ascending) {
191+
mask = (zero_fp_sign_bit) ? zero_mask : nonzero_mask;
192+
}
193+
else {
194+
mask = (zero_fp_sign_bit) ? (inv_zero_mask) : (inv_nonzero_mask);
195+
}
196+
197+
const RadixType masked = uint_val ^ mask;
198+
return sycl::bit_cast<sycl::half>(masked);
199+
}
200+
};
201+
202+
template <bool is_ascending, typename FloatT>
203+
struct RadixTypeConfig<
204+
is_ascending,
205+
FloatT,
206+
std::enable_if_t<std::is_floating_point_v<FloatT> &&
207+
sizeof(FloatT) == sizeof(std::uint32_t)>>
208+
{
209+
typedef std::uint32_t RadixType;
210+
211+
static inline RadixType encode(FloatT val)
212+
{
213+
RadixType uint_val = sycl::bit_cast<RadixType>(
214+
(sycl::isnan(val)) ? std::numeric_limits<FloatT>::quiet_NaN()
215+
: val);
216+
217+
RadixType mask;
218+
219+
// test the sign bit of the original value
220+
const bool zero_fp_sign_bit = (RadixType(0) == (uint_val >> 31));
221+
222+
constexpr RadixType zero_mask = RadixType(0x80000000u);
223+
constexpr RadixType nonzero_mask = RadixType(0xFFFFFFFFu);
224+
225+
if constexpr (is_ascending)
226+
mask = (zero_fp_sign_bit) ? zero_mask : nonzero_mask;
227+
else
228+
mask = (zero_fp_sign_bit) ? (~zero_mask) : (~nonzero_mask);
229+
230+
return (uint_val ^ mask);
231+
}
232+
233+
static inline FloatT decode(RadixType uint_val)
234+
{
235+
RadixType mask;
236+
237+
// test the sign bit of the original value
238+
const bool zero_fp_sign_bit = (RadixType(0) == (uint_val >> 31));
239+
240+
constexpr RadixType zero_mask = RadixType(0xFFFFFFFFu);
241+
constexpr RadixType nonzero_mask = RadixType(0x80000000u);
242+
243+
if constexpr (is_ascending)
244+
mask = (zero_fp_sign_bit) ? zero_mask : nonzero_mask;
245+
else
246+
mask = (zero_fp_sign_bit) ? (~zero_mask) : (~nonzero_mask);
247+
248+
const RadixType masked = uint_val ^ mask;
249+
return sycl::bit_cast<FloatT>(masked);
250+
}
251+
};
252+
253+
template <bool is_ascending, typename FloatT>
254+
struct RadixTypeConfig<
255+
is_ascending,
256+
FloatT,
257+
std::enable_if_t<std::is_floating_point_v<FloatT> &&
258+
sizeof(FloatT) == sizeof(std::uint64_t)>>
259+
{
260+
typedef std::uint64_t RadixType;
261+
262+
static inline RadixType encode(FloatT val)
263+
{
264+
265+
RadixType uint_val = sycl::bit_cast<RadixType>(
266+
(sycl::isnan(val)) ? std::numeric_limits<FloatT>::quiet_NaN()
267+
: val);
268+
RadixType mask;
269+
270+
// test the sign bit of the original value
271+
const bool zero_fp_sign_bit = (RadixType(0) == (uint_val >> 63));
272+
273+
constexpr RadixType zero_mask = RadixType(0x8000000000000000u);
274+
constexpr RadixType nonzero_mask = RadixType(0xFFFFFFFFFFFFFFFFu);
275+
276+
if constexpr (is_ascending)
277+
mask = (zero_fp_sign_bit) ? zero_mask : nonzero_mask;
278+
else
279+
mask = (zero_fp_sign_bit) ? (~zero_mask) : (~nonzero_mask);
280+
281+
return (uint_val ^ mask);
282+
}
283+
284+
static inline FloatT decode(RadixType uint_val)
285+
{
286+
RadixType mask;
287+
288+
// test the sign bit of the original value
289+
const bool zero_fp_sign_bit = (RadixType(0) == (uint_val >> 63));
290+
291+
constexpr RadixType zero_mask = RadixType(0xFFFFFFFFFFFFFFFFu);
292+
constexpr RadixType nonzero_mask = RadixType(0x8000000000000000u);
293+
294+
if constexpr (is_ascending)
295+
mask = (zero_fp_sign_bit) ? zero_mask : nonzero_mask;
296+
else
297+
mask = (zero_fp_sign_bit) ? (~zero_mask) : (~nonzero_mask);
298+
299+
const RadixType masked = uint_val ^ mask;
300+
return sycl::bit_cast<FloatT>(masked);
301+
}
302+
};
303+
304+
template <bool is_ascending, typename T>
305+
typename RadixTypeConfig<is_ascending, T>::RadixType
306+
order_preserving_cast(T val)
307+
{
308+
return RadixTypeConfig<is_ascending, T>::encode(val);
309+
}
310+
311+
//-----------------
312+
// bucket functions
313+
//-----------------
314+
315+
template <typename T> constexpr std::size_t number_of_bits_in_type()
316+
{
317+
constexpr std::size_t type_bits =
318+
(sizeof(T) * std::numeric_limits<unsigned char>::digits);
319+
return type_bits;
320+
}
321+
322+
// the number of buckets (size of radix bits) in T
323+
template <typename T>
324+
constexpr std::uint32_t number_of_buckets_in_type(std::uint32_t radix_bits)
325+
{
326+
constexpr std::size_t type_bits = number_of_bits_in_type<T>();
327+
return (type_bits + radix_bits - 1) / radix_bits;
328+
}
329+
330+
// get bits value (bucket) in a certain radix position
331+
template <std::uint32_t radix_mask, typename T>
332+
std::uint32_t get_bucket_id(T val, std::uint32_t radix_offset)
333+
{
334+
static_assert(std::is_unsigned_v<T>);
335+
336+
return (val >> radix_offset) & T(radix_mask);
337+
}
338+
339+
template <std::uint32_t radix_mask, typename T>
340+
T set_bucket_id(T val, T insert, std::uint32_t radix_offset)
341+
{
342+
static_assert(std::is_unsigned_v<T>);
343+
344+
T m = radix_mask;
345+
insert &= m;
346+
insert <<= radix_offset;
347+
m <<= radix_offset;
348+
return (val & ~m) | insert;
349+
}
350+
351+
} // namespace radix_common
352+
} // namespace kernels
353+
} // namespace tensor
354+
} // namespace dpctl

0 commit comments

Comments
 (0)