Skip to content

Commit 899f346

Browse files
authored
Loop around level loader on finishing all levels (#13)
2 parents 9432d14 + b0905d0 commit 899f346

File tree

2 files changed

+24
-17
lines changed

2 files changed

+24
-17
lines changed

envpool/sokoban/level_loader.cc

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ LevelLoader::LevelLoader(const std::filesystem::path& base_path,
3232
int env_id, int num_envs, int verbose)
3333
: load_sequentially_(load_sequentially),
3434
n_levels_to_load_(n_levels_to_load),
35+
env_id_(env_id),
3536
num_envs_(num_envs),
3637
cur_level_(env_id),
3738
verbose(verbose) {
@@ -54,6 +55,7 @@ LevelLoader::LevelLoader(const std::filesystem::path& base_path,
5455
throw std::runtime_error(
5556
"n_levels_to_load must be a multiple of num_envs.");
5657
}
58+
n_levels_to_load_ /= num_envs_;
5759
}
5860

5961
static const std::array<char, kMaxLevelObject + 1> kPrintLevelKey{
@@ -186,16 +188,19 @@ void LevelLoader::LoadFile(std::mt19937& gen) {
186188

187189
std::vector<SokobanLevel>::iterator LevelLoader::GetLevel(std::mt19937& gen) {
188190
if (n_levels_to_load_ > 0 && levels_loaded_ >= n_levels_to_load_) {
189-
throw std::runtime_error("Loaded all requested levels.");
191+
// std::cerr << "Warning: All levels loaded. Looping around now." <<
192+
// std::endl;
193+
levels_loaded_ = 0;
194+
cur_file_ = level_file_paths_.begin();
195+
LoadFile(gen);
196+
// re-start from the `env_id`th level, like we do in the constructor.
197+
cur_level_ = env_id_;
190198
}
191199
// Load new files until the current level index is within the loaded levels
192200
// this is required when new files have lesser levels than the number of envs
193201
while (cur_level_ >= levels_.size()) {
194202
cur_level_ -= levels_.size();
195203
LoadFile(gen);
196-
if (levels_.empty()) { // new file is empty
197-
throw std::runtime_error("No levels loaded.");
198-
}
199204
}
200205
// no need for bound checks since it is checked in the while loop above
201206
auto out = levels_.begin() + cur_level_;

envpool/sokoban/sokoban_py_envpool_test.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -305,19 +305,21 @@ def test_load_sequentially_with_multiple_envs() -> None:
305305
)
306306
dim_room = env.spec.config.dim_room
307307
printed_obs = []
308-
for _ in range(total_levels // num_envs):
309-
obs, _ = env.reset()
310-
assert obs.shape == (
311-
num_envs,
312-
3,
313-
dim_room,
314-
dim_room,
315-
), f"obs shape: {obs.shape}"
316-
for idx in range(num_envs):
317-
printed_obs.append(print_obs(obs[idx]).strip().split("\n"))
318-
for i, level in enumerate(levels_by_files):
319-
for j, line in enumerate(level):
320-
assert printed_obs[i][j] == line, f"Level {i} is not loaded correctly."
308+
309+
for _ in range(2): # check loader loops around and loads levels again
310+
for _ in range(total_levels // num_envs):
311+
obs, _ = env.reset()
312+
assert obs.shape == (
313+
num_envs,
314+
3,
315+
dim_room,
316+
dim_room,
317+
), f"obs shape: {obs.shape}"
318+
for idx in range(num_envs):
319+
printed_obs.append(print_obs(obs[idx]).strip().split("\n"))
320+
for i, level in enumerate(levels_by_files):
321+
for j, line in enumerate(level):
322+
assert printed_obs[i][j] == line, f"Level {i} is not loaded correctly."
321323

322324

323325
def test_astar_log(tmp_path) -> None:

0 commit comments

Comments
 (0)