Skip to content

Commit 36ab997

Browse files
TristanKalloniatistristankalloniatisdluo96
authored
feat: allow custom rendering and instance generation methods for Minesweeper (#85)
Co-authored-by: tristankalloniatis <t.kalloniatis@instadeep.com> Co-authored-by: Daniel <57721552+dluo96@users.noreply.github.com>
1 parent ec608f1 commit 36ab997

File tree

11 files changed

+398
-244
lines changed

11 files changed

+398
-244
lines changed

jumanji/environments/logic/minesweeper/conftest.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,16 @@
1818

1919
from jumanji.environments.logic.minesweeper.constants import UNEXPLORED_ID
2020
from jumanji.environments.logic.minesweeper.env import Minesweeper
21+
from jumanji.environments.logic.minesweeper.generator import UniformSamplingGenerator
2122
from jumanji.environments.logic.minesweeper.types import State
2223

2324

2425
@pytest.fixture
2526
def minesweeper_env() -> Minesweeper:
26-
"""Fixture for a default minesweeper env"""
27-
return Minesweeper()
27+
"""Fixture for a default minesweeper environment with 10 rows and columns, and 10 mines."""
28+
return Minesweeper(
29+
generator=UniformSamplingGenerator(num_rows=10, num_cols=10, num_mines=10)
30+
)
2831

2932

3033
@pytest.fixture

jumanji/environments/logic/minesweeper/constants.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,7 @@
1515
UNEXPLORED_ID: int = -1
1616
IS_MINE: int = 1
1717
PATCH_SIZE: int = 3
18-
REVEALED_EMPTY_SQUARE_REWARD: float = 1.0
19-
REVEALED_MINE_OR_INVALID_ACTION_REWARD: float = 0.0
20-
COLOUR_MAPPING: list = [
18+
DEFAULT_COLOR_MAPPING: list = [
2119
"orange",
2220
"blue",
2321
"green",

jumanji/environments/logic/minesweeper/env.py

Lines changed: 53 additions & 171 deletions
Original file line numberDiff line numberDiff line change
@@ -12,31 +12,28 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import List, Optional, Sequence, Tuple
15+
from typing import Optional, Sequence, Tuple
1616

1717
import chex
1818
import jax
1919
import jax.numpy as jnp
2020
import matplotlib.animation
21-
import matplotlib.pyplot as plt
21+
from numpy.typing import NDArray
2222

23-
import jumanji.environments
2423
from jumanji import specs
2524
from jumanji.env import Environment
26-
from jumanji.environments.logic.minesweeper.constants import (
27-
COLOUR_MAPPING,
28-
PATCH_SIZE,
29-
UNEXPLORED_ID,
30-
)
25+
from jumanji.environments.logic.minesweeper.constants import PATCH_SIZE, UNEXPLORED_ID
3126
from jumanji.environments.logic.minesweeper.done import DefaultDoneFn, DoneFn
27+
from jumanji.environments.logic.minesweeper.generator import (
28+
Generator,
29+
UniformSamplingGenerator,
30+
)
3231
from jumanji.environments.logic.minesweeper.reward import DefaultRewardFn, RewardFn
3332
from jumanji.environments.logic.minesweeper.types import Observation, State
34-
from jumanji.environments.logic.minesweeper.utils import (
35-
count_adjacent_mines,
36-
create_flat_mine_locations,
37-
explored_mine,
38-
)
33+
from jumanji.environments.logic.minesweeper.utils import count_adjacent_mines
34+
from jumanji.environments.logic.minesweeper.viewer import MinesweeperViewer
3935
from jumanji.types import TimeStep, restart, termination, transition
36+
from jumanji.viewer import Viewer
4037

4138

4239
class Minesweeper(Environment[State]):
@@ -53,7 +50,7 @@ class Minesweeper(Environment[State]):
5350
specifies how many timesteps have elapsed since environment reset.
5451
5552
- action:
56-
multi discrete array containing the square to explore (height and width).
53+
multi discrete array containing the square to explore (row and col).
5754
5855
- reward: jax array (float32):
5956
Configurable function of state and action. By default:
@@ -92,46 +89,47 @@ class Minesweeper(Environment[State]):
9289

9390
def __init__(
9491
self,
95-
num_rows: int = 10,
96-
num_cols: int = 10,
97-
num_mines: int = 10,
92+
generator: Optional[Generator] = None,
9893
reward_function: Optional[RewardFn] = None,
9994
done_function: Optional[DoneFn] = None,
100-
color_mapping: Optional[List[str]] = None,
95+
viewer: Optional[Viewer[State]] = None,
10196
):
10297
"""Instantiate a `Minesweeper` environment.
10398
10499
Args:
105-
num_rows: number of rows, i.e. height of the board. Defaults to 10.
106-
num_cols: number of columns, i.e. width of the board. Defaults to 10.
107-
num_mines: number of mines on the board. Defaults to 10.
100+
generator: `Generator` to generate problem instances on environment reset.
101+
Implemented options are [`SamplingGenerator`]. Defaults to `SamplingGenerator`.
102+
The generator will have attributes:
103+
- num_rows: number of rows, i.e. height of the board. Defaults to 10.
104+
- num_cols: number of columns, i.e. width of the board. Defaults to 10.
105+
- num_mines: number of mines generated. Defaults to 10.
108106
reward_function: `RewardFn` whose `__call__` method computes the reward of an
109107
environment transition based on the given current state and selected action.
110-
Implemented options are [`DefaultRewardFn`]. Defaults to `DefaultRewardFn`.
108+
Implemented options are [`DefaultRewardFn`]. Defaults to `DefaultRewardFn`, giving
109+
a reward of 1.0 for revealing an empty square, 0.0 for revealing a mine, and
110+
0.0 for an invalid action (selecting an already revealed square).
111111
done_function: `DoneFn` whose `__call__` method computes the done signal given the
112112
current state, action taken, and next state.
113-
Implemented options are [`DefaultDoneFn`]. Defaults to `DefaultDoneFn`.
114-
color_mapping: colour map used for rendering.
113+
Implemented options are [`DefaultDoneFn`]. Defaults to `DefaultDoneFn`, ending the
114+
episode on solving the board, revealing a mine, or picking an invalid action.
115+
viewer: `Viewer` to support rendering and animation methods.
116+
Implemented options are [`MinesweeperViewer`]. Defaults to `MinesweeperViewer`.
115117
"""
116-
if num_rows <= 1 or num_cols <= 1:
117-
raise ValueError(
118-
f"Should make a board of height and width greater than 1, "
119-
f"got num_rows={num_rows}, num_cols={num_cols}"
120-
)
121-
if num_mines < 0 or num_mines >= num_rows * num_cols:
122-
raise ValueError(
123-
f"Number of mines should be constrained between 0 and the size of the board, "
124-
f"got {num_mines}"
125-
)
126-
self.num_rows = num_rows
127-
self.num_cols = num_cols
128-
self.num_mines = num_mines
129-
self.reward_function = reward_function or DefaultRewardFn()
118+
self.reward_function = reward_function or DefaultRewardFn(
119+
revealed_empty_square_reward=1.0,
120+
revealed_mine_reward=0.0,
121+
invalid_action_reward=0.0,
122+
)
130123
self.done_function = done_function or DefaultDoneFn()
131-
132-
self.cmap = color_mapping if color_mapping else COLOUR_MAPPING
133-
self.figure_name = f"{num_rows}x{num_cols} Minesweeper"
134-
self.figure_size = (6.0, 6.0)
124+
self.generator = generator or UniformSamplingGenerator(
125+
num_rows=10, num_cols=10, num_mines=10
126+
)
127+
self.num_rows = self.generator.num_rows
128+
self.num_cols = self.generator.num_cols
129+
self.num_mines = self.generator.num_mines
130+
self._viewer = viewer or MinesweeperViewer(
131+
num_rows=self.num_rows, num_cols=self.num_cols
132+
)
135133

136134
def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]:
137135
"""Resets the environment.
@@ -144,25 +142,7 @@ def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]:
144142
timestep: `TimeStep` corresponding to the first timestep returned by the
145143
environment.
146144
"""
147-
key, sample_key = jax.random.split(key)
148-
board = jnp.full(
149-
shape=(self.num_rows, self.num_cols),
150-
fill_value=UNEXPLORED_ID,
151-
dtype=jnp.int32,
152-
)
153-
step_count = jnp.array(0, jnp.int32)
154-
flat_mine_locations = create_flat_mine_locations(
155-
key=sample_key,
156-
num_rows=self.num_rows,
157-
num_cols=self.num_cols,
158-
num_mines=self.num_mines,
159-
)
160-
state = State(
161-
board=board,
162-
step_count=step_count,
163-
key=key,
164-
flat_mine_locations=flat_mine_locations,
165-
)
145+
state = self.generator(key)
166146
observation = self._state_to_observation(state=state)
167147
timestep = restart(observation=observation)
168148
return state, timestep
@@ -180,9 +160,7 @@ def step(
180160
next_state: `State` corresponding to the next state of the environment,
181161
next_timestep: `TimeStep` corresponding to the timestep returned by the environment.
182162
"""
183-
board = state.board
184-
action_height, action_width = action
185-
board = board.at[action_height, action_width].set(
163+
board = state.board.at[tuple(action)].set(
186164
count_adjacent_mines(state=state, action=action)
187165
)
188166
step_count = state.step_count + 1
@@ -272,134 +250,38 @@ def _state_to_observation(self, state: State) -> Observation:
272250
step_count=state.step_count,
273251
)
274252

275-
def render(self, state: State) -> None:
276-
"""Render the given environment state using matplotlib.
253+
def render(self, state: State) -> Optional[NDArray]:
254+
"""Renders the current state of the board.
277255
278256
Args:
279-
state: environment state to be rendered.
280-
257+
state: the current state to be rendered.
281258
"""
282-
self._clear_display()
283-
fig, ax = self._get_fig_ax()
284-
self._draw(ax, state)
285-
self._update_display(fig)
259+
return self._viewer.render(state=state)
286260

287261
def animate(
288262
self,
289263
states: Sequence[State],
290264
interval: int = 200,
291265
save_path: Optional[str] = None,
292266
) -> matplotlib.animation.FuncAnimation:
293-
"""Create an animation from a sequence of environment states.
267+
"""Creates an animated gif of the board based on the sequence of states.
294268
295269
Args:
296-
states: sequence of environment states corresponding to consecutive timesteps.
297-
interval: delay between frames in milliseconds, default to 200.
270+
states: a list of `State` objects representing the sequence of states.
271+
interval: the delay between frames in milliseconds, default to 200.
298272
save_path: the path where the animation file should be saved. If it is None, the plot
299273
will not be saved.
300274
301275
Returns:
302-
Animation object that can be saved as a GIF, MP4, or rendered with HTML.
276+
animation.FuncAnimation: the animation object that was created.
303277
"""
304-
fig, ax = self._get_fig_ax()
305-
plt.tight_layout()
306-
plt.close(fig)
307-
308-
def make_frame(state_index: int) -> None:
309-
state = states[state_index]
310-
self._draw(ax, state)
311-
312-
# Create the animation object.
313-
self._animation = matplotlib.animation.FuncAnimation(
314-
fig,
315-
make_frame,
316-
frames=len(states),
317-
interval=interval,
278+
return self._viewer.animate(
279+
states=states, interval=interval, save_path=save_path
318280
)
319281

320-
# Save the animation as a GIF.
321-
if save_path:
322-
self._animation.save(save_path)
323-
324-
return self._animation
325-
326282
def close(self) -> None:
327283
"""Perform any necessary cleanup.
328-
329284
Environments will automatically :meth:`close()` themselves when
330285
garbage collected or when the program exits.
331286
"""
332-
plt.close(self.figure_name)
333-
334-
def _get_fig_ax(self) -> Tuple[plt.Figure, plt.Axes]:
335-
exists = plt.fignum_exists(self.figure_name)
336-
if exists:
337-
fig = plt.figure(self.figure_name)
338-
ax = fig.get_axes()[0]
339-
else:
340-
fig = plt.figure(self.figure_name, figsize=self.figure_size)
341-
plt.suptitle(self.figure_name)
342-
plt.tight_layout()
343-
if not plt.isinteractive():
344-
fig.show()
345-
ax = fig.add_subplot()
346-
return fig, ax
347-
348-
def _draw(self, ax: plt.Axes, state: State) -> None:
349-
ax.clear()
350-
ax.set_xticks(jnp.arange(-0.5, self.num_cols - 1, 1))
351-
ax.set_yticks(jnp.arange(-0.5, self.num_rows - 1, 1))
352-
ax.tick_params(
353-
top=False,
354-
bottom=False,
355-
left=False,
356-
right=False,
357-
labelleft=False,
358-
labelbottom=False,
359-
labeltop=False,
360-
labelright=False,
361-
)
362-
background = jnp.ones_like(state.board)
363-
for i in range(self.num_rows):
364-
for j in range(self.num_cols):
365-
background = self._render_grid_square(
366-
state=state, ax=ax, i=i, j=j, background=background
367-
)
368-
ax.imshow(background, cmap="gray", vmin=0, vmax=1)
369-
ax.grid(color="black", linestyle="-", linewidth=2)
370-
371-
def _render_grid_square(
372-
self, state: State, ax: plt.Axes, i: int, j: int, background: chex.Array
373-
) -> chex.Array:
374-
board_value = state.board[i, j]
375-
if board_value != UNEXPLORED_ID:
376-
if explored_mine(state=state, action=jnp.array([i, j], dtype=jnp.int32)):
377-
background = background.at[i, j].set(0)
378-
else:
379-
ax.text(
380-
j,
381-
i,
382-
str(board_value),
383-
color=self.cmap[board_value],
384-
ha="center",
385-
va="center",
386-
fontsize="xx-large",
387-
)
388-
return background
389-
390-
def _update_display(self, fig: plt.Figure) -> None:
391-
if plt.isinteractive():
392-
# Required to update render when using Jupyter Notebook.
393-
fig.canvas.draw()
394-
if jumanji.environments.is_colab():
395-
plt.show(self.figure_name)
396-
else:
397-
# Required to update render when not using Jupyter Notebook.
398-
fig.canvas.draw_idle()
399-
fig.canvas.flush_events()
400-
401-
def _clear_display(self) -> None:
402-
if jumanji.environments.is_colab():
403-
import IPython.display
404-
405-
IPython.display.clear_output(True)
287+
self._viewer.close()

0 commit comments

Comments
 (0)