Skip to content

Commit 509b359

Browse files
author
kevyuu
committed
Make radix sort more efficient
1 parent e91aba1 commit 509b359

File tree

1 file changed

+18
-35
lines changed

1 file changed

+18
-35
lines changed

include/nbl/core/algorithm/radix_sort.h

Lines changed: 18 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -67,60 +67,43 @@ struct RadixLsbSorter
6767
return pass<RandomIt,KeyAccessor,0ull>(input,output,rangeSize,comp);
6868
}
6969

70-
std::pair<histogram_t, histogram_t> getHashBound(size_t key) const
70+
std::pair<histogram_t, histogram_t> getMostSignificantRadixBound(size_t key) const
7171
{
7272
constexpr histogram_t shift = static_cast<histogram_t>(radix_bits * last_pass);
7373
const auto histogramIx = (key >> shift) & radix_mask;
74-
const auto boundBegin = histogramIx == 0 ? 0 : histogram[histogramIx - 1];
75-
return { boundBegin, histogram[histogramIx] };
74+
const auto boundBegin = histogramIx == 0 ? 0 : m_histogram[histogramIx - 1];
75+
return { boundBegin, m_histogram[histogramIx] };
7676
}
7777

7878
private:
7979
template<class RandomIt, class KeyAccessor, size_t pass_ix>
8080
inline RandomIt pass(RandomIt input, RandomIt output, const histogram_t rangeSize, const KeyAccessor& comp)
8181
{
8282
// clear
83-
std::fill_n(histogram,histogram_size,static_cast<histogram_t>(0u));
83+
std::fill_n(m_histogram,histogram_size,static_cast<histogram_t>(0u));
84+
8485
// count
8586
constexpr histogram_t shift = static_cast<histogram_t>(radix_bits*pass_ix);
8687
for (histogram_t i = 0u; i < rangeSize; i++)
87-
++histogram[comp.template operator()<shift,radix_mask>(input[i])];
88+
++m_histogram[comp.template operator()<shift,radix_mask>(input[i])];
89+
8890
// prefix sum
89-
std::inclusive_scan(histogram, histogram + histogram_size, histogram);
90-
// scatter
91+
std::exclusive_scan(m_histogram, m_histogram + histogram_size, m_histogram, 0);
9192

92-
if constexpr (pass_ix != last_pass)
93+
// scatter. After scatter m_histogram now become a skiplist
94+
for (histogram_t i = 0; i < rangeSize; i++)
9395
{
94-
95-
for (histogram_t i = rangeSize; i != 0u;)
96-
{
97-
i--;
98-
const auto& val = input[i];
99-
const auto& histogramIx = comp.template operator()<shift,radix_mask>(val);
100-
output[--histogram[histogramIx]] = val;
101-
}
102-
103-
return pass<RandomIt,KeyAccessor,pass_ix+1ull>(output,input,rangeSize,comp);
104-
}
105-
else
106-
{
107-
// need to preserve histogram value for the skip list, so we copy to temporary histogramArray and use that
108-
std::array<histogram_t, histogram_size> tmpHistogram;
109-
std::copy(histogram, histogram + histogram_size, tmpHistogram.data());
110-
111-
for (histogram_t i = rangeSize; i != 0u;)
112-
{
113-
i--;
114-
const auto& val = input[i];
115-
const auto& histogramIx = comp.template operator()<shift,radix_mask>(val);
116-
output[--tmpHistogram[histogramIx]] = val;
117-
}
118-
119-
return output;
96+
const auto& val = input[i];
97+
const auto& histogramIx = comp.template operator()<shift,radix_mask>(val);
98+
output[m_histogram[histogramIx]++] = val;
12099
}
100+
101+
if constexpr (pass_ix != last_pass)
102+
return pass<RandomIt,KeyAccessor,pass_ix+1ull>(output,input,rangeSize,comp);
103+
return output;
121104
}
122105

123-
alignas(sizeof(histogram_t)) histogram_t histogram[histogram_size];
106+
alignas(sizeof(histogram_t)) histogram_t m_histogram[histogram_size];
124107
};
125108

126109
template<class RandomIt, class KeyAccessor>

0 commit comments

Comments
 (0)