Skip to content

Commit 9b91276

Browse files
Test external custom environment (#24)
1 parent 8b91e82 commit 9b91276

File tree

3 files changed

+88
-2
lines changed

3 files changed

+88
-2
lines changed

examples/grid_world/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ The `difficulty` parameter limits how far the goal can be from the starting posi
66

77
## Build
88

9-
Compile the module with `maturin develop` or `python -m pip install -e .` if you have [maturin](https://github.com/PyO3/maturin) installed.
9+
Compile the module with `python -m pip install -e .` or `maturin develop` if you have the [maturin](https://github.com/PyO3/maturin) package installed.
1010

1111
## Usage
1212

examples/grid_world/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,4 @@ packages = ["grid_world"]
1616
[tool.maturin]
1717
features = ["pyo3/extension-module"]
1818
profile = "release"
19-
module-name = "twisterl.grid_world"
19+
module-name = "grid_world"
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
import functools
2+
import importlib
3+
import os
4+
import subprocess
5+
import sys
6+
from pathlib import Path
7+
8+
import pytest
9+
10+
from twisterl.utils import load_config, prepare_algorithm
11+
12+
13+
@functools.lru_cache(maxsize=1)
14+
def _ensure_grid_world_available():
15+
"""Install the external grid_world example if it is not already importable."""
16+
module_name = "grid_world"
17+
try:
18+
importlib.import_module(module_name)
19+
return module_name
20+
except ModuleNotFoundError:
21+
pass
22+
23+
# pytest.importorskip(
24+
# "maturin", reason="Grid World example needs maturin to build the extension"
25+
# )
26+
example_dir = Path(__file__).resolve().parents[1] / "examples" / "grid_world"
27+
env = os.environ.copy()
28+
venv_bin = Path(sys.executable).resolve().parent
29+
env["PATH"] = f"{venv_bin}{os.pathsep}{env.get('PATH', '')}"
30+
env.setdefault("VIRTUAL_ENV", str(venv_bin.parent))
31+
32+
subprocess.run(
33+
[
34+
sys.executable,
35+
"-m",
36+
"pip",
37+
"install",
38+
"-e",
39+
"."
40+
],
41+
check=True,
42+
cwd=str(example_dir),
43+
stdout=subprocess.PIPE,
44+
stderr=subprocess.PIPE,
45+
env=env,
46+
)
47+
importlib.invalidate_caches()
48+
importlib.import_module(module_name)
49+
return module_name
50+
51+
52+
def test_grid_world_external_env_works_with_twisterl():
53+
module_name = _ensure_grid_world_available()
54+
55+
config_path = (
56+
Path(__file__).resolve().parents[1]
57+
/ "examples"
58+
/ "grid_world"
59+
/ "ppo_grid_world_5x5_v1.json"
60+
)
61+
algo_config = load_config(config_path)
62+
algo_cfg = algo_config["algorithm"]
63+
algo_cfg["collecting"].update({"num_cores": 1, "num_episodes": 1})
64+
65+
grid_world_module = importlib.import_module(module_name)
66+
GridWorld = grid_world_module.GridWorld
67+
68+
env_args = algo_config["env"].copy()
69+
env = GridWorld(**env_args)
70+
env.reset()
71+
# Observations & states should match the board size
72+
state = env.get_state()
73+
expected_cells = env_args["width"] * env_args["height"]
74+
assert len(state) == expected_cells
75+
assert {0, 1, 2, 3}.issuperset(set(state))
76+
77+
algo = prepare_algorithm(algo_config)
78+
79+
# Ensure the basic interactions work as expected
80+
algo.env.reset()
81+
observed = algo.env.observe()
82+
assert len(observed) == algo.env.obs_shape()[0]
83+
84+
collected, _ = algo.collect()
85+
assert collected.obs
86+
assert collected.additional_data["rets"]

0 commit comments

Comments
 (0)