Skip to content

Commit 9b32dd0

Browse files
author
Adrià Garriga-Alonso
authored
Fix astar_log test (#10)
2 parents 0183b4b + 9c6a5cc commit 9b32dd0

File tree

2 files changed

+30
-25
lines changed

2 files changed

+30
-25
lines changed

envpool/sokoban/level_loader.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,11 @@ LevelLoader::LevelLoader(const std::filesystem::path& base_path,
4242
level_file_paths_.push_back(entry.path());
4343
}
4444
}
45+
std::sort(
46+
level_file_paths_.begin(), level_file_paths_.end(),
47+
[](const std::filesystem::path& a, const std::filesystem::path& b) {
48+
return a.filename().string() < b.filename().string();
49+
});
4550
}
4651
cur_file_ = level_file_paths_.begin();
4752
}

envpool/sokoban/sokoban_py_envpool_test.py

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
import re
1818
import subprocess
1919
import sys
20-
import tempfile
2120
import time
2221

2322
import numpy as np
@@ -111,7 +110,7 @@ def test_envpool_load_sequentially(capfd) -> None:
111110
levels_dir = "/app/envpool/sokoban/sample_levels"
112111
files = glob.glob(f"{levels_dir}/*.txt")
113112
levels_by_files = []
114-
for file in files:
113+
for file in sorted(files):
115114
with open(file, "r") as f:
116115
text = f.read()
117116
levels = text.split("\n;")
@@ -204,10 +203,10 @@ def print_obs(obs: np.ndarray):
204203

205204

206205
action_astar_to_envpool = {
207-
"0": 1,
208-
"1": 4,
209-
"2": 2,
210-
"3": 3,
206+
"0": 0,
207+
"1": 3,
208+
"2": 1,
209+
"3": 2,
211210
}
212211

213212

@@ -243,7 +242,7 @@ def test_solved_level_does_not_truncate(solve_on_time: bool):
243242
)
244243
assert not term and not trunc, "Level should not have reached time limit"
245244

246-
NOOP = 0
245+
wrong_action = str((int(SOLVE_LEVEL_ZERO[-1]) + 1) % 4)
247246

248247
if solve_on_time:
249248
obs, reward, term, trunc, infos = env.step(
@@ -256,30 +255,31 @@ def test_solved_level_does_not_truncate(solve_on_time: bool):
256255
assert term and not trunc, "Level should finish within the time limit"
257256

258257
else:
259-
obs, reward, term, trunc, infos = env.step(make_1d_array(NOOP))
258+
obs, reward, term, trunc, infos = env.step(make_1d_array(wrong_action))
260259
assert not term and trunc, "Level should truncate at precisely this step"
261260

262-
_, _, term, trunc, _ = env.step(make_1d_array(NOOP))
261+
_, _, term, trunc, _ = env.step(make_1d_array(wrong_action))
263262
assert not term and not trunc, "Level should reset correctly"
264263

265264

266-
@pytest.mark.skip
267-
def test_astar_log() -> None:
265+
def test_astar_log(tmp_path) -> None:
268266
level_file_name = "/app/envpool/sokoban/sample_levels/small.txt"
269-
with tempfile.NamedTemporaryFile() as f:
270-
log_file_name = f.name
271-
subprocess.run(
272-
[
273-
"/root/go/bin/bazel", "run", "//envpool/sokoban:astar_log", "--",
274-
level_file_name, log_file_name, "1"
275-
],
276-
check=True,
277-
cwd="/app",
278-
env=dict(HOME="/root"),
279-
)
280-
with open(log_file_name, "r") as f:
281-
log = f.read()
282-
assert f"1, {SOLVE_LEVEL_ZERO}, 21, 1443" == log.split("\n")[1]
267+
log_file_name = tmp_path / "log_file.csv"
268+
subprocess.run(
269+
[
270+
"/root/go/bin/bazel", f"--output_base={str(tmp_path)}", "run",
271+
"//envpool/sokoban:astar_log", "--", level_file_name,
272+
str(log_file_name), "1"
273+
],
274+
check=True,
275+
cwd="/app/envpool",
276+
env={
277+
"HOME": "/root",
278+
"PATH": "/opt/conda/bin:/usr/bin"
279+
},
280+
)
281+
log = log_file_name.read_text()
282+
assert f"0,{SOLVE_LEVEL_ZERO},21,1380" == log.split("\n")[1]
283283

284284

285285
if __name__ == "__main__":

0 commit comments

Comments
 (0)