1
+ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2
+
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License. */
14
+
1
15
#include " sampler.h"
2
16
3
17
namespace paddle {
@@ -7,8 +21,13 @@ Sampler::~Sampler() {}
7
21
8
22
UniformSampler::UniformSampler (int64 range)
9
23
: Sampler(range), inv_range_(1.0 / range) {
10
- std::random_device r;
11
- random_engine_ = std::make_shared<std::mt19937>(r ());
24
+ random_engine_ = std::make_shared<std::mt19937>(seed_);
25
+ dist_ = std::make_shared<std::uniform_int_distribution<>>(0 , range);
26
+ }
27
+
28
+ UniformSampler::UniformSampler (int64 range, unsigned int seed)
29
+ : Sampler(range, seed), inv_range_(1.0 / range) {
30
+ random_engine_ = std::make_shared<std::mt19937>(seed_);
12
31
dist_ = std::make_shared<std::uniform_int_distribution<>>(0 , range);
13
32
}
14
33
@@ -18,11 +37,15 @@ float UniformSampler::Probability(int64 value) const { return inv_range_; }
18
37
19
38
LogUniformSampler::LogUniformSampler (int64 range)
20
39
: Sampler(range), log_range_(log(range + 1 )) {
21
- std::random_device r;
22
- random_engine_ = std::make_shared<std::mt19937>(r ());
40
+ random_engine_ = std::make_shared<std::mt19937>(seed_);
23
41
dist_ = std::make_shared<std::uniform_real_distribution<>>(0 , 1 );
24
42
}
25
43
44
+ LogUniformSampler::LogUniformSampler (int64 range, unsigned int seed)
45
+ : Sampler(range, seed), log_range_(log(range + 1 )) {
46
+ random_engine_ = std::make_shared<std::mt19937>(seed_);
47
+ dist_ = std::make_shared<std::uniform_real_distribution<>>(0 , 1 );
48
+ }
26
49
int64 LogUniformSampler::Sample () const {
27
50
// Got Log Uniform distribution from uniform distribution by
28
51
// inverse_transform_sampling method
0 commit comments