1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15- from typing import List , Optional , Sequence , Tuple
15+ from typing import Optional , Sequence , Tuple
1616
1717import chex
1818import jax
1919import jax .numpy as jnp
2020import matplotlib .animation
21- import matplotlib . pyplot as plt
21+ from numpy . typing import NDArray
2222
23- import jumanji .environments
2423from jumanji import specs
2524from jumanji .env import Environment
26- from jumanji .environments .logic .minesweeper .constants import (
27- COLOUR_MAPPING ,
28- PATCH_SIZE ,
29- UNEXPLORED_ID ,
30- )
25+ from jumanji .environments .logic .minesweeper .constants import PATCH_SIZE , UNEXPLORED_ID
3126from jumanji .environments .logic .minesweeper .done import DefaultDoneFn , DoneFn
27+ from jumanji .environments .logic .minesweeper .generator import (
28+ Generator ,
29+ UniformSamplingGenerator ,
30+ )
3231from jumanji .environments .logic .minesweeper .reward import DefaultRewardFn , RewardFn
3332from jumanji .environments .logic .minesweeper .types import Observation , State
34- from jumanji .environments .logic .minesweeper .utils import (
35- count_adjacent_mines ,
36- create_flat_mine_locations ,
37- explored_mine ,
38- )
33+ from jumanji .environments .logic .minesweeper .utils import count_adjacent_mines
34+ from jumanji .environments .logic .minesweeper .viewer import MinesweeperViewer
3935from jumanji .types import TimeStep , restart , termination , transition
36+ from jumanji .viewer import Viewer
4037
4138
4239class Minesweeper (Environment [State ]):
@@ -53,7 +50,7 @@ class Minesweeper(Environment[State]):
5350 specifies how many timesteps have elapsed since environment reset.
5451
5552 - action:
56- multi discrete array containing the square to explore (height and width ).
53+ multi discrete array containing the square to explore (row and col ).
5754
5855 - reward: jax array (float32):
5956 Configurable function of state and action. By default:
@@ -92,46 +89,47 @@ class Minesweeper(Environment[State]):
9289
9390 def __init__ (
9491 self ,
95- num_rows : int = 10 ,
96- num_cols : int = 10 ,
97- num_mines : int = 10 ,
92+ generator : Optional [Generator ] = None ,
9893 reward_function : Optional [RewardFn ] = None ,
9994 done_function : Optional [DoneFn ] = None ,
100- color_mapping : Optional [List [ str ]] = None ,
95+ viewer : Optional [Viewer [ State ]] = None ,
10196 ):
10297 """Instantiate a `Minesweeper` environment.
10398
10499 Args:
105- num_rows: number of rows, i.e. height of the board. Defaults to 10.
106- num_cols: number of columns, i.e. width of the board. Defaults to 10.
107- num_mines: number of mines on the board. Defaults to 10.
100+ generator: `Generator` to generate problem instances on environment reset.
101+ Implemented options are [`SamplingGenerator`]. Defaults to `SamplingGenerator`.
102+ The generator will have attributes:
103+ - num_rows: number of rows, i.e. height of the board. Defaults to 10.
104+ - num_cols: number of columns, i.e. width of the board. Defaults to 10.
105+ - num_mines: number of mines generated. Defaults to 10.
108106 reward_function: `RewardFn` whose `__call__` method computes the reward of an
109107 environment transition based on the given current state and selected action.
110- Implemented options are [`DefaultRewardFn`]. Defaults to `DefaultRewardFn`.
108+ Implemented options are [`DefaultRewardFn`]. Defaults to `DefaultRewardFn`, giving
109+ a reward of 1.0 for revealing an empty square, 0.0 for revealing a mine, and
110+ 0.0 for an invalid action (selecting an already revealed square).
111111 done_function: `DoneFn` whose `__call__` method computes the done signal given the
112112 current state, action taken, and next state.
113- Implemented options are [`DefaultDoneFn`]. Defaults to `DefaultDoneFn`.
114- color_mapping: colour map used for rendering.
113+ Implemented options are [`DefaultDoneFn`]. Defaults to `DefaultDoneFn`, ending the
114+ episode on solving the board, revealing a mine, or picking an invalid action.
115+ viewer: `Viewer` to support rendering and animation methods.
116+ Implemented options are [`MinesweeperViewer`]. Defaults to `MinesweeperViewer`.
115117 """
116- if num_rows <= 1 or num_cols <= 1 :
117- raise ValueError (
118- f"Should make a board of height and width greater than 1, "
119- f"got num_rows={ num_rows } , num_cols={ num_cols } "
120- )
121- if num_mines < 0 or num_mines >= num_rows * num_cols :
122- raise ValueError (
123- f"Number of mines should be constrained between 0 and the size of the board, "
124- f"got { num_mines } "
125- )
126- self .num_rows = num_rows
127- self .num_cols = num_cols
128- self .num_mines = num_mines
129- self .reward_function = reward_function or DefaultRewardFn ()
118+ self .reward_function = reward_function or DefaultRewardFn (
119+ revealed_empty_square_reward = 1.0 ,
120+ revealed_mine_reward = 0.0 ,
121+ invalid_action_reward = 0.0 ,
122+ )
130123 self .done_function = done_function or DefaultDoneFn ()
131-
132- self .cmap = color_mapping if color_mapping else COLOUR_MAPPING
133- self .figure_name = f"{ num_rows } x{ num_cols } Minesweeper"
134- self .figure_size = (6.0 , 6.0 )
124+ self .generator = generator or UniformSamplingGenerator (
125+ num_rows = 10 , num_cols = 10 , num_mines = 10
126+ )
127+ self .num_rows = self .generator .num_rows
128+ self .num_cols = self .generator .num_cols
129+ self .num_mines = self .generator .num_mines
130+ self ._viewer = viewer or MinesweeperViewer (
131+ num_rows = self .num_rows , num_cols = self .num_cols
132+ )
135133
136134 def reset (self , key : chex .PRNGKey ) -> Tuple [State , TimeStep [Observation ]]:
137135 """Resets the environment.
@@ -144,25 +142,7 @@ def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]:
144142 timestep: `TimeStep` corresponding to the first timestep returned by the
145143 environment.
146144 """
147- key , sample_key = jax .random .split (key )
148- board = jnp .full (
149- shape = (self .num_rows , self .num_cols ),
150- fill_value = UNEXPLORED_ID ,
151- dtype = jnp .int32 ,
152- )
153- step_count = jnp .array (0 , jnp .int32 )
154- flat_mine_locations = create_flat_mine_locations (
155- key = sample_key ,
156- num_rows = self .num_rows ,
157- num_cols = self .num_cols ,
158- num_mines = self .num_mines ,
159- )
160- state = State (
161- board = board ,
162- step_count = step_count ,
163- key = key ,
164- flat_mine_locations = flat_mine_locations ,
165- )
145+ state = self .generator (key )
166146 observation = self ._state_to_observation (state = state )
167147 timestep = restart (observation = observation )
168148 return state , timestep
@@ -180,9 +160,7 @@ def step(
180160 next_state: `State` corresponding to the next state of the environment,
181161 next_timestep: `TimeStep` corresponding to the timestep returned by the environment.
182162 """
183- board = state .board
184- action_height , action_width = action
185- board = board .at [action_height , action_width ].set (
163+ board = state .board .at [tuple (action )].set (
186164 count_adjacent_mines (state = state , action = action )
187165 )
188166 step_count = state .step_count + 1
@@ -272,134 +250,38 @@ def _state_to_observation(self, state: State) -> Observation:
272250 step_count = state .step_count ,
273251 )
274252
275- def render (self , state : State ) -> None :
276- """Render the given environment state using matplotlib .
253+ def render (self , state : State ) -> Optional [ NDArray ] :
254+ """Renders the current state of the board .
277255
278256 Args:
279- state: environment state to be rendered.
280-
257+ state: the current state to be rendered.
281258 """
282- self ._clear_display ()
283- fig , ax = self ._get_fig_ax ()
284- self ._draw (ax , state )
285- self ._update_display (fig )
259+ return self ._viewer .render (state = state )
286260
287261 def animate (
288262 self ,
289263 states : Sequence [State ],
290264 interval : int = 200 ,
291265 save_path : Optional [str ] = None ,
292266 ) -> matplotlib .animation .FuncAnimation :
293- """Create an animation from a sequence of environment states.
267+ """Creates an animated gif of the board based on the sequence of states.
294268
295269 Args:
296- states: sequence of environment states corresponding to consecutive timesteps .
297- interval: delay between frames in milliseconds, default to 200.
270+ states: a list of `State` objects representing the sequence of states .
271+ interval: the delay between frames in milliseconds, default to 200.
298272 save_path: the path where the animation file should be saved. If it is None, the plot
299273 will not be saved.
300274
301275 Returns:
302- Animation object that can be saved as a GIF, MP4, or rendered with HTML .
276+ animation.FuncAnimation: the animation object that was created .
303277 """
304- fig , ax = self ._get_fig_ax ()
305- plt .tight_layout ()
306- plt .close (fig )
307-
308- def make_frame (state_index : int ) -> None :
309- state = states [state_index ]
310- self ._draw (ax , state )
311-
312- # Create the animation object.
313- self ._animation = matplotlib .animation .FuncAnimation (
314- fig ,
315- make_frame ,
316- frames = len (states ),
317- interval = interval ,
278+ return self ._viewer .animate (
279+ states = states , interval = interval , save_path = save_path
318280 )
319281
320- # Save the animation as a GIF.
321- if save_path :
322- self ._animation .save (save_path )
323-
324- return self ._animation
325-
326282 def close (self ) -> None :
327283 """Perform any necessary cleanup.
328-
329284 Environments will automatically :meth:`close()` themselves when
330285 garbage collected or when the program exits.
331286 """
332- plt .close (self .figure_name )
333-
334- def _get_fig_ax (self ) -> Tuple [plt .Figure , plt .Axes ]:
335- exists = plt .fignum_exists (self .figure_name )
336- if exists :
337- fig = plt .figure (self .figure_name )
338- ax = fig .get_axes ()[0 ]
339- else :
340- fig = plt .figure (self .figure_name , figsize = self .figure_size )
341- plt .suptitle (self .figure_name )
342- plt .tight_layout ()
343- if not plt .isinteractive ():
344- fig .show ()
345- ax = fig .add_subplot ()
346- return fig , ax
347-
348- def _draw (self , ax : plt .Axes , state : State ) -> None :
349- ax .clear ()
350- ax .set_xticks (jnp .arange (- 0.5 , self .num_cols - 1 , 1 ))
351- ax .set_yticks (jnp .arange (- 0.5 , self .num_rows - 1 , 1 ))
352- ax .tick_params (
353- top = False ,
354- bottom = False ,
355- left = False ,
356- right = False ,
357- labelleft = False ,
358- labelbottom = False ,
359- labeltop = False ,
360- labelright = False ,
361- )
362- background = jnp .ones_like (state .board )
363- for i in range (self .num_rows ):
364- for j in range (self .num_cols ):
365- background = self ._render_grid_square (
366- state = state , ax = ax , i = i , j = j , background = background
367- )
368- ax .imshow (background , cmap = "gray" , vmin = 0 , vmax = 1 )
369- ax .grid (color = "black" , linestyle = "-" , linewidth = 2 )
370-
371- def _render_grid_square (
372- self , state : State , ax : plt .Axes , i : int , j : int , background : chex .Array
373- ) -> chex .Array :
374- board_value = state .board [i , j ]
375- if board_value != UNEXPLORED_ID :
376- if explored_mine (state = state , action = jnp .array ([i , j ], dtype = jnp .int32 )):
377- background = background .at [i , j ].set (0 )
378- else :
379- ax .text (
380- j ,
381- i ,
382- str (board_value ),
383- color = self .cmap [board_value ],
384- ha = "center" ,
385- va = "center" ,
386- fontsize = "xx-large" ,
387- )
388- return background
389-
390- def _update_display (self , fig : plt .Figure ) -> None :
391- if plt .isinteractive ():
392- # Required to update render when using Jupyter Notebook.
393- fig .canvas .draw ()
394- if jumanji .environments .is_colab ():
395- plt .show (self .figure_name )
396- else :
397- # Required to update render when not using Jupyter Notebook.
398- fig .canvas .draw_idle ()
399- fig .canvas .flush_events ()
400-
401- def _clear_display (self ) -> None :
402- if jumanji .environments .is_colab ():
403- import IPython .display
404-
405- IPython .display .clear_output (True )
287+ self ._viewer .close ()
0 commit comments