1
- /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
1
+ /* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2
2
3
3
Licensed under the Apache License, Version 2.0 (the "License");
4
4
you may not use this file except in compliance with the License.
@@ -13,58 +13,123 @@ See the License for the specific language governing permissions and
13
13
limitations under the License. */
14
14
15
15
#include " paddle/fluid/operators/math/sampler.h"
16
+ #include < iostream>
17
+ #include < queue>
18
+ #include < utility>
19
+ #include < vector>
16
20
17
21
namespace paddle {
18
- namespace random {
22
+ namespace operators {
23
+ namespace math {
19
24
20
25
Sampler::~Sampler () {}
21
26
22
- UniformSampler::UniformSampler (int64 range)
23
- : Sampler(range), inv_range_(1.0 / range) {
24
- random_engine_ = std::make_shared<std::mt19937 >(seed_);
27
+ UniformSampler::UniformSampler (int64_t range, unsigned int seed )
28
+ : Sampler(range, seed ), inv_range_(1.0 / ( range + 1 ) ) {
29
+ random_engine_ = std::make_shared<std::mt19937_64 >(seed_);
25
30
dist_ = std::make_shared<std::uniform_int_distribution<>>(0 , range);
26
31
}
27
32
28
- UniformSampler::UniformSampler (int64 range, unsigned int seed)
29
- : Sampler(range, seed), inv_range_(1.0 / range) {
30
- random_engine_ = std::make_shared<std::mt19937>(seed_);
31
- dist_ = std::make_shared<std::uniform_int_distribution<>>(0 , range);
32
- }
33
-
34
- int64 UniformSampler::Sample () const { return (*dist_)(*random_engine_); }
33
+ int64_t UniformSampler::Sample () const { return (*dist_)(*random_engine_); }
35
34
36
- float UniformSampler::Probability (int64 value) const { return inv_range_; }
35
+ float UniformSampler::Probability (int64_t value) const { return inv_range_; }
37
36
38
- LogUniformSampler::LogUniformSampler (int64 range)
39
- : Sampler(range), log_range_(log(range + 1 )) {
40
- random_engine_ = std::make_shared<std::mt19937>(seed_);
41
- dist_ = std::make_shared<std::uniform_real_distribution<>>(0 , 1 );
42
- }
43
-
44
- LogUniformSampler::LogUniformSampler (int64 range, unsigned int seed)
37
+ LogUniformSampler::LogUniformSampler (int64_t range, unsigned int seed)
45
38
: Sampler(range, seed), log_range_(log(range + 1 )) {
46
- random_engine_ = std::make_shared<std::mt19937 >(seed_);
39
+ random_engine_ = std::make_shared<std::mt19937_64 >(seed_);
47
40
dist_ = std::make_shared<std::uniform_real_distribution<>>(0 , 1 );
48
41
}
49
- int64 LogUniformSampler::Sample () const {
42
+
43
+ int64_t LogUniformSampler::Sample () const {
50
44
// Got Log Uniform distribution from uniform distribution by
51
45
// inverse_transform_sampling method
52
46
// More details:
53
47
// https://wanghaoshuang.github.io/2017/11/Log-uniform-distribution-sampler/
54
- const int64 value =
55
- static_cast <int64 >(exp ((*dist_)(*random_engine_) * log_range_)) - 1 ;
48
+ const int64_t value =
49
+ static_cast <int64_t >(exp ((*dist_)(*random_engine_) * log_range_)) - 1 ;
56
50
// Mathematically, value should be <= range_, but might not be due to some
57
51
// floating point roundoff, so we mod by range_.
58
52
return value % range_;
59
53
}
60
54
61
- float LogUniformSampler::Probability (int64 value) const {
55
+ float LogUniformSampler::Probability (int64_t value) const {
62
56
// Given f(x) = 1/[(x+1) * log_range_]
63
57
// The value's probability is integral of f(x) from value to (value + 1)
64
58
// More details:
65
59
// https://wanghaoshuang.github.io/2017/11/Log-uniform-distribution-sampler
66
60
return (log ((value + 2.0 ) / (value + 1.0 ))) / log_range_;
67
61
}
68
62
69
- } // namespace random
63
+ CustomSampler::CustomSampler (int64_t range, const float * probabilities,
64
+ unsigned int seed)
65
+ : Sampler(range, seed) {
66
+ random_engine_ = std::make_shared<std::mt19937_64>(seed_);
67
+ real_dist_ = std::make_shared<std::uniform_real_distribution<>>(0 , 1 );
68
+ int_dist_ = std::make_shared<std::uniform_int_distribution<>>(0 , range);
69
+ alias_probs_ = std::make_shared<std::vector<float >>(range + 1 );
70
+ alias_ = std::make_shared<std::vector<int64_t >>(range + 1 );
71
+ probs_ = std::make_shared<std::vector<float >>(range + 1 );
72
+
73
+ std::queue<std::pair<int64_t , float >> bigs;
74
+ std::queue<std::pair<int64_t , float >> littles;
75
+ for (int64_t i = 0 ; i <= range; ++i) {
76
+ (*probs_)[i] = probabilities[i];
77
+ float normal_prob = probabilities[i] * (range + 1 );
78
+ if (normal_prob - 1.0 > 1e-4 ) {
79
+ bigs.emplace (i, normal_prob);
80
+ } else if (1.0 - normal_prob > 1e-4 ) {
81
+ littles.emplace (i, normal_prob);
82
+ } else {
83
+ (*alias_probs_)[i] = normal_prob;
84
+ (*alias_)[i] = -1 ;
85
+ }
86
+ }
87
+
88
+ while ((!littles.empty ()) && (!bigs.empty ())) {
89
+ auto big = bigs.front ();
90
+ auto little = littles.front ();
91
+ bigs.pop ();
92
+ littles.pop ();
93
+ (*alias_probs_)[little.first ] = little.second ;
94
+ (*alias_)[little.first ] = big.first ;
95
+ auto big_left = big.second - (1 - little.second );
96
+ if (big_left - 1.0 > 1e-4 ) {
97
+ bigs.emplace (big.first , big_left);
98
+ } else if (1.0 - big_left > 1e-4 ) {
99
+ littles.emplace (big.first , big_left);
100
+ } else {
101
+ (*alias_probs_)[big.first ] = big_left;
102
+ (*alias_)[big.first ] = -1 ;
103
+ }
104
+ }
105
+
106
+ if (!littles.empty ()) { // littles.second is close to 1.0
107
+ auto little = littles.front ();
108
+ (*alias_probs_)[little.first ] = 1.0 ;
109
+ (*alias_)[little.first ] = -1 ;
110
+ }
111
+
112
+ if (!bigs.empty ()) { // bigs.second is close to 1.0
113
+ auto big = bigs.front ();
114
+ (*alias_probs_)[big.first ] = 1.0 ;
115
+ (*alias_)[big.first ] = -1 ;
116
+ }
117
+ }
118
+
119
+ int64_t CustomSampler::Sample () const {
120
+ auto index = (*int_dist_)(*random_engine_);
121
+ auto p = (*real_dist_)(*random_engine_);
122
+ if (p > (*alias_probs_)[index]) {
123
+ return (*alias_)[index];
124
+ } else {
125
+ return index;
126
+ }
127
+ }
128
+
129
+ float CustomSampler::Probability (int64_t value) const {
130
+ return (*probs_)[value];
131
+ }
132
+
133
+ } // namespace math
134
+ } // namespace operators
70
135
} // namespace paddle
0 commit comments