Skip to content

Commit ce439db

Browse files
authored
Fix delayed resetting on solving a level or on truncation (#11)
2 parents 9b32dd0 + 6b1b577 commit ce439db

File tree

2 files changed

+13
-1
lines changed

2 files changed

+13
-1
lines changed

envpool/sokoban/sokoban_envpool.cc

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
namespace sokoban {
2626

27-
void SokobanEnv::Reset() {
27+
void SokobanEnv::ResetWithoutWrite() {
2828
const int max_episode_steps = spec_.config["max_episode_steps"_];
2929
const int min_episode_steps = spec_.config["min_episode_steps"_];
3030
current_max_episode_steps_ =
@@ -52,6 +52,10 @@ void SokobanEnv::Reset() {
5252
}
5353
}
5454
current_step_ = 0;
55+
}
56+
57+
void SokobanEnv::Reset() {
58+
ResetWithoutWrite();
5559
WriteState(0.0f);
5660
}
5761

@@ -142,6 +146,7 @@ void SokobanEnv::Step(const Action& action_dict) {
142146
reward_box_ * static_cast<double>(prev_unmatched_boxes -
143147
unmatched_boxes_) +
144148
((unmatched_boxes_ == 0) ? reward_finished_ : 0.0f);
149+
145150
WriteState(static_cast<float>(reward));
146151
}
147152

@@ -177,6 +182,12 @@ void SokobanEnv::WriteState(float reward) {
177182
throw std::runtime_error(msg.str());
178183
}
179184

185+
if (IsDone()) {
186+
// If this episode truncates or terminates, the observation should be the
187+
// one for the next episode.
188+
ResetWithoutWrite();
189+
}
190+
180191
std::vector<uint8_t> out(3 * world_.size());
181192
for (int rgb = 0; rgb < 3; rgb++) {
182193
for (size_t i = 0; i < world_.size(); i++) {

envpool/sokoban/sokoban_envpool.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ class SokobanEnv : public Env<SokobanEnvSpec> {
115115

116116
[[nodiscard]] uint8_t WorldAt(int x, int y) const;
117117
void WorldAssignAt(int x, int y, uint8_t value);
118+
void ResetWithoutWrite();
118119
};
119120

120121
using SokobanEnvPool = AsyncEnvPool<SokobanEnv>;

0 commit comments

Comments
 (0)