1
+ // Copyright (C) 2018-2020 - DevSH Graphics Programming Sp. z O.O.
2
+ // This file is part of the "Nabla Engine".
3
+ // For conditions of distribution and use, see copyright notice in nabla.h
4
+
5
+ #ifndef __NBL_CORE_RADIX_SORT_H_INCLUDED__
6
+ #define __NBL_CORE_RADIX_SORT_H_INCLUDED__
7
+
8
+ #include < algorithm>
9
+ #include < bitset>
10
+ #include < cstdint>
11
+ #include < numeric>
12
+
13
+ #include " nbl/macros.h"
14
+
15
+ namespace nbl
16
+ {
17
+ namespace core
18
+ {
19
+
20
+ namespace impl
21
+ {
22
+
23
+ template <typename T>
24
+ struct KeyAdaptor
25
+ {
26
+ static_assert (std::is_integral_v<T>&&std::is_unsigned_v<T>," Need to use your own key value accessor." );
27
+ _NBL_STATIC_INLINE_CONSTEXPR size_t key_bit_count = sizeof (T)*8u ;
28
+
29
+ template <auto bit_offset, auto radix_mask>
30
+ inline decltype (radix_mask) operator()(const T& item) const
31
+ {
32
+ return static_cast <decltype (radix_mask)>(item>>static_cast <T>(bit_offset))&radix_mask;
33
+ }
34
+ };
35
+
36
+ template <typename T>
37
+ constexpr uint8_t find_msb (const T& a_variable)
38
+ {
39
+ static_assert (std::is_unsigned<T>::value, " Variable must be unsigned" );
40
+
41
+ constexpr uint8_t number_of_bits = std::numeric_limits<T>::digits;
42
+ const std::bitset<number_of_bits> variable_bitset{a_variable};
43
+
44
+ for (uint8_t msb = number_of_bits - 1 ; msb >= 0 ; msb--)
45
+ {
46
+ if (variable_bitset[msb] == 1 )
47
+ return msb + 1 ;
48
+ }
49
+ return 0 ;
50
+ }
51
+
52
+ template <size_t key_bit_count, typename histogram_t >
53
+ struct RadixSorter
54
+ {
55
+ _NBL_STATIC_INLINE_CONSTEXPR uint16_t histogram_bytesize = 8192u ;
56
+ _NBL_STATIC_INLINE_CONSTEXPR size_t histogram_size = size_t (histogram_bytesize)/sizeof (histogram_t );
57
+ _NBL_STATIC_INLINE_CONSTEXPR uint8_t radix_bits = find_msb(histogram_size);
58
+ _NBL_STATIC_INLINE_CONSTEXPR size_t last_pass = (key_bit_count-1ull )/size_t (radix_bits);
59
+ _NBL_STATIC_INLINE_CONSTEXPR uint16_t radix_mask = (1u <<radix_bits)-1u ;
60
+
61
+ template <class RandomIt , class KeyAccessor >
62
+ inline RandomIt operator ()(RandomIt input, RandomIt output, const histogram_t rangeSize, const KeyAccessor& comp)
63
+ {
64
+ return pass<RandomIt,KeyAccessor,0ull >(input,output,rangeSize,comp);
65
+ }
66
+ private:
67
+ template <class RandomIt , class KeyAccessor , size_t pass_ix>
68
+ inline RandomIt pass (RandomIt input, RandomIt output, const histogram_t rangeSize, const KeyAccessor& comp)
69
+ {
70
+ // clear
71
+ std::fill_n (histogram,histogram_size,static_cast <histogram_t >(0u ));
72
+ // count
73
+ constexpr histogram_t shift = static_cast <histogram_t >(radix_bits*pass_ix);
74
+ for (histogram_t i=0u ; i<rangeSize; i++)
75
+ ++histogram[comp.operator ()<shift,radix_mask>(input[i])];
76
+ // prefix sum
77
+ std::inclusive_scan (histogram,histogram+histogram_size,histogram);
78
+ // scatter
79
+ for (histogram_t i=0u ; i<rangeSize; i++)
80
+ output[--histogram[comp.operator ()<shift,radix_mask>(input[i])]] = input[i];
81
+
82
+ if constexpr (pass_ix != last_pass)
83
+ return pass<RandomIt,KeyAccessor,pass_ix+1ull >(output,input,rangeSize,comp);
84
+ else
85
+ return output;
86
+ }
87
+
88
+ alignas (sizeof (histogram_t )) histogram_t histogram[histogram_size];
89
+ };
90
+
91
+ }
92
+
93
+ template <class RandomIt , class KeyAccessor >
94
+ inline RandomIt radix_sort (RandomIt input, RandomIt scratch, const size_t rangeSize, const KeyAccessor& comp)
95
+ {
96
+ assert (std::abs (std::distance (input,scratch))>=rangeSize);
97
+
98
+ if (rangeSize<static_cast <decltype (rangeSize)>(0x1ull <<16ull ))
99
+ return impl::RadixSorter<KeyAccessor::key_bit_count,uint16_t >()(input,scratch,static_cast <uint16_t >(rangeSize),comp);
100
+ if (rangeSize<static_cast <decltype (rangeSize)>(0x1ull <<32ull ))
101
+ return impl::RadixSorter<KeyAccessor::key_bit_count,uint32_t >()(input,scratch,static_cast <uint16_t >(rangeSize),comp);
102
+ else
103
+ return impl::RadixSorter<KeyAccessor::key_bit_count,size_t >()(input,scratch,rangeSize,comp);
104
+ }
105
+
106
+ // ! Because Radix Sort needs O(2n) space and a number of passes dependant on the key length, the final sorted range can be either in `input` or `scratch`
107
+ template <class RandomIt >
108
+ inline RandomIt radix_sort (RandomIt input, RandomIt scratch, const size_t rangeSize)
109
+ {
110
+ return radix_sort<RandomIt>(input,scratch,rangeSize,impl::KeyAdaptor<decltype (*input)>());
111
+ }
112
+
113
+ }
114
+ }
115
+
116
+ #endif
0 commit comments