File tree Expand file tree Collapse file tree 2 files changed +13
-1
lines changed
Expand file tree Collapse file tree 2 files changed +13
-1
lines changed Original file line number Diff line number Diff line change 2424
2525namespace sokoban {
2626
27- void SokobanEnv::Reset () {
27+ void SokobanEnv::ResetWithoutWrite () {
2828 const int max_episode_steps = spec_.config [" max_episode_steps" _];
2929 const int min_episode_steps = spec_.config [" min_episode_steps" _];
3030 current_max_episode_steps_ =
@@ -52,6 +52,10 @@ void SokobanEnv::Reset() {
5252 }
5353 }
5454 current_step_ = 0 ;
55+ }
56+
57+ void SokobanEnv::Reset () {
58+ ResetWithoutWrite ();
5559 WriteState (0 .0f );
5660}
5761
@@ -142,6 +146,7 @@ void SokobanEnv::Step(const Action& action_dict) {
142146 reward_box_ * static_cast <double >(prev_unmatched_boxes -
143147 unmatched_boxes_) +
144148 ((unmatched_boxes_ == 0 ) ? reward_finished_ : 0 .0f );
149+
145150 WriteState (static_cast <float >(reward));
146151}
147152
@@ -177,6 +182,12 @@ void SokobanEnv::WriteState(float reward) {
177182 throw std::runtime_error (msg.str ());
178183 }
179184
185+ if (IsDone ()) {
186+ // If this episode truncates or terminates, the observation should be the
187+ // one for the next episode.
188+ ResetWithoutWrite ();
189+ }
190+
180191 std::vector<uint8_t > out (3 * world_.size ());
181192 for (int rgb = 0 ; rgb < 3 ; rgb++) {
182193 for (size_t i = 0 ; i < world_.size (); i++) {
Original file line number Diff line number Diff line change @@ -115,6 +115,7 @@ class SokobanEnv : public Env<SokobanEnvSpec> {
115115
116116 [[nodiscard]] uint8_t WorldAt (int x, int y) const ;
117117 void WorldAssignAt (int x, int y, uint8_t value);
118+ void ResetWithoutWrite ();
118119};
119120
120121using SokobanEnvPool = AsyncEnvPool<SokobanEnv>;
You can’t perform that action at this time.
0 commit comments