-
-
Notifications
You must be signed in to change notification settings - Fork 24
Expand file tree
/
Copy pathrandom-sampler.cpp
More file actions
85 lines (70 loc) · 1.93 KB
/
random-sampler.cpp
File metadata and controls
85 lines (70 loc) · 1.93 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
//
// random-sampler.cpp
//
// Copyright © 2020 by Blockchain Commons, LLC
// Licensed under the "BSD-2-Clause Plus Patent License"
//
#include "random-sampler.hpp"
#include <numeric>
#include <algorithm>
#include <assert.h>
#include <iterator>
using namespace std;
namespace ur {
RandomSampler::RandomSampler(std::vector<double> probs) {
for(auto p: probs) { assert(p >= 0); }
// Normalize given probabilities
auto sum = accumulate(probs.begin(), probs.end(), 0.0);
assert(sum > 0);
auto n = probs.size();
vector<double> P;
P.reserve(n);
transform(probs.begin(), probs.end(), back_inserter(P), [&](double d) { return d * double(n) / sum; });
vector<int> S;
S.reserve(n);
vector<int> L;
L.reserve(n);
// Set separate index lists for small and large probabilities:
for(int i = n - 1; i >= 0; i--) {
// at variance from Schwarz, we reverse the index order
if(P[i] < 1) {
S.push_back(i);
} else {
L.push_back(i);
}
}
// Work through index lists
vector<double> _probs(n, 0);
vector<int> _aliases(n, 0);
while(!S.empty() && !L.empty()) {
auto a = S.back(); S.pop_back(); // Schwarz's l
auto g = L.back(); L.pop_back(); // Schwarz's g
_probs[a] = P[a];
_aliases[a] = g;
P[g] += P[a] - 1;
if(P[g] < 1) {
S.push_back(g);
} else {
L.push_back(g);
}
}
while(!L.empty()) {
_probs[L.back()] = 1;
L.pop_back();
}
while(!S.empty()) {
// can only happen through numeric instability
_probs[S.back()] = 1;
S.pop_back();
}
this->probs_ = _probs;
this->aliases_ = _aliases;
}
int RandomSampler::next(std::function<double()> rng) {
auto r1 = rng();
auto r2 = rng();
auto n = probs_.size();
auto i = int(double(n) * r1);
return r2 < probs_[i] ? i : aliases_[i];
}
}