Skip to content

Commit 1446484

Browse files
committed
Add beta distribution wrapper. Fix api names
1 parent 73a0892 commit 1446484

File tree

1 file changed

+171
-0
lines changed
  • src/rlenvs/utils/maths/statistics/distributions

1 file changed

+171
-0
lines changed
Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
#ifndef BETA_DIST_H
2+
#define BETA_DIST_H
3+
4+
#include "rlenvs/rlenvs_types_v2.h"
5+
#include "rlenvs/rlenvs_consts.h"
6+
#include "rlevns/utils/maths/math_utils.h"
7+
8+
#include <boost/random/beta_distribution.hpp>
9+
10+
#include <vector>
11+
#include <cmath>
12+
#include <type_traits>
13+
14+
15+
16+
namespace rlenvscpp {
17+
namespace utils{
18+
namespace maths {
19+
namespace stats {
20+
21+
///
22+
/// \brief The beta distribution
23+
/// is a real-valued distribution which produces values in the range [0, 1].
24+
/// It has two parameters, alpha and beta.
25+
///
26+
template<typename RealType = real_t>
27+
class BetaDist
28+
{
29+
public:
30+
31+
static_assert(is_floating_point<RealType>::value, "Not a floating point type");
32+
33+
///
34+
/// \breif The return type every time we call pdf, sample
35+
///
36+
typedef RealType result_type;
37+
38+
///
39+
/// \brief Constructor
40+
///
41+
explicit BetaDist(result_type alpha=1.0, result_type std = 1.0);
42+
43+
///
44+
/// \brief compute the value of the PDF at the given point
45+
///
46+
result_type pdf(result_type x)const;
47+
48+
///
49+
/// \brief Sample from the distribution
50+
///
51+
result_type sample() const;
52+
53+
///
54+
/// \brief Sample from the distribution
55+
///
56+
result_type sample(uint_t seed) const;
57+
58+
///
59+
/// \brief sample from the distribution
60+
///
61+
std::vector<result_type> sample_many(uint_t size) const;
62+
63+
///
64+
/// \brief sample from the distribution
65+
///
66+
std::vector<result_type> sample_many(uint_t size, uint_t seed) const;
67+
68+
///
69+
/// \brief The mean value of the distribution
70+
/// see https://en.wikipedia.org/wiki/Beta_distribution
71+
///
72+
result_type mean()const{return dist_.alpha / (dist_.alpha() + dist_.beta());}
73+
74+
///
75+
/// \brief The variance of the distribution.
76+
/// see https://en.wikipedia.org/wiki/Beta_distribution
77+
///
78+
result_type variance()const{return dist_.stddev();}
79+
80+
private:
81+
82+
///
83+
/// \brief The underlying distribution. Mutable
84+
/// as the API exposes const methods and the compiler
85+
/// complains
86+
///
87+
mutable boost::beta_distribution<RealType> dist_;
88+
89+
};
90+
91+
92+
template<typename RealType>
93+
BetaDist<RealType>::BetaDist(RealType alpha, RealType beta)
94+
:
95+
dist_(alpha, beta)
96+
{}
97+
98+
template<typename RealType>
99+
RealType
100+
BetaDist<RealType>::variance()const{
101+
auto a = dist_.alpha()
102+
auto b = dist_.beta();
103+
return a*b / (utils::maths::sqr(a + b)*(a + b + 1) );
104+
}
105+
106+
template<typename RealType>
107+
RealType
108+
BetaDist<RealType>::sample() const{
109+
110+
std::random_device rd{};
111+
std::mt19937 gen{rd()};
112+
return dist_(gen);
113+
114+
}
115+
116+
template<typename RealType>
117+
RealType
118+
BetaDist<RealType>::sample(uint_t seed) const{
119+
120+
std::mt19937 gen{seed};
121+
return dist_(gen);
122+
}
123+
124+
template<typename RealType>
125+
std::vector<RealType>
126+
BetaDist<RealType>::sample_many(uint_t size) const{
127+
128+
std::vector<RealType> samples(size);
129+
std::random_device rd{};
130+
std::mt19937 gen{rd()};
131+
132+
for(uint_t i=0; i<size; ++i){
133+
samples[i] = dist_(gen);
134+
}
135+
136+
return samples;
137+
}
138+
139+
template<typename RealType>
140+
std::vector<RealType>
141+
BetaDist<RealType>::sample_many(uint_t size, uint_t seed) const{
142+
143+
std::vector<RealType> samples(size);
144+
std::mt19937 gen(seed);
145+
146+
for(uint_t i=0; i<size; ++i){
147+
samples[i] = dist_(gen);
148+
}
149+
150+
return samples;
151+
}
152+
153+
template<typename RealType>
154+
RealType
155+
BetaDist<RealType>::pdf(RealType x)const{
156+
157+
auto alpha = dist_.alpha();
158+
auto beta = dist_.beta();
159+
auto beta_func = std::beta(alpha, beta);
160+
return std::pow(x, alpha - 1)*std::pow(1.0-x, beta - 1.0) / beta_func;
161+
162+
}
163+
164+
}
165+
}
166+
}
167+
}
168+
169+
170+
171+
#endif

0 commit comments

Comments
 (0)