Skip to content

Commit b1c9fa0

Browse files
committed
Refactor tensorboard class. Add n_copies to environments. Update README doc
1 parent fec3d04 commit b1c9fa0

File tree

16 files changed

+214
-116
lines changed

16 files changed

+214
-116
lines changed

CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,6 @@ FILE(GLOB SRCS src/bitrl/*.cpp
158158
src/bitrl/dynamics/*.cpp
159159
src/bitrl/utils/*.cpp
160160
src/bitrl/utils/io/*.cpp
161-
src/bitrl/utils/io/tensor_board_server/*.cpp
162161
src/bitrl/utils/maths/statistics/distributions/*.cpp
163162
#src/bitrl/utils/geometry/*.cpp
164163
#src/bitrl/utils/geometry/shapes/*.cpp

README.md

Lines changed: 71 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -13,94 +13,121 @@ The following is an example how to use the
1313
#include "bitrl/envs/gymnasium/toy_text/frozen_lake_env.h"
1414
#include "bitrl/network/rest_rl_env_client.h"
1515
16+
#include <any>
1617
#include <iostream>
1718
#include <string>
1819
#include <unordered_map>
19-
#include <any>
2020
21-
namespace example_1{
21+
namespace example_1
22+
{
23+
using namespace bitrl;
2224
2325
const std::string SERVER_URL = "http://0.0.0.0:8001/api";
2426
2527
using bitrl::envs::gymnasium::FrozenLake;
26-
using bitrl::envs::RESTApiServerWrapper;
28+
using bitrl::network::RESTRLEnvClient;
29+
30+
void test_frozen_lake(RESTRLEnvClient &server)
31+
{
2732
28-
void test_frozen_lake(const RESTApiServerWrapper& server){
33+
// the environment is not registered with the server
34+
std::cout << "Is environment registered: " << server.is_env_registered(FrozenLake<4>::name)
35+
<< std::endl;
2936
37+
// when the environment is created we register it with the REST client
3038
FrozenLake<4> env(server);
3139
32-
std::cout<<"Environame URL: "<<env.get_url()<<std::endl;
40+
// environment name can also be accessed via env.env_name()
41+
std::cout << "Is environment registered: " << server.is_env_registered(env.env_name())
42+
<< std::endl;
43+
std::cout << "Environment URL: " << env.get_url() << std::endl;
3344
34-
// make the environment
45+
// make the environment we pass both make options
46+
// and reset options
3547
std::unordered_map<std::string, std::any> make_ops;
3648
make_ops.insert({"is_slippery", false});
3749
38-
std::unordered_map<std::string, std::any> reset_ops;
39-
reset_ops.insert({"seed", static_cast<uint_t>(42)});
50+
std::unordered_map<std::string, std::any> reset_ops;
51+
reset_ops.insert({"seed", static_cast<uint_t>(42)});
4052
env.make("v1", make_ops, reset_ops);
4153
42-
std::cout<<"Is environment created? "<<env.is_created()<<std::endl;
43-
std::cout<<"Is environment alive? "<<env.is_alive()<<std::endl;
44-
std::cout<<"Number of valid actions? "<<env.n_actions()<<std::endl;
45-
std::cout<<"Number of states? "<<env.n_states()<<std::endl;
46-
std::cout<<"Env idx: "<<env.idx()<<std::endl;
54+
// query the environemnt version
55+
std::cout << "Environment version: " << env.version() << std::endl;
56+
57+
// once the env is created we can get it's id
58+
std::cout << "Environment idx is: " << env.idx() << std::endl;
59+
60+
// the create flag should be true
61+
std::cout << "Is environment created? " << env.is_created() << std::endl;
62+
63+
// environment should be alive on the server
64+
std::cout << "Is environment alive? " << env.is_alive() << std::endl;
65+
66+
// FrozenLake is a discrete state-action env so we can
67+
// query number of actions and states
68+
std::cout << "Number of valid actions? " << env.n_actions() << std::endl;
69+
std::cout << "Number of states? " << env.n_states() << std::endl;
70+
71+
// how many copies of this environment
72+
auto n_copies = env.n_copies();
73+
std::cout << "n_copies: " << n_copies << std::endl;
4774
4875
// reset the environment
4976
auto time_step = env.reset();
5077
51-
std::cout<<"Reward on reset: "<<time_step.reward()<<std::endl;
52-
std::cout<<"Observation on reset: "<<time_step.observation()<<std::endl;
53-
std::cout<<"Is terminal state: "<<time_step.done()<<std::endl;
54-
std::cout<<"Env idx: "<<env.idx()<<std::endl;
78+
std::cout << "Reward on reset: " << time_step.reward() << std::endl;
79+
std::cout << "Observation on reset: " << time_step.observation() << std::endl;
80+
std::cout << "Is terminal state: " << time_step.done() << std::endl;
5581
5682
//...print the time_step
57-
std::cout<<time_step<<std::endl;
83+
std::cout << time_step << std::endl;
5884
5985
// take an action in the environment
60-
// 2 = RIGHT
86+
// 2 = RIGHT
6187
auto new_time_step = env.step(2);
62-
63-
std::cout<<new_time_step<<std::endl;
88+
std::cout << new_time_step << std::endl;
6489
6590
// get the dynamics of the environment for the given state and action
6691
auto state = 0;
6792
auto action = 1;
6893
auto dynamics = env.p(state, action);
6994
70-
std::cout<<"Dynamics for state="<<state<<" and action="<<action<<std::endl;
71-
72-
for(auto item:dynamics){
73-
std::cout<<std::get<0>(item)<<std::endl;
74-
std::cout<<std::get<1>(item)<<std::endl;
75-
std::cout<<std::get<2>(item)<<std::endl;
76-
std::cout<<std::get<3>(item)<<std::endl;
95+
std::cout << "Dynamics for state=" << state << " and action=" << action << std::endl;
96+
for (auto item : dynamics)
97+
{
98+
std::cout << std::get<0>(item) << std::endl;
99+
std::cout << std::get<1>(item) << std::endl;
100+
std::cout << std::get<2>(item) << std::endl;
101+
std::cout << std::get<3>(item) << std::endl;
77102
}
78-
79-
action = env.sample_action();
80-
std::cout<<"Action sampled: "<<action<<std::endl;
81-
82-
new_time_step = env.step(action);
83-
std::cout<<new_time_step<<std::endl;
103+
104+
// discrete action environments can sample
105+
// actions
106+
action = env.sample_action();
107+
std::cout << "Action sampled: " << action << std::endl;
108+
109+
new_time_step = env.step(action);
110+
std::cout << new_time_step << std::endl;
84111
85112
// close the environment
86113
env.close();
87114
}
88-
}
89115
116+
} // namespace example_1
90117
91-
int main(){
92-
93-
using namespace example_1;
94-
95-
RESTApiServerWrapper server(SERVER_URL, true);
96-
97-
std::cout<<"Testing FrozenLake..."<<std::endl;
118+
int main()
119+
{
120+
using namespace example_1;
121+
RESTRLEnvClient server(SERVER_URL, false);
122+
123+
std::cout << "Testing FrozenLake..." << std::endl;
98124
example_1::test_frozen_lake(server);
99-
std::cout<<"===================="<<std::endl;
100-
125+
std::cout << "====================" << std::endl;
126+
101127
return 0;
102128
}
103129
130+
104131
```
105132

106133
Gymnasium environments exposed over a REST like API can be found at: <a href="https://github.com/pockerman/bitrl-rest-api">bitrl-rest-api</a>

examples/example_1/example_1.cpp

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,10 @@ void test_frozen_lake(RESTRLEnvClient &server)
6363
std::cout << "Number of valid actions? " << env.n_actions() << std::endl;
6464
std::cout << "Number of states? " << env.n_states() << std::endl;
6565

66+
// how many copies of this environment
67+
auto n_copies = env.n_copies();
68+
std::cout << "n_copies: " << n_copies << std::endl;
69+
6670
// reset the environment
6771
auto time_step = env.reset();
6872

@@ -230,6 +234,11 @@ void test_cliff_world(RESTRLEnvClient &server)
230234
std::cout << "Number of valid actions? " << env.n_actions() << std::endl;
231235
std::cout << "Number of states? " << env.n_states() << std::endl;
232236

237+
// how many copies of this environment
238+
auto n_copies = env.n_copies();
239+
std::cout << "n_copies: " << n_copies << std::endl;
240+
241+
233242
// reset the environment
234243
auto time_step = env.reset();
235244

@@ -243,7 +252,6 @@ void test_cliff_world(RESTRLEnvClient &server)
243252
// take an action in the environment
244253
// 0 = UP
245254
auto new_time_step = env.step(0);
246-
247255
std::cout << new_time_step << std::endl;
248256

249257
// get the dynamics of the environment for the given state and action
@@ -262,6 +270,8 @@ void test_cliff_world(RESTRLEnvClient &server)
262270
std::cout << std::get<3>(item) << std::endl;
263271
}
264272

273+
274+
265275
// close the environment
266276
env.close();
267277
}

examples/example_12/example_12.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
#include "bitrl/bitrl_types.h"
2121
#include "bitrl/utils/io/json_file_reader.h"
22-
#include "bitrl/utils/io/tensor_board_server/tensorboard_server.h"
22+
#include "bitrl/network/tensorboard_server.h"
2323

2424
#include <filesystem>
2525
#include <iostream>
@@ -30,7 +30,7 @@ namespace example_12
3030
{
3131
using namespace bitrl;
3232
using utils::io::JSONFileReader;
33-
using utils::io::TensorboardServer;
33+
using network::TensorboardServer;
3434

3535
namespace fs = std::filesystem;
3636
const std::string CONFIG = "config.json";
@@ -90,7 +90,6 @@ int main()
9090
}
9191
catch (...)
9292
{
93-
9493
std::cout << "Unknown exception occured" << std::endl;
9594
}
9695

src/bitrl/envs/connect2/connect2_env.cpp

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,20 +14,17 @@ namespace envs::connect2
1414
{
1515

1616
const std::string Connect2::name = "Connect2";
17+
std::atomic<uint_t> Connect2::n_copies_ = 0;
1718

1819
Connect2::Connect2()
1920
: EnvBase<TimeStep<std::vector<uint_t>>,
2021
DiscreteVectorStateDiscreteActionEnv<53, 0, 4, uint_t>>("Connect2"),
2122
discount_(1.0), board_()
2223
{
24+
++n_copies_;
2325
}
2426

25-
Connect2::Connect2(const Connect2 &other)
26-
: EnvBase<TimeStep<std::vector<uint_t>>,
27-
DiscreteVectorStateDiscreteActionEnv<53, 0, 4, uint_t>>(other),
28-
discount_(1.0), board_(other.board_), is_finished_(other.is_finished_)
29-
{
30-
}
27+
3128

3229
void Connect2::make(const std::string & /*version*/,
3330
const std::unordered_map<std::string, std::any> &options,

src/bitrl/envs/connect2/connect2_env.h

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include <string>
2020
#include <unordered_map>
2121
#include <vector>
22+
#include <atomic>
2223

2324
namespace bitrl
2425
{
@@ -84,11 +85,6 @@ class Connect2 final : public EnvBase<TimeStep<std::vector<uint_t>>,
8485
///
8586
Connect2();
8687

87-
///
88-
///
89-
///
90-
Connect2(const Connect2 &other);
91-
9288
///
9389
/// \brief make. Builds the environment. Optionally we can choose if the
9490
/// environment will be slippery
@@ -115,12 +111,6 @@ class Connect2 final : public EnvBase<TimeStep<std::vector<uint_t>>,
115111
///
116112
virtual time_step_type reset() override final;
117113

118-
///
119-
/// \brief Create a new copy of the environment with the given
120-
/// copy index
121-
///
122-
Connect2 make_copy(uint_t cidx) const;
123-
124114
///
125115
/// \brief n_states. Returns the number of states
126116
///
@@ -151,6 +141,14 @@ class Connect2 final : public EnvBase<TimeStep<std::vector<uint_t>>,
151141
///
152142
std::vector<uint_t> get_valid_moves() const;
153143

144+
/**
145+
* Get the number of copies of this class
146+
* @return
147+
*/
148+
static uint_t n_copies() {
149+
return n_copies_.load();
150+
}
151+
154152
private:
155153
///
156154
/// \brief The discount factor
@@ -172,6 +170,12 @@ class Connect2 final : public EnvBase<TimeStep<std::vector<uint_t>>,
172170
///
173171
const uint_t win_val_{2};
174172

173+
/**
174+
* Counter to count the number of instances of this
175+
* class.
176+
*/
177+
static std::atomic<uint_t> n_copies_;
178+
175179
///
176180
/// \brief The representation of the board
177181
///

src/bitrl/envs/gdrl/gym_walk.h

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,13 @@ class GymWalk final : public EnvBase<TimeStep<uint_t>, ScalarDiscreteEnv<state_s
7474
* @brief Get the full URL for this environment endpoint on the server.
7575
* @return Environment URL string.
7676
*/
77-
std::string get_url() const{return api_server_->get_env_url(this->env_name());};
77+
std::string get_url() const{return api_server_->get_env_url(this->env_name());}
78+
79+
/**
80+
* Get the number of copies on the server for this environment
81+
* @return
82+
*/
83+
uint_t n_copies() const;
7884

7985
private:
8086
dynamics_t build_dynamics_from_response_(const nlohmann::json &response) const;
@@ -161,6 +167,14 @@ template <uint_t state_size> bool GymWalk<state_size>::is_alive() const
161167
return response["result"];
162168
}
163169

170+
template <uint_t state_size>
171+
uint_t GymWalk<state_size>::n_copies() const
172+
{
173+
auto response = this->api_server_->n_copies(this->env_name());
174+
return response["copies"];
175+
}
176+
177+
164178
template <uint_t state_size> void GymWalk<state_size>::close()
165179
{
166180

src/bitrl/envs/grid_world/grid_world_env.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
namespace bitrl::envs::grid_world
1010
{
1111

12+
13+
1214
GridWorldInitType from_string(const std::string &gw_init_type)
1315
{
1416

0 commit comments

Comments
 (0)