Skip to content

Commit 62efc89

Browse files
committed
Refine code
1. Add copyright info 2. Overload structure for customized random seed
1 parent 16ed4a9 commit 62efc89

File tree

2 files changed

+41
-6
lines changed

2 files changed

+41
-6
lines changed

paddle/operators/math/sampler.cc

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,17 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
115
#include "sampler.h"
216

317
namespace paddle {
@@ -7,8 +21,13 @@ Sampler::~Sampler() {}
721

822
UniformSampler::UniformSampler(int64 range)
923
: Sampler(range), inv_range_(1.0 / range) {
10-
std::random_device r;
11-
random_engine_ = std::make_shared<std::mt19937>(r());
24+
random_engine_ = std::make_shared<std::mt19937>(seed_);
25+
dist_ = std::make_shared<std::uniform_int_distribution<>>(0, range);
26+
}
27+
28+
UniformSampler::UniformSampler(int64 range, unsigned int seed)
29+
: Sampler(range, seed), inv_range_(1.0 / range) {
30+
random_engine_ = std::make_shared<std::mt19937>(seed_);
1231
dist_ = std::make_shared<std::uniform_int_distribution<>>(0, range);
1332
}
1433

@@ -18,11 +37,15 @@ float UniformSampler::Probability(int64 value) const { return inv_range_; }
1837

1938
LogUniformSampler::LogUniformSampler(int64 range)
2039
: Sampler(range), log_range_(log(range + 1)) {
21-
std::random_device r;
22-
random_engine_ = std::make_shared<std::mt19937>(r());
40+
random_engine_ = std::make_shared<std::mt19937>(seed_);
2341
dist_ = std::make_shared<std::uniform_real_distribution<>>(0, 1);
2442
}
2543

44+
LogUniformSampler::LogUniformSampler(int64 range, unsigned int seed)
45+
: Sampler(range, seed), log_range_(log(range + 1)) {
46+
random_engine_ = std::make_shared<std::mt19937>(seed_);
47+
dist_ = std::make_shared<std::uniform_real_distribution<>>(0, 1);
48+
}
2649
int64 LogUniformSampler::Sample() const {
2750
// Got Log Uniform distribution from uniform distribution by
2851
// inverse_transform_sampling method

paddle/operators/math/sampler.h

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,21 @@ namespace paddle {
2020
namespace operators {
2121
namespace math {
2222

23-
// TODO: Support for GPU
23+
// TODO(wanghaoshuang): Support for GPU
2424

2525
/**
2626
* Sample integers from [0, range).
2727
*/
2828
class Sampler {
2929
public:
30-
explicit Sampler(int64 range) : range_(range) { /* check range > 0*/
30+
explicit Sampler(int64 range) : range_(range) {
31+
PADDLE_ENFORCE_GT(range, 0);
32+
std::random_device r;
33+
seed_ = r();
34+
}
35+
explicit Sampler(int64 range, unsigned int seed)
36+
: range_(range), seed_(seed) {
37+
PADDLE_ENFORCE_GT(range, 0);
3138
}
3239
virtual ~Sampler();
3340
// Sample a single value
@@ -39,6 +46,7 @@ class Sampler {
3946

4047
protected:
4148
const int64 range_;
49+
unsigned int seed_;
4250
};
4351

4452
/**
@@ -50,6 +58,8 @@ class UniformSampler : public Sampler {
5058
public:
5159
explicit UniformSampler(int64 range);
5260

61+
explicit UniformSampler(int64 range, unsigned int seed);
62+
5363
~UniformSampler() override {}
5464

5565
int64 Sample() const override;
@@ -71,6 +81,8 @@ class LogUniformSampler : public Sampler {
7181
public:
7282
explicit LogUniformSampler(int64 range);
7383

84+
explicit LogUniformSampler(int64 range, unsigned int seed);
85+
7486
~LogUniformSampler() override {}
7587

7688
int64 Sample() const override;

0 commit comments

Comments
 (0)