Skip to content

Commit 16ed4a9

Browse files
committed
Add math function for sampling integers from:
1. uniform distribution 2. log uniform distribution
1 parent d89ff5b commit 16ed4a9

File tree

2 files changed

+135
-0
lines changed

2 files changed

+135
-0
lines changed

paddle/operators/math/sampler.cc

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
#include "sampler.h"
2+
3+
namespace paddle {
4+
namespace random {
5+
6+
Sampler::~Sampler() {}
7+
8+
UniformSampler::UniformSampler(int64 range)
9+
: Sampler(range), inv_range_(1.0 / range) {
10+
std::random_device r;
11+
random_engine_ = std::make_shared<std::mt19937>(r());
12+
dist_ = std::make_shared<std::uniform_int_distribution<>>(0, range);
13+
}
14+
15+
int64 UniformSampler::Sample() const { return (*dist_)(*random_engine_); }
16+
17+
float UniformSampler::Probability(int64 value) const { return inv_range_; }
18+
19+
LogUniformSampler::LogUniformSampler(int64 range)
20+
: Sampler(range), log_range_(log(range + 1)) {
21+
std::random_device r;
22+
random_engine_ = std::make_shared<std::mt19937>(r());
23+
dist_ = std::make_shared<std::uniform_real_distribution<>>(0, 1);
24+
}
25+
26+
int64 LogUniformSampler::Sample() const {
27+
// Got Log Uniform distribution from uniform distribution by
28+
// inverse_transform_sampling method
29+
// More details:
30+
// https://wanghaoshuang.github.io/2017/11/Log-uniform-distribution-sampler/
31+
const int64 value =
32+
static_cast<int64>(exp((*dist_)(*random_engine_) * log_range_)) - 1;
33+
// Mathematically, value should be <= range_, but might not be due to some
34+
// floating point roundoff, so we mod by range_.
35+
return value % range_;
36+
}
37+
38+
float LogUniformSampler::Probability(int64 value) const {
39+
// Given f(x) = 1/[(x+1) * log_range_]
40+
// The value's probability is integral of f(x) from value to (value + 1)
41+
// More details:
42+
// https://wanghaoshuang.github.io/2017/11/Log-uniform-distribution-sampler
43+
return (log((value + 2.0) / (value + 1.0))) / log_range_;
44+
}
45+
46+
} // namespace random
47+
} // namespace paddle

paddle/operators/math/sampler.h

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
#include <memory>
17+
#include <random>
18+
typedef long int64;
19+
namespace paddle {
20+
namespace operators {
21+
namespace math {
22+
23+
// TODO: Support for GPU
24+
25+
/**
26+
* Sample integers from [0, range).
27+
*/
28+
class Sampler {
29+
public:
30+
explicit Sampler(int64 range) : range_(range) { /* check range > 0*/
31+
}
32+
virtual ~Sampler();
33+
// Sample a single value
34+
virtual int64 Sample() const = 0;
35+
// The probability that a single call to Sample() returns the given value.
36+
virtual float Probability(int64 value) const = 0;
37+
38+
int64 range() { return range_; };
39+
40+
protected:
41+
const int64 range_;
42+
};
43+
44+
/**
45+
* Sample integers from [0, range).
46+
* And the distribution function is:
47+
* P(x) = 1 / range
48+
*/
49+
class UniformSampler : public Sampler {
50+
public:
51+
explicit UniformSampler(int64 range);
52+
53+
~UniformSampler() override {}
54+
55+
int64 Sample() const override;
56+
57+
float Probability(int64 value) const override;
58+
59+
private:
60+
const float inv_range_;
61+
std::shared_ptr<std::mt19937_64> random_engine_;
62+
std::shared_ptr<std::uniform_int_distribution<>> dist_;
63+
};
64+
65+
/**
66+
* Sample integers from [0, range).
67+
* And the distribution function is:
68+
* P(x) = (1/ln(range+1)) * ln(1 + 1/(x + 1))
69+
*/
70+
class LogUniformSampler : public Sampler {
71+
public:
72+
explicit LogUniformSampler(int64 range);
73+
74+
~LogUniformSampler() override {}
75+
76+
int64 Sample() const override;
77+
78+
float Probability(int64 value) const override;
79+
80+
private:
81+
const float log_range_;
82+
std::shared_ptr<std::mt19937_64> random_engine_;
83+
std::shared_ptr<std::uniform_real_distribution<>> dist_;
84+
};
85+
86+
} // math
87+
} // namespace operators
88+
} // namespace paddle

0 commit comments

Comments
 (0)