1515#include " ecole/scip/type.hpp"
1616#include " ecole/traits.hpp"
1717
18- #include < iostream >
18+ #include < optional >
1919
2020template <typename T> struct is_optional : std::false_type {};
2121template <typename T> struct is_optional <std::optional<T>> : std::true_type {};
@@ -67,8 +67,8 @@ class Environment {
6767 std::map<std::string, scip::Param> scip_params = {},
6868 Args&&... args) :
6969 the_dynamics (std::forward<Args>(args)...),
70- the_observation_function (data::parse(std::move(observation_function))),
7170 the_reward_function (data::parse(std::move(reward_function))),
71+ the_observation_function (data::parse(std::move(observation_function))),
7272 the_information_function (data::parse(std::move(information_function))),
7373 the_scip_params (std::move(scip_params)),
7474 the_random_engine (spawn_random_engine()) {}
@@ -112,22 +112,23 @@ class Environment {
112112 dynamics ().set_dynamics_random_state (model (), random_engine ());
113113
114114 // Reset data extraction function and bring model to initial state.
115- observation_function ().before_reset (model ());
116115 reward_function ().before_reset (model ());
116+ observation_function ().before_reset (model ());
117117 information_function ().before_reset (model ());
118- auto const [done, action_set] = dynamics ().reset_dynamics (model (), std::forward<Args>(args)...);
118+
119+ // Place the environment in its initial state
120+ auto [done, action_set] = dynamics ().reset_dynamics (model (), std::forward<Args>(args)...);
119121 can_transition = !done;
120122
121- auto observation = OptionalObservation{};
122- if (!done) {
123- observation = observation_function ().extract (model (), done);
124- }
123+ // Extract additional information to be returned by reset
124+ auto [reward, observation, information] = extract_reward_observation_information (done);
125+
125126 return {
126127 std::move (observation),
127128 std::move (action_set),
128- reward_function (). extract ( model (), done ),
129+ std::move (reward ),
129130 done,
130- information_function (). extract ( model (), done ),
131+ std::move (information ),
131132 };
132133 } catch (std::exception const &) {
133134 can_transition = false ;
@@ -170,19 +171,19 @@ class Environment {
170171 throw Exception (" Environment need to be reset." );
171172 }
172173 try {
173- auto const [done, action_set] = dynamics ().step_dynamics (model (), action, std::forward<Args>(args)...);
174+ // Transition the environment to the next state
175+ auto [done, action_set] = dynamics ().step_dynamics (model (), action, std::forward<Args>(args)...);
174176 can_transition = !done;
175177
176- auto observation = OptionalObservation{};
177- if (!done) {
178- observation = observation_function ().extract (model (), done);
179- }
178+ // Extract additional information to be returned by step
179+ auto [reward, observation, information] = extract_reward_observation_information (done);
180+
180181 return {
181182 std::move (observation),
182183 std::move (action_set),
183- reward_function (). extract ( model (), done ),
184+ std::move (reward ),
184185 done,
185- information_function (). extract ( model (), done ),
186+ std::move (information ),
186187 };
187188 } catch (std::exception const &) {
188189 can_transition = false ;
@@ -201,12 +202,22 @@ class Environment {
201202private:
202203 Dynamics the_dynamics;
203204 scip::Model the_model;
204- ObservationFunction the_observation_function;
205205 RewardFunction the_reward_function;
206+ ObservationFunction the_observation_function;
206207 InformationFunction the_information_function;
207208 std::map<std::string, scip::Param> the_scip_params;
208209 RandomEngine the_random_engine;
209210 bool can_transition = false ;
211+
212+ // extract reward, observation and information (in that order)
213+ auto extract_reward_observation_information (bool done) -> std::tuple<Reward, OptionalObservation, InformationMap> {
214+ auto reward = reward_function ().extract (model (), done);
215+ // Don't extract observations in final states
216+ auto observation = done ? OptionalObservation{} : observation_function ().extract (model (), done);
217+ auto information = information_function ().extract (model (), done);
218+
219+ return {std::move (reward), std::move (observation), std::move (information)};
220+ }
210221};
211222
212223} // namespace ecole::environment
0 commit comments