Skip to content

Commit ec608f1

Browse files
coyettevclement-bonnetdluo96
authored
refactor(cleaner): return state from instance generator (#87)
Co-authored-by: Clément Bonnet <56230714+clement-bonnet@users.noreply.github.com> Co-authored-by: Daniel <57721552+dluo96@users.noreply.github.com>
1 parent 8b394e2 commit ec608f1

File tree

10 files changed

+105
-85
lines changed

10 files changed

+105
-85
lines changed

jumanji/environments/commons/maze_utils/maze_generation.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@
4141
import chex
4242
import jax
4343
import jax.numpy as jnp
44-
from typing_extensions import TypeAlias
4544

4645
from jumanji.environments.commons.maze_utils.stack import (
4746
Stack,
@@ -51,8 +50,6 @@
5150
stack_push,
5251
)
5352

54-
Maze: TypeAlias = chex.Array
55-
5653
EMPTY = 0
5754
WALL = 1
5855

@@ -65,7 +62,7 @@ class MazeGenerationState(NamedTuple):
6562
- key: the Jax random generation key.
6663
"""
6764

68-
maze: Maze
65+
maze: chex.Array
6966
chambers: Stack
7067
key: chex.PRNGKey
7168

@@ -79,7 +76,7 @@ def create_chambers_stack(maze_width: int, maze_height: int) -> Stack:
7976
return stack_push(chambers, jnp.array([0, 0, maze_width, maze_height]))
8077

8178

82-
def create_empty_maze(width: int, height: int) -> Maze:
79+
def create_empty_maze(width: int, height: int) -> chex.Array:
8380
"""Create an empty maze."""
8481
return jnp.full((height, width), EMPTY, dtype=jnp.int8)
8582

@@ -94,19 +91,19 @@ def random_odd(key: chex.PRNGKey, max_val: int) -> chex.Array:
9491
return jax.random.randint(key, (), 0, max_val // 2) * 2 + 1
9592

9693

97-
def draw_horizontal_wall(maze: Maze, x: int, y: int, width: int) -> Maze:
94+
def draw_horizontal_wall(maze: chex.Array, x: int, y: int, width: int) -> chex.Array:
9895
"""Draw a horizontal wall on the maze starting from (x,y) with the specified width."""
9996

100-
def body_fun(i: int, maze: Maze) -> Maze:
97+
def body_fun(i: int, maze: chex.Array) -> chex.Array:
10198
return maze.at[y, i].set(WALL)
10299

103100
return jax.lax.fori_loop(x, x + width, body_fun, maze)
104101

105102

106-
def draw_vertical_wall(maze: Maze, x: int, y: int, height: int) -> Maze:
103+
def draw_vertical_wall(maze: chex.Array, x: int, y: int, height: int) -> chex.Array:
107104
"""Draw a vertical wall on the maze starting from (x,y) with the specified height."""
108105

109-
def body_fun(i: int, maze: Maze) -> Maze:
106+
def body_fun(i: int, maze: chex.Array) -> chex.Array:
110107
return maze.at[i, x].set(WALL)
111108

112109
return jax.lax.fori_loop(y, y + height, body_fun, maze)
@@ -156,7 +153,7 @@ def split_vertically(
156153

157154
def split_horizontally(
158155
state: MazeGenerationState, chamber: chex.Array
159-
) -> Tuple[Maze, Stack, chex.PRNGKey]:
156+
) -> Tuple[chex.Array, Stack, chex.PRNGKey]:
160157
"""Split the chamber horizontally.
161158
162159
Randomly draw a vertical wall to split the chamber horizontally. Randomly open a passage
@@ -202,7 +199,7 @@ def chambers_remaining(state: MazeGenerationState) -> int:
202199
return ~empty_stack(state.chambers)
203200

204201

205-
def generate_maze(width: int, height: int, key: chex.PRNGKey) -> Maze:
202+
def generate_maze(width: int, height: int, key: chex.PRNGKey) -> chex.Array:
206203
"""Randomly generate a maze.
207204
208205
Args:

jumanji/environments/commons/maze_utils/maze_generation_test.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
from jumanji.environments.commons.maze_utils.maze_generation import (
2222
EMPTY,
2323
WALL,
24-
Maze,
2524
MazeGenerationState,
2625
create_chambers_stack,
2726
create_empty_maze,
@@ -34,7 +33,7 @@
3433
from jumanji.environments.commons.maze_utils.stack import Stack, stack_pop
3534

3635

37-
def no_more_chamber(maze: Maze) -> chex.Array:
36+
def no_more_chamber(maze: chex.Array) -> chex.Array:
3837
"""Test if there is no chamber in the maze that can be divided anymore.
3938
4039
A chamber can be divided if its width and height are greater or equal to two.
@@ -48,7 +47,7 @@ def no_more_chamber(maze: Maze) -> chex.Array:
4847
return jnp.all(convolved)
4948

5049

51-
def all_tiles_connected(maze: Maze) -> bool:
50+
def all_tiles_connected(maze: chex.Array) -> bool:
5251
"""Test if all the tiles of the maze can be reached.
5352
5453
The function scipy.ndimage.label can be used to count the number of connected components
@@ -68,7 +67,7 @@ class TestMazeGeneration:
6867
HEIGHT = 15
6968

7069
@pytest.fixture
71-
def maze(self) -> Maze:
70+
def maze(self) -> chex.Array:
7271
return create_empty_maze(self.WIDTH, self.HEIGHT)
7372

7473
@pytest.fixture
@@ -111,7 +110,7 @@ def test_random_odd(self, key: chex.PRNGKey) -> None:
111110
assert 0 <= i < max_val
112111

113112
def test_split_vertically(
114-
self, maze: Maze, chambers: Stack, key: chex.PRNGKey
113+
self, maze: chex.Array, chambers: Stack, key: chex.PRNGKey
115114
) -> None:
116115
"""Test that a horizontal wall is drawn and that subchambers are added to stack."""
117116
chambers, chamber = stack_pop(chambers)
@@ -126,7 +125,7 @@ def test_split_vertically(
126125
assert chambers.insertion_index >= 1
127126

128127
def test_split_horizontally(
129-
self, maze: Maze, chambers: Stack, key: chex.PRNGKey
128+
self, maze: chex.Array, chambers: Stack, key: chex.PRNGKey
130129
) -> None:
131130
"""Test that a vertical wall is drawn and that subchambers are added to stack."""
132131
chambers, chamber = stack_pop(chambers)

jumanji/environments/commons/maze_utils/maze_rendering.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from typing import Callable, Optional, Sequence, Tuple
1616

17+
import chex
1718
import matplotlib.animation
1819
import matplotlib.cm
1920
import matplotlib.pyplot as plt
@@ -23,7 +24,7 @@
2324
from numpy.typing import NDArray
2425

2526
import jumanji.environments
26-
from jumanji.environments.commons.maze_utils.maze_generation import EMPTY, WALL, Maze
27+
from jumanji.environments.commons.maze_utils.maze_generation import EMPTY, WALL
2728
from jumanji.viewer import Viewer
2829

2930

@@ -55,7 +56,7 @@ def __init__(self, name: str, render_mode: str = "human") -> None:
5556
else:
5657
raise ValueError(f"Invalid render mode: {render_mode}")
5758

58-
def render(self, maze: Maze) -> Optional[NDArray]:
59+
def render(self, maze: chex.Array) -> Optional[NDArray]:
5960
"""
6061
Render maze.
6162
@@ -73,7 +74,7 @@ def render(self, maze: Maze) -> Optional[NDArray]:
7374

7475
def animate(
7576
self,
76-
mazes: Sequence[Maze],
77+
mazes: Sequence[chex.Array],
7778
interval: int = 200,
7879
save_path: Optional[str] = None,
7980
) -> matplotlib.animation.FuncAnimation:
@@ -124,12 +125,12 @@ def _get_fig_ax(self) -> Tuple[plt.Figure, plt.Axes]:
124125
ax = fig.get_axes()[0]
125126
return fig, ax
126127

127-
def _add_grid_image(self, maze: Maze, ax: Axes) -> image.AxesImage:
128+
def _add_grid_image(self, maze: chex.Array, ax: Axes) -> image.AxesImage:
128129
img = self._create_grid_image(maze)
129130
ax.set_axis_off()
130131
return ax.imshow(img)
131132

132-
def _create_grid_image(self, maze: Maze) -> NDArray:
133+
def _create_grid_image(self, maze: chex.Array) -> NDArray:
133134
img = np.zeros((*maze.shape, 3))
134135
for tile_value, color in self.COLORS.items():
135136
img[np.where(maze == tile_value)] = color

jumanji/environments/routing/cleaner/env.py

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,6 @@ class Cleaner(Environment[State]):
8282

8383
def __init__(
8484
self,
85-
num_agents: int = 3,
8685
generator: Optional[Generator] = None,
8786
time_limit: Optional[int] = None,
8887
penalty_per_timestep: float = 0.5,
@@ -95,13 +94,15 @@ def __init__(
9594
time_limit: max number of steps in an episode. Defaults to `num_rows * num_cols`.
9695
generator: `Generator` whose `__call__` instantiates an environment instance.
9796
Implemented options are [`RandomGenerator`]. Defaults to `RandomGenerator` with
98-
`num_rows=10` and `num_cols=10`.
97+
`num_rows=10`, `num_cols=10` and `num_agents=3`.
9998
viewer: `Viewer` used for rendering. Defaults to `CleanerViewer` with "human" render
10099
mode.
101100
penalty_per_timestep: the penalty returned at each timestep in the reward.
102101
"""
103-
self.num_agents = num_agents
104-
self.generator = generator or RandomGenerator(num_rows=10, num_cols=10)
102+
self.generator = generator or RandomGenerator(
103+
num_rows=10, num_cols=10, num_agents=3
104+
)
105+
self.num_agents = self.generator.num_agents
105106
self.num_rows = self.generator.num_rows
106107
self.num_cols = self.generator.num_cols
107108
self.grid_shape = (self.num_rows, self.num_cols)
@@ -177,22 +178,13 @@ def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]:
177178
timestep: `TimeStep` object corresponding to the first timestep returned by the
178179
environment after a reset.
179180
"""
180-
key, subkey = jax.random.split(key)
181-
182181
# Agents start in upper left corner
183182
agents_locations = jnp.zeros((self.num_agents, 2), int)
184183

185-
grid = self.generator(subkey)
186-
# Clean the tile in upper left corner
187-
grid = self._clean_tiles_containing_agents(grid, agents_locations)
184+
state = self.generator(key)
188185

189-
state = State(
190-
grid=grid,
191-
agents_locations=agents_locations,
192-
action_mask=self._compute_action_mask(grid, agents_locations),
193-
step_count=jnp.array(0, jnp.int32),
194-
key=key,
195-
)
186+
# Create the action mask and update the state
187+
state.action_mask = self._compute_action_mask(state.grid, agents_locations)
196188

197189
observation = self._observation_from_state(state)
198190

jumanji/environments/routing/cleaner/env_test.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,37 +19,45 @@
1919

2020
from jumanji.environments.routing.cleaner.constants import CLEAN, DIRTY, WALL
2121
from jumanji.environments.routing.cleaner.env import Cleaner
22-
from jumanji.environments.routing.cleaner.generator import Generator, Maze
22+
from jumanji.environments.routing.cleaner.generator import Generator
2323
from jumanji.environments.routing.cleaner.types import Observation, State
2424
from jumanji.testing.env_not_smoke import check_env_does_not_smoke
2525
from jumanji.testing.pytrees import assert_is_jax_array_tree
2626
from jumanji.types import StepType, TimeStep
2727

2828
SAMPLE_GRID = jnp.array(
2929
[
30-
[DIRTY, DIRTY, WALL, DIRTY, DIRTY],
30+
[CLEAN, DIRTY, WALL, DIRTY, DIRTY],
3131
[WALL, DIRTY, WALL, DIRTY, WALL],
3232
[DIRTY, DIRTY, DIRTY, DIRTY, WALL],
3333
[DIRTY, WALL, WALL, DIRTY, WALL],
3434
[DIRTY, WALL, DIRTY, DIRTY, DIRTY],
3535
]
3636
)
37-
N_AGENT = 3
3837

3938

4039
class DummyGenerator(Generator):
41-
def __init__(self) -> None:
42-
super(DummyGenerator, self).__init__(num_rows=5, num_cols=5)
40+
"""Dummy generator, generate an instance of size 5x5 with 3 agents."""
4341

44-
def __call__(self, key: chex.PRNGKey) -> Maze:
45-
return SAMPLE_GRID
42+
def __init__(self) -> None:
43+
super(DummyGenerator, self).__init__(num_rows=5, num_cols=5, num_agents=3)
44+
45+
def __call__(self, key: chex.PRNGKey) -> State:
46+
agents_locations = jnp.zeros((self.num_agents, 2), int)
47+
return State(
48+
grid=SAMPLE_GRID,
49+
agents_locations=agents_locations,
50+
action_mask=None,
51+
step_count=jnp.array(0, jnp.int32),
52+
key=key,
53+
)
4654

4755

4856
class TestCleaner:
4957
@pytest.fixture
5058
def cleaner(self) -> Cleaner:
5159
generator = DummyGenerator()
52-
return Cleaner(num_agents=N_AGENT, generator=generator)
60+
return Cleaner(generator=generator)
5361

5462
@pytest.fixture
5563
def key(self) -> chex.PRNGKey:
@@ -74,7 +82,7 @@ def test_cleaner__reset(self, cleaner: Cleaner, key: chex.PRNGKey) -> None:
7482
assert isinstance(timestep, TimeStep)
7583
assert isinstance(state, State)
7684

77-
assert jnp.all(state.agents_locations == jnp.zeros((N_AGENT, 2)))
85+
assert jnp.all(state.agents_locations == jnp.zeros((cleaner.num_agents, 2)))
7886
assert jnp.sum(state.grid == CLEAN) == 1 # Only the top-left tile is clean
7987
assert state.step_count == 0
8088

@@ -101,7 +109,7 @@ def test_cleaner__step(self, cleaner: Cleaner, key: chex.PRNGKey) -> None:
101109
step_fn = jax.jit(cleaner.step)
102110

103111
# First action: all agents move right
104-
actions = jnp.array([1] * N_AGENT)
112+
actions = jnp.array([1] * cleaner.num_agents)
105113
state, timestep = step_fn(initial_state, actions)
106114
# Assert only one tile changed, on the right of the initial pos
107115
assert jnp.sum(state.grid != initial_state.grid) == 1
@@ -148,7 +156,7 @@ def test_cleaner__initial_action_mask(
148156

149157
# All agents can only move right in the initial state
150158
expected_action_mask = jnp.array(
151-
[[False, True, False, False] for _ in range(N_AGENT)]
159+
[[False, True, False, False] for _ in range(cleaner.num_agents)]
152160
)
153161

154162
assert jnp.all(state.action_mask == expected_action_mask)
@@ -177,7 +185,7 @@ def select_action(
177185
key, jnp.arange(4), p=agent_action_mask.flatten()
178186
)
179187

180-
subkeys = jax.random.split(key, N_AGENT)
188+
subkeys = jax.random.split(key, cleaner.num_agents)
181189
return select_action(subkeys, observation.action_mask)
182190

183191
check_env_does_not_smoke(cleaner, select_actions)

0 commit comments

Comments
 (0)