Skip to content

Commit 5913185

Browse files
<algorithm>: Optimize sample() and shuffle() with Lemire's algorithm (#5735)
1 parent fa166f2 commit 5913185

File tree

7 files changed

+259
-149
lines changed

7 files changed

+259
-149
lines changed

benchmarks/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,8 +128,10 @@ add_benchmark(remove src/remove.cpp)
128128
add_benchmark(replace src/replace.cpp)
129129
add_benchmark(reverse src/reverse.cpp)
130130
add_benchmark(rotate src/rotate.cpp)
131+
add_benchmark(sample src/sample.cpp)
131132
add_benchmark(search src/search.cpp)
132133
add_benchmark(search_n src/search_n.cpp)
134+
add_benchmark(shuffle src/shuffle.cpp)
133135
add_benchmark(std_copy src/std_copy.cpp)
134136
add_benchmark(sv_equal src/sv_equal.cpp)
135137
add_benchmark(swap_ranges src/swap_ranges.cpp)

benchmarks/src/sample.cpp

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
// Copyright (c) Microsoft Corporation.
2+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
3+
4+
#include <algorithm>
5+
#include <benchmark/benchmark.h>
6+
#include <cstddef>
7+
#include <cstdint>
8+
#include <numeric>
9+
#include <random>
10+
#include <type_traits>
11+
#include <vector>
12+
using namespace std;
13+
14+
enum class alg_type { std_fn, rng };
15+
16+
template <class T, alg_type Alg>
17+
void bm_sample(benchmark::State& state) {
18+
static_assert(is_unsigned_v<T>, "T must be unsigned so iota() doesn't have to worry about overflow.");
19+
20+
const auto population_size = static_cast<size_t>(state.range(0));
21+
const auto sampled_size = static_cast<size_t>(state.range(1));
22+
23+
vector<T> population(population_size);
24+
vector<T> sampled(sampled_size);
25+
iota(population.begin(), population.end(), T{0});
26+
mt19937_64 urbg;
27+
28+
for (auto _ : state) {
29+
benchmark::DoNotOptimize(population);
30+
if constexpr (Alg == alg_type::rng) {
31+
ranges::sample(population, sampled.begin(), sampled_size, urbg);
32+
} else {
33+
sample(population.begin(), population.end(), sampled.begin(), sampled_size, urbg);
34+
}
35+
benchmark::DoNotOptimize(sampled);
36+
}
37+
}
38+
39+
void common_args(auto bm) {
40+
bm->Args({1 << 20, 1 << 15});
41+
}
42+
43+
BENCHMARK(bm_sample<uint8_t, alg_type::std_fn>)->Apply(common_args);
44+
BENCHMARK(bm_sample<uint16_t, alg_type::std_fn>)->Apply(common_args);
45+
BENCHMARK(bm_sample<uint32_t, alg_type::std_fn>)->Apply(common_args);
46+
BENCHMARK(bm_sample<uint64_t, alg_type::std_fn>)->Apply(common_args);
47+
48+
BENCHMARK(bm_sample<uint8_t, alg_type::rng>)->Apply(common_args);
49+
BENCHMARK(bm_sample<uint16_t, alg_type::rng>)->Apply(common_args);
50+
BENCHMARK(bm_sample<uint32_t, alg_type::rng>)->Apply(common_args);
51+
BENCHMARK(bm_sample<uint64_t, alg_type::rng>)->Apply(common_args);
52+
53+
BENCHMARK_MAIN();

benchmarks/src/shuffle.cpp

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
// Copyright (c) Microsoft Corporation.
2+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
3+
4+
#include <algorithm>
5+
#include <benchmark/benchmark.h>
6+
#include <cstddef>
7+
#include <cstdint>
8+
#include <numeric>
9+
#include <random>
10+
#include <type_traits>
11+
#include <vector>
12+
using namespace std;
13+
14+
enum class alg_type { std_fn, rng };
15+
16+
template <class T, alg_type Alg>
17+
void bm_shuffle(benchmark::State& state) {
18+
static_assert(is_unsigned_v<T>, "T must be unsigned so iota() doesn't have to worry about overflow.");
19+
20+
const auto n = static_cast<size_t>(state.range(0));
21+
vector<T> v(n);
22+
iota(v.begin(), v.end(), T{0});
23+
mt19937_64 urbg;
24+
25+
for (auto _ : state) {
26+
benchmark::DoNotOptimize(v);
27+
if constexpr (Alg == alg_type::rng) {
28+
ranges::shuffle(v, urbg);
29+
} else {
30+
shuffle(v.begin(), v.end(), urbg);
31+
}
32+
}
33+
}
34+
35+
void common_args(auto bm) {
36+
bm->Arg(1 << 20);
37+
}
38+
39+
BENCHMARK(bm_shuffle<uint8_t, alg_type::std_fn>)->Apply(common_args);
40+
BENCHMARK(bm_shuffle<uint16_t, alg_type::std_fn>)->Apply(common_args);
41+
BENCHMARK(bm_shuffle<uint32_t, alg_type::std_fn>)->Apply(common_args);
42+
BENCHMARK(bm_shuffle<uint64_t, alg_type::std_fn>)->Apply(common_args);
43+
44+
BENCHMARK(bm_shuffle<uint8_t, alg_type::rng>)->Apply(common_args);
45+
BENCHMARK(bm_shuffle<uint16_t, alg_type::rng>)->Apply(common_args);
46+
BENCHMARK(bm_shuffle<uint32_t, alg_type::rng>)->Apply(common_args);
47+
BENCHMARK(bm_shuffle<uint64_t, alg_type::rng>)->Apply(common_args);
48+
49+
BENCHMARK_MAIN();

stl/inc/algorithm

Lines changed: 144 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include <yvals_core.h>
99
#if _STL_COMPILER_PREPROCESSOR
1010
#include <__msvc_heap_algorithms.hpp>
11+
#include <__msvc_int128.hpp>
1112
#include <__msvc_minmax.hpp>
1213
#include <xmemory>
1314

@@ -6019,6 +6020,141 @@ private:
60196020
_Udiff _Bmask; // 2^_Bits - 1
60206021
};
60216022

6023+
template <class _Diff, class _Urng>
6024+
class _Rng_from_urng_v2 { // wrap a URNG as an RNG
6025+
public:
6026+
using _Ty0 = make_unsigned_t<_Diff>;
6027+
using _Ty1 = _Invoke_result_t<_Urng&>;
6028+
6029+
using _Udiff = conditional_t<sizeof(_Ty1) < sizeof(_Ty0), _Ty0, _Ty1>;
6030+
static constexpr unsigned int _Udiff_bits = sizeof(_Udiff) * CHAR_BIT;
6031+
using _Uprod = conditional_t<_Udiff_bits <= 16, uint32_t, conditional_t<_Udiff_bits <= 32, uint64_t, _Unsigned128>>;
6032+
6033+
explicit _Rng_from_urng_v2(_Urng& _Func) noexcept : _Ref(_Func) {}
6034+
6035+
_Diff operator()(_Diff _Index) { // adapt _Urng closed range to [0, _Index)
6036+
// From Daniel Lemire, "Fast Random Integer Generation in an Interval",
6037+
// ACM Trans. Model. Comput. Simul. 29 (1), 2019.
6038+
//
6039+
// Algorithm 5 <-> This Code:
6040+
// m <-> _Product
6041+
// l <-> _Rem
6042+
// s <-> _Index
6043+
// t <-> _Threshold
6044+
// L <-> _Generated_bits
6045+
// 2^L - 1 <-> _Mask
6046+
6047+
_Udiff _Mask = _Bmask;
6048+
unsigned int _Niter = 1;
6049+
6050+
if constexpr (_Bits < _Udiff_bits) {
6051+
while (_Mask < static_cast<_Udiff>(_Index - 1)) {
6052+
_Mask <<= _Bits;
6053+
_Mask |= _Bmask;
6054+
++_Niter;
6055+
}
6056+
}
6057+
6058+
// x <- random integer in [0, 2^L)
6059+
// m <- x * s
6060+
auto _Product = _Get_random_product(_Index, _Niter);
6061+
// l <- m mod 2^L
6062+
auto _Rem = static_cast<_Udiff>(_Product) & _Mask;
6063+
6064+
if (_Rem < static_cast<_Udiff>(_Index)) {
6065+
// t <- (2^L - s) mod s
6066+
const auto _Threshold = (_Mask - _Index + 1) % _Index;
6067+
while (_Rem < _Threshold) {
6068+
_Product = _Get_random_product(_Index, _Niter);
6069+
_Rem = static_cast<_Udiff>(_Product) & _Mask;
6070+
}
6071+
}
6072+
6073+
unsigned int _Generated_bits;
6074+
if constexpr (_Bits < _Udiff_bits) {
6075+
_Generated_bits = static_cast<unsigned int>(_Popcount(_Mask));
6076+
} else {
6077+
_Generated_bits = _Udiff_bits;
6078+
}
6079+
6080+
// m / 2^L
6081+
return static_cast<_Diff>(_Product >> _Generated_bits);
6082+
}
6083+
6084+
_Udiff _Get_all_bits() {
6085+
_Udiff _Ret = _Get_bits();
6086+
6087+
if constexpr (_Bits < _Udiff_bits) {
6088+
for (unsigned int _Num = _Bits; _Num < _Udiff_bits; _Num += _Bits) { // don't mask away any bits
6089+
_Ret <<= _Bits;
6090+
_Ret |= _Get_bits();
6091+
}
6092+
}
6093+
6094+
return _Ret;
6095+
}
6096+
6097+
_Rng_from_urng_v2(const _Rng_from_urng_v2&) = delete;
6098+
_Rng_from_urng_v2& operator=(const _Rng_from_urng_v2&) = delete;
6099+
6100+
private:
6101+
_Udiff _Get_bits() { // return a random value within [0, _Bmask]
6102+
constexpr auto _Urng_min = (_Urng::min) ();
6103+
for (;;) { // repeat until random value is in range
6104+
const _Udiff _Val = static_cast<_Udiff>(_Ref() - _Urng_min);
6105+
6106+
if (_Val <= _Bmask) {
6107+
return _Val;
6108+
}
6109+
}
6110+
}
6111+
6112+
static constexpr size_t _Calc_bits() {
6113+
auto _Bits_local = _Udiff_bits;
6114+
auto _Bmask_local = static_cast<_Udiff>(-1);
6115+
for (; static_cast<_Udiff>((_Urng::max) () - (_Urng::min) ()) < _Bmask_local; _Bmask_local >>= 1) {
6116+
--_Bits_local;
6117+
}
6118+
6119+
return _Bits_local;
6120+
}
6121+
6122+
_Uprod _Get_random_product(const _Diff _Index, unsigned int _Niter) {
6123+
_Udiff _Ret = _Get_bits();
6124+
if constexpr (_Bits < _Udiff_bits) {
6125+
while (--_Niter > 0) {
6126+
_Ret <<= _Bits;
6127+
_Ret |= _Get_bits();
6128+
}
6129+
}
6130+
6131+
if constexpr (is_same_v<_Udiff, uint64_t>) {
6132+
uint64_t _High;
6133+
const auto _Low = _Base128::_UMul128(_Ret, static_cast<_Udiff>(_Index), _High);
6134+
return _Uprod{_Low, _High};
6135+
} else {
6136+
return _Uprod{_Ret} * static_cast<_Uprod>(_Index);
6137+
}
6138+
}
6139+
6140+
_Urng& _Ref; // reference to URNG
6141+
static constexpr size_t _Bits = _Calc_bits(); // number of random bits generated by _Get_bits()
6142+
static constexpr _Udiff _Bmask = static_cast<_Udiff>(-1) >> (_Udiff_bits - _Bits); // 2^_Bits - 1
6143+
};
6144+
6145+
template <class _Gen, class = void>
6146+
constexpr bool _Has_static_min_max = false;
6147+
6148+
// This checks a requirement of N4981 [rand.req.urng] `concept uniform_random_bit_generator` but doesn't attempt
6149+
// to implement the whole concept - we just need to distinguish Standard machinery from tr1 machinery.
6150+
template <class _Gen>
6151+
constexpr bool _Has_static_min_max<_Gen, void_t<decltype(bool_constant<(_Gen::min) () < (_Gen::max) ()>::value)>> =
6152+
true;
6153+
6154+
template <class _Diff, class _Urng>
6155+
using _Rng_from_urng_v1_or_v2 =
6156+
conditional_t<_Has_static_min_max<_Urng>, _Rng_from_urng_v2<_Diff, _Urng>, _Rng_from_urng<_Diff, _Urng>>;
6157+
60226158
#if _HAS_CXX17
60236159
template <class _PopIt, class _SampleIt, class _Diff, class _RngFn>
60246160
_SampleIt _Sample_reservoir_unchecked(
@@ -6076,7 +6212,7 @@ _SampleIt sample(_PopIt _First, _PopIt _Last, _SampleIt _Dest, _Diff _Count, _Ur
60766212
auto _UFirst = _STD _Get_unwrapped(_First);
60776213
auto _ULast = _STD _Get_unwrapped(_Last);
60786214
using _PopDiff = _Iter_diff_t<_PopIt>;
6079-
_Rng_from_urng<_PopDiff, remove_reference_t<_Urng>> _RngFunc(_Func);
6215+
_Rng_from_urng_v1_or_v2<_PopDiff, remove_reference_t<_Urng>> _RngFunc(_Func);
60806216
if constexpr (_Is_ranges_fwd_iter_v<_PopIt>) {
60816217
// source is forward: use selection sampling (stable)
60826218
using _CT = common_type_t<_Diff, _PopDiff>;
@@ -6119,7 +6255,7 @@ namespace ranges {
61196255
return _Output;
61206256
}
61216257

6122-
_Rng_from_urng<iter_difference_t<_It>, remove_reference_t<_Urng>> _RngFunc(_Func);
6258+
_Rng_from_urng_v1_or_v2<iter_difference_t<_It>, remove_reference_t<_Urng>> _RngFunc(_Func);
61236259
if constexpr (forward_iterator<_It>) {
61246260
auto _UFirst = _RANGES _Unwrap_iter<_Se>(_STD move(_First));
61256261
auto _Pop_size = _RANGES distance(_UFirst, _RANGES _Unwrap_sent<_It>(_STD move(_Last)));
@@ -6140,7 +6276,7 @@ namespace ranges {
61406276
return _Output;
61416277
}
61426278

6143-
_Rng_from_urng<range_difference_t<_Rng>, remove_reference_t<_Urng>> _RngFunc(_Func);
6279+
_Rng_from_urng_v1_or_v2<range_difference_t<_Rng>, remove_reference_t<_Urng>> _RngFunc(_Func);
61446280
if constexpr (forward_range<_Rng>) {
61456281
auto _UFirst = _Ubegin(_Range);
61466282
auto _Pop_size = _RANGES distance(_UFirst, _Uend(_Range));
@@ -6243,7 +6379,7 @@ void _Random_shuffle1(_RanIt _First, _RanIt _Last, _RngFn& _RngFunc) {
62436379
_EXPORT_STD template <class _RanIt, class _Urng>
62446380
void shuffle(_RanIt _First, _RanIt _Last, _Urng&& _Func) { // shuffle [_First, _Last) using URNG _Func
62456381
using _Urng0 = remove_reference_t<_Urng>;
6246-
_Rng_from_urng<_Iter_diff_t<_RanIt>, _Urng0> _RngFunc(_Func);
6382+
_Rng_from_urng_v1_or_v2<_Iter_diff_t<_RanIt>, _Urng0> _RngFunc(_Func);
62476383
_STD _Random_shuffle1(_First, _Last, _RngFunc);
62486384
}
62496385

@@ -6256,7 +6392,7 @@ namespace ranges {
62566392
_STATIC_CALL_OPERATOR _It operator()(_It _First, _Se _Last, _Urng&& _Func) _CONST_CALL_OPERATOR {
62576393
_STD _Adl_verify_range(_First, _Last);
62586394

6259-
_Rng_from_urng<iter_difference_t<_It>, remove_reference_t<_Urng>> _RngFunc(_Func);
6395+
_Rng_from_urng_v1_or_v2<iter_difference_t<_It>, remove_reference_t<_Urng>> _RngFunc(_Func);
62606396
auto _UResult = _Shuffle_unchecked(
62616397
_RANGES _Unwrap_iter<_Se>(_STD move(_First)), _RANGES _Unwrap_sent<_It>(_STD move(_Last)), _RngFunc);
62626398

@@ -6267,7 +6403,7 @@ namespace ranges {
62676403
template <random_access_range _Rng, class _Urng>
62686404
requires permutable<iterator_t<_Rng>> && uniform_random_bit_generator<remove_reference_t<_Urng>>
62696405
_STATIC_CALL_OPERATOR borrowed_iterator_t<_Rng> operator()(_Rng&& _Range, _Urng&& _Func) _CONST_CALL_OPERATOR {
6270-
_Rng_from_urng<range_difference_t<_Rng>, remove_reference_t<_Urng>> _RngFunc(_Func);
6406+
_Rng_from_urng_v1_or_v2<range_difference_t<_Rng>, remove_reference_t<_Urng>> _RngFunc(_Func);
62716407

62726408
return _RANGES _Rewrap_iterator(_Range, _Shuffle_unchecked(_Ubegin(_Range), _Uend(_Range), _RngFunc));
62736409
}
@@ -6313,11 +6449,11 @@ void random_shuffle(_RanIt _First, _RanIt _Last, _RngFn&& _RngFunc) {
63136449
struct _Rand_urng_from_func { // wrap rand() as a URNG
63146450
using result_type = unsigned int;
63156451

6316-
static result_type(min)() { // return minimum possible generated value
6452+
static constexpr result_type(min)() { // return minimum possible generated value
63176453
return 0;
63186454
}
63196455

6320-
static result_type(max)() { // return maximum possible generated value
6456+
static constexpr result_type(max)() { // return maximum possible generated value
63216457
return RAND_MAX;
63226458
}
63236459

0 commit comments

Comments
 (0)