Skip to content

Commit badf71e

Browse files
committed
Add multi armed bandits environment. Add bernoulli distribution wrapper
1 parent d433044 commit badf71e

File tree

4 files changed

+404
-0
lines changed

4 files changed

+404
-0
lines changed
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
#include "rlenvs/envs/multi_armed_bandits/multi_armed_bandits.h"
2+
#include "rlenvs/envs/time_step_type.h"
3+
4+
#include <vector>
5+
#include <exception>
6+
7+
namespace rlenvscpp{
8+
namespace envs{
9+
namespace bandits{
10+
11+
const std::string MultiArmedBandits::name = "MultiArmedBandits";
12+
13+
MultiArmedBandits::MultiArmedBandits()
14+
:
15+
EnvBase<TimeStep<Null>, MultiArmedBanditsSpace>(0, "MultiArmedBandits"),
16+
bandits_()
17+
{}
18+
19+
void
20+
MultiArmedBandits::make(const std::string& version,
21+
const std::unordered_map<std::string, std::any>& options){
22+
23+
24+
auto p_itr = options.find("p");
25+
if(p_itr == options.end()){
26+
throw std::logic_error("option p is missing");
27+
}
28+
29+
auto p = std::any_cast<std::vector<real_t>>(p_itr -> second);
30+
31+
bandits_.reserve(p.size());
32+
for(auto p_: p){
33+
bandits_.push_back(utils::maths::stats::BernoulliDist(p_));
34+
}
35+
36+
auto success_reward_itr = options.find("success_reward");
37+
if(success_reward_itr != options.end()){
38+
success_reward_ = std::any_cast<real_t>(success_reward_itr -> second);
39+
}
40+
else{
41+
success_reward_ = 1.0;
42+
}
43+
44+
auto fail_reward_itr = options.find("fail_reward");
45+
if(success_reward_itr != options.end()){
46+
fail_reward_ = std::any_cast<real_t>(fail_reward_itr -> second);
47+
}
48+
else{
49+
fail_reward_ = 0.0;
50+
}
51+
52+
53+
this -> set_version_(version);
54+
this -> set_make_options_(options);
55+
this -> make_created_();
56+
57+
}
58+
59+
MultiArmedBandits::time_step_type
60+
MultiArmedBandits::reset(uint_t seed,
61+
const std::unordered_map<std::string, std::any>& /*options*/){
62+
seed_ = seed;
63+
64+
return MultiArmedBandits::time_step_type(TimeStepTp::FIRST,
65+
0.0,
66+
Null(),
67+
1.0
68+
);
69+
70+
}
71+
72+
MultiArmedBandits::time_step_type
73+
MultiArmedBandits::step(const action_type& action){
74+
75+
if(action >= bandits_.size()){
76+
throw std::logic_error("Invalid action index");
77+
}
78+
79+
auto result = bandits_[action].sample(seed_);
80+
81+
if(result){
82+
this -> get_current_time_step_() = MultiArmedBandits::time_step_type(TimeStepTp::LAST,
83+
success_reward_,
84+
Null(),
85+
1.0
86+
);
87+
}
88+
else{
89+
this -> get_current_time_step_() = MultiArmedBandits::time_step_type(TimeStepTp::LAST,
90+
fail_reward_,
91+
Null(),
92+
1.0);
93+
}
94+
95+
return this -> get_current_time_step_();
96+
}
97+
98+
void
99+
MultiArmedBandits::close(){
100+
101+
bandits_.clear();
102+
this -> EnvBase<TimeStep<Null>, MultiArmedBanditsSpace>::close();
103+
}
104+
105+
106+
}
107+
}
108+
}
Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
#ifndef MULTI_ARMED_BANDITS_H
2+
#define MULTI_ARMED_BANDITS_H
3+
4+
#include "rlenvs/rlenvs_types_v2.h"
5+
#include "rlenvs/envs/env_base.h"
6+
#include "rlenvs/envs/time_step.h"
7+
#include "rlenvs/utils/maths/statistics/distributions/bernoulli_dist.h"
8+
9+
#include <vector>
10+
#include <string>
11+
#include <any>
12+
#include <unordered_map>
13+
14+
15+
namespace rlenvscpp{
16+
namespace envs{
17+
namespace bandits{
18+
19+
///
20+
/// \brief struct MultiArmedBanditsSpace specifies the
21+
/// MultiArmedBandits state-action space
22+
///
23+
struct MultiArmedBanditsSpace
24+
{
25+
26+
///
27+
/// \brief The type describing the state space for the environment
28+
///
29+
typedef Null state_space;
30+
31+
///
32+
/// \brief The type of the state
33+
///
34+
typedef Null state_type;
35+
36+
///
37+
/// \brief The type of the action space for the environment
38+
///
39+
typedef Null action_space;
40+
41+
///
42+
/// \brief The type of the action to be undertaken in the environment
43+
///
44+
typedef uint_t action_type;
45+
46+
47+
};
48+
49+
///
50+
/// \brief class MultiArmedBandits. Environment for simulating armed-bandits
51+
/// The bandits are represented as Bernoulli distribution. At each step
52+
/// only one bandit can be executed
53+
///
54+
class MultiArmedBandits final: public EnvBase<TimeStep<Null>, MultiArmedBanditsSpace>{
55+
56+
public:
57+
58+
///
59+
/// \brief name
60+
///
61+
static const std::string name;
62+
63+
///
64+
/// \brief The base type
65+
///
66+
typedef EnvBase<TimeStep<Null>, MultiArmedBanditsSpace> base_type;
67+
68+
///
69+
/// \brief The time step type we return every time a step in the
70+
/// environment is performed
71+
///
72+
typedef typename base_type::time_step_type time_step_type;
73+
74+
///
75+
/// \brief The type describing the state space for the environment
76+
///
77+
typedef typename base_type::state_space_type state_space_type;
78+
79+
///
80+
/// \brief The type of the action space for the environment
81+
///
82+
typedef typename base_type::action_space_type action_space_type;
83+
84+
///
85+
/// \brief The type of the action to be undertaken in the environment
86+
///
87+
typedef typename base_type::action_type action_type;
88+
89+
///
90+
/// \brief The type of the action to be undertaken in the environment
91+
///
92+
typedef typename base_type::state_type state_type;
93+
94+
95+
///
96+
/// \brief MultiArmedBandits Constructor
97+
///
98+
MultiArmedBandits();
99+
100+
///
101+
/// \brief make. Builds the environment.
102+
/// \param version. the version of the environment to build
103+
/// \param options. Options to use for building the environment.
104+
/// Concrete classes may choose to hold a copy
105+
///
106+
virtual void make(const std::string& version,
107+
const std::unordered_map<std::string, std::any>& options)override final;
108+
109+
///
110+
/// \brief close the environment
111+
///
112+
virtual void close()override final;
113+
114+
///
115+
/// \brief Reset the environment
116+
/// \param seed. The seed to use for resetting the environment
117+
/// \param options. Options to use for resetting the environment.
118+
///
119+
virtual time_step_type reset(uint_t seed,
120+
const std::unordered_map<std::string, std::any>& options)override final;
121+
122+
///
123+
/// \brief step in the environment by performing the given action
124+
/// \param action. The action to execute in the environment
125+
/// \return An instance of time_step_type
126+
virtual time_step_type step(const action_type& action)override final;
127+
128+
private:
129+
130+
///
131+
/// \brief The seed to use
132+
///
133+
uint_t seed_;
134+
135+
///
136+
/// \brief The success reward
137+
///
138+
real_t success_reward_;
139+
140+
///
141+
/// \brief The reward to return on fail.
142+
/// Default is zero
143+
///
144+
real_t fail_reward_;
145+
146+
///
147+
/// \brief Every bandit is represented as a Bernoulli distribution
148+
///
149+
std::vector<utils::maths::stats::BernoulliDist> bandits_;
150+
151+
152+
};
153+
154+
}
155+
}
156+
}
157+
158+
159+
160+
#endif
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
#include "rlenvs/utils/maths/statistics/distributions/bernoulli_dist.h"
2+
3+
4+
namespace rlenvscpp {
5+
namespace utils{
6+
namespace maths {
7+
namespace stats {
8+
9+
10+
BernoulliDist::BernoulliDist(result_type p)
11+
:
12+
p_(p),
13+
dist_(p)
14+
{}
15+
16+
BernoulliDist::result_type
17+
BernoulliDist::sample() const{
18+
19+
std::random_device rd{};
20+
std::mt19937 gen{rd()};
21+
return dist_(gen);
22+
}
23+
24+
25+
BernoulliDist::result_type
26+
BernoulliDist::sample(uint_t seed) const{
27+
28+
std::mt19937 gen{seed};
29+
return dist_(gen);
30+
}
31+
32+
33+
std::vector<BernoulliDist::result_type>
34+
BernoulliDist::sample_many(uint_t size) const{
35+
36+
std::vector<BernoulliDist::result_type> samples(size);
37+
std::random_device rd{};
38+
std::mt19937 gen{rd()};
39+
40+
for(uint_t i=0; i<size; ++i){
41+
samples[i] = dist_(gen);
42+
}
43+
44+
return samples;
45+
46+
}
47+
48+
49+
std::vector<BernoulliDist::result_type>
50+
BernoulliDist::sample_many(uint_t size, uint_t seed) const{
51+
52+
std::vector<BernoulliDist::result_type> samples(size);
53+
std::mt19937 gen(seed);
54+
55+
for(uint_t i=0; i<size; ++i){
56+
samples[i] = dist_(gen);
57+
}
58+
59+
return samples;
60+
}
61+
62+
}
63+
}
64+
}
65+
}

0 commit comments

Comments
 (0)