Skip to content

Commit 9432d14

Browse files
authored
Fix sequential loading when using multiple envs (#12)
2 parents ce439db + 12f813a commit 9432d14

File tree

4 files changed

+81
-12
lines changed

4 files changed

+81
-12
lines changed

envpool/sokoban/level_loader.cc

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,11 @@ namespace sokoban {
2929

3030
LevelLoader::LevelLoader(const std::filesystem::path& base_path,
3131
bool load_sequentially, int n_levels_to_load,
32-
int verbose)
32+
int env_id, int num_envs, int verbose)
3333
: load_sequentially_(load_sequentially),
3434
n_levels_to_load_(n_levels_to_load),
35-
cur_level_(levels_.begin()),
35+
num_envs_(num_envs),
36+
cur_level_(env_id),
3637
verbose(verbose) {
3738
if (std::filesystem::is_regular_file(base_path)) {
3839
level_file_paths_.push_back(base_path);
@@ -49,6 +50,10 @@ LevelLoader::LevelLoader(const std::filesystem::path& base_path,
4950
});
5051
}
5152
cur_file_ = level_file_paths_.begin();
53+
if (n_levels_to_load_ > 0 && n_levels_to_load_ % num_envs_ != 0) {
54+
throw std::runtime_error(
55+
"n_levels_to_load must be a multiple of num_envs.");
56+
}
5257
}
5358

5459
static const std::array<char, kMaxLevelObject + 1> kPrintLevelKey{
@@ -183,15 +188,18 @@ std::vector<SokobanLevel>::iterator LevelLoader::GetLevel(std::mt19937& gen) {
183188
if (n_levels_to_load_ > 0 && levels_loaded_ >= n_levels_to_load_) {
184189
throw std::runtime_error("Loaded all requested levels.");
185190
}
186-
if (cur_level_ == levels_.end()) {
191+
// Load new files until the current level index is within the loaded levels
192+
// this is required when new files have lesser levels than the number of envs
193+
while (cur_level_ >= levels_.size()) {
194+
cur_level_ -= levels_.size();
187195
LoadFile(gen);
188-
cur_level_ = levels_.begin();
189-
if (cur_level_ == levels_.end()) {
196+
if (levels_.empty()) { // new file is empty
190197
throw std::runtime_error("No levels loaded.");
191198
}
192199
}
193-
auto out = cur_level_;
194-
cur_level_++;
200+
// no need for bound checks since it is checked in the while loop above
201+
auto out = levels_.begin() + cur_level_;
202+
cur_level_ += num_envs_;
195203
levels_loaded_++;
196204
return out;
197205
}

envpool/sokoban/level_loader.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,10 @@ class LevelLoader {
3939
bool load_sequentially_;
4040
int n_levels_to_load_;
4141
int levels_loaded_{0};
42+
int env_id_{0};
43+
int num_envs_{1};
4244
std::vector<SokobanLevel> levels_{0};
43-
std::vector<SokobanLevel>::iterator cur_level_;
45+
int cur_level_;
4446
std::vector<std::filesystem::path> level_file_paths_{0};
4547
std::vector<std::filesystem::path>::iterator cur_file_;
4648
void LoadFile(std::mt19937& gen);
@@ -51,7 +53,7 @@ class LevelLoader {
5153
std::vector<SokobanLevel>::iterator GetLevel(std::mt19937& gen);
5254
explicit LevelLoader(const std::filesystem::path& base_path,
5355
bool load_sequentially, int n_levels_to_load,
54-
int verbose = 0);
56+
int env_id = 0, int num_envs = 1, int verbose = 0);
5557
};
5658

5759
void PrintLevel(std::ostream& os, const SokobanLevel& vec);

envpool/sokoban/sokoban_envpool.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ class SokobanEnv : public Env<SokobanEnvSpec> {
7070
levels_dir_{static_cast<std::string>(spec.config["levels_dir"_])},
7171
level_loader_(levels_dir_, spec.config["load_sequentially"_],
7272
static_cast<int>(spec.config["n_levels_to_load"_]),
73+
env_id, static_cast<int>(spec.config["num_envs"_]),
7374
static_cast<int>(spec.config["verbose"_])),
7475
world_(kWall, static_cast<std::size_t>(dim_room_ * dim_room_)),
7576
verbose_(static_cast<int>(spec.config["verbose"_])),

envpool/sokoban/sokoban_py_envpool_test.py

Lines changed: 61 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
import subprocess
1919
import sys
2020
import time
21+
from pathlib import Path
22+
from typing import List
2123

2224
import numpy as np
2325
import pytest
@@ -187,19 +189,21 @@ def test_xla() -> None:
187189

188190
def print_obs(obs: np.ndarray):
189191
assert obs.shape == (3, 10, 10)
192+
printed = ""
190193
for y in range(obs.shape[1]):
191194
for x in range(obs.shape[2]):
192195
arr = obs[:, y, x]
193196
printed_any = False
194197
for color, symbol in TINY_COLORS:
195198
assert arr.shape == (3,)
196199
if np.array_equal(arr, color):
197-
print(symbol, end="")
200+
printed += symbol
198201
printed_any = True
199202
break
200203
assert printed_any, f"Could not find match for {arr}"
201-
print("\n", end="")
202-
print("\n", end="")
204+
printed += "\n"
205+
printed += "\n"
206+
return printed
203207

204208

205209
action_astar_to_envpool = {
@@ -262,6 +266,60 @@ def test_solved_level_does_not_truncate(solve_on_time: bool):
262266
assert not term and not trunc, "Level should reset correctly"
263267

264268

269+
def read_levels_file(fpath: Path) -> List[List[str]]:
270+
maps = []
271+
current_map = []
272+
with open(fpath, "r") as sf:
273+
for line in sf.readlines():
274+
if ";" in line and current_map:
275+
maps.append(current_map)
276+
current_map = []
277+
if "#" == line[0]:
278+
current_map.append(line.strip())
279+
280+
maps.append(current_map)
281+
return maps
282+
283+
284+
def test_load_sequentially_with_multiple_envs() -> None:
285+
levels_dir = "/app/envpool/sokoban/sample_levels"
286+
files = glob.glob(f"{levels_dir}/*.txt")
287+
levels_by_files = []
288+
total_levels, num_envs = 8, 2
289+
for file in sorted(files):
290+
levels = read_levels_file(file)
291+
levels_by_files.extend(levels)
292+
assert len(levels_by_files) == total_levels, "8 levels stored in files."
293+
294+
env = envpool.make(
295+
"Sokoban-v0",
296+
env_type="gymnasium",
297+
num_envs=num_envs,
298+
batch_size=num_envs,
299+
max_episode_steps=60,
300+
min_episode_steps=60,
301+
levels_dir=levels_dir,
302+
load_sequentially=True,
303+
n_levels_to_load=total_levels,
304+
verbose=2,
305+
)
306+
dim_room = env.spec.config.dim_room
307+
printed_obs = []
308+
for _ in range(total_levels // num_envs):
309+
obs, _ = env.reset()
310+
assert obs.shape == (
311+
num_envs,
312+
3,
313+
dim_room,
314+
dim_room,
315+
), f"obs shape: {obs.shape}"
316+
for idx in range(num_envs):
317+
printed_obs.append(print_obs(obs[idx]).strip().split("\n"))
318+
for i, level in enumerate(levels_by_files):
319+
for j, line in enumerate(level):
320+
assert printed_obs[i][j] == line, f"Level {i} is not loaded correctly."
321+
322+
265323
def test_astar_log(tmp_path) -> None:
266324
level_file_name = "/app/envpool/sokoban/sample_levels/small.txt"
267325
log_file_name = tmp_path / "log_file.csv"

0 commit comments

Comments
 (0)