1717import re
1818import subprocess
1919import sys
20- import tempfile
2120import time
2221
2322import numpy as np
@@ -111,7 +110,7 @@ def test_envpool_load_sequentially(capfd) -> None:
111110 levels_dir = "/app/envpool/sokoban/sample_levels"
112111 files = glob .glob (f"{ levels_dir } /*.txt" )
113112 levels_by_files = []
114- for file in files :
113+ for file in sorted ( files ) :
115114 with open (file , "r" ) as f :
116115 text = f .read ()
117116 levels = text .split ("\n ;" )
@@ -204,10 +203,10 @@ def print_obs(obs: np.ndarray):
204203
205204
206205action_astar_to_envpool = {
207- "0" : 1 ,
208- "1" : 4 ,
209- "2" : 2 ,
210- "3" : 3 ,
206+ "0" : 0 ,
207+ "1" : 3 ,
208+ "2" : 1 ,
209+ "3" : 2 ,
211210}
212211
213212
@@ -243,7 +242,7 @@ def test_solved_level_does_not_truncate(solve_on_time: bool):
243242 )
244243 assert not term and not trunc , "Level should not have reached time limit"
245244
246- NOOP = 0
245+ wrong_action = str (( int ( SOLVE_LEVEL_ZERO [ - 1 ]) + 1 ) % 4 )
247246
248247 if solve_on_time :
249248 obs , reward , term , trunc , infos = env .step (
@@ -256,30 +255,31 @@ def test_solved_level_does_not_truncate(solve_on_time: bool):
256255 assert term and not trunc , "Level should finish within the time limit"
257256
258257 else :
259- obs , reward , term , trunc , infos = env .step (make_1d_array (NOOP ))
258+ obs , reward , term , trunc , infos = env .step (make_1d_array (wrong_action ))
260259 assert not term and trunc , "Level should truncate at precisely this step"
261260
262- _ , _ , term , trunc , _ = env .step (make_1d_array (NOOP ))
261+ _ , _ , term , trunc , _ = env .step (make_1d_array (wrong_action ))
263262 assert not term and not trunc , "Level should reset correctly"
264263
265264
266- @pytest .mark .skip
267- def test_astar_log () -> None :
265+ def test_astar_log (tmp_path ) -> None :
268266 level_file_name = "/app/envpool/sokoban/sample_levels/small.txt"
269- with tempfile .NamedTemporaryFile () as f :
270- log_file_name = f .name
271- subprocess .run (
272- [
273- "/root/go/bin/bazel" , "run" , "//envpool/sokoban:astar_log" , "--" ,
274- level_file_name , log_file_name , "1"
275- ],
276- check = True ,
277- cwd = "/app" ,
278- env = dict (HOME = "/root" ),
279- )
280- with open (log_file_name , "r" ) as f :
281- log = f .read ()
282- assert f"1, { SOLVE_LEVEL_ZERO } , 21, 1443" == log .split ("\n " )[1 ]
267+ log_file_name = tmp_path / "log_file.csv"
268+ subprocess .run (
269+ [
270+ "/root/go/bin/bazel" , f"--output_base={ str (tmp_path )} " , "run" ,
271+ "//envpool/sokoban:astar_log" , "--" , level_file_name ,
272+ str (log_file_name ), "1"
273+ ],
274+ check = True ,
275+ cwd = "/app/envpool" ,
276+ env = {
277+ "HOME" : "/root" ,
278+ "PATH" : "/opt/conda/bin:/usr/bin"
279+ },
280+ )
281+ log = log_file_name .read_text ()
282+ assert f"0,{ SOLVE_LEVEL_ZERO } ,21,1380" == log .split ("\n " )[1 ]
283283
284284
285285if __name__ == "__main__" :
0 commit comments