Skip to content

Commit 1556cd9

Browse files
authored
feat: upgrade gym wrapper to gymnasium (#264)
1 parent 7979c36 commit 1556cd9

File tree

7 files changed

+37
-37
lines changed

7 files changed

+37
-37
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ combinatorial problems.
8686
- 🍬 **Wrappers**: easily connect to your favourite RL frameworks and libraries such as
8787
[Acme](https://github.com/deepmind/acme),
8888
[Stable Baselines3](https://github.com/DLR-RM/stable-baselines3),
89-
[RLlib](https://docs.ray.io/en/latest/rllib/index.html), [OpenAI Gym](https://github.com/openai/gym)
89+
[RLlib](https://docs.ray.io/en/latest/rllib/index.html), [Gymnasium](https://github.com/Farama-Foundation/Gymnasium)
9090
and [DeepMind-Env](https://github.com/deepmind/dm_env) through our `dm_env` and `gym` wrappers.
9191
- 🎓 **Examples**: guides to facilitate Jumanji's adoption and highlight the added value of
9292
JAX-based environments.

docs/guides/wrappers.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,19 +18,19 @@ next_timestep = dm_env.step(action)
1818
...
1919
```
2020

21-
## Jumanji To Gym
22-
We can also convert our Jumanji environments to a [Gym](https://github.com/openai/gym) environment!
23-
Below is an example of how to convert a Jumanji environment into a Gym environment.
21+
## Jumanji To Gymnasium
22+
We can also convert our Jumanji environments to a [Gymnasium](https://github.com/Farama-Foundation/Gymnasium) environment!
23+
Below is an example of how to convert a Jumanji environment into a Gymnasium environment.
2424

2525
```python
2626
import jumanji.wrappers
2727

2828
env = jumanji.make("Snake-6x6-v0")
2929
gym_env = jumanji.wrappers.JumanjiToGymWrapper(env)
3030

31-
obs = gym_env.reset()
31+
obs, info = gym_env.reset()
3232
action = gym_env.action_space.sample()
33-
observation, reward, done, extra = gym_env.step(action)
33+
observation, reward, term, trunc, info = gym_env.step(action)
3434
...
3535
```
3636

jumanji/specs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232

3333
import chex
3434
import dm_env.specs
35-
import gym
35+
import gymnasium as gym
3636
import jax
3737
import jax.numpy as jnp
3838
import numpy as np

jumanji/specs_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
import chex
2222
import dm_env.specs
23-
import gym.spaces
23+
import gymnasium as gym
2424
import jax.numpy as jnp
2525
import numpy as np
2626
import pytest

jumanji/wrappers.py

Lines changed: 18 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,11 @@
1414
from __future__ import annotations
1515

1616
from functools import cached_property
17-
from typing import Any, Callable, ClassVar, Dict, Generic, Optional, Tuple, Union
17+
from typing import Any, Callable, Dict, Generic, Optional, Tuple, TypeAlias, Union
1818

1919
import chex
2020
import dm_env.specs
21-
import gym
21+
import gymnasium as gym
2222
import jax
2323
import jax.numpy as jnp
2424
import numpy as np
@@ -28,7 +28,7 @@
2828
from jumanji.types import TimeStep
2929

3030
# Type alias that corresponds to ObsType in the Gym API
31-
GymObservation = Any
31+
GymObservation: TypeAlias = chex.ArrayNumpy | Dict[str, Union[chex.ArrayNumpy, "GymObservation"]]
3232

3333

3434
class Wrapper(Environment[State, ActionSpec, Observation], Generic[State, ActionSpec, Observation]):
@@ -584,10 +584,6 @@ def render(self, state: State) -> Any:
584584
class JumanjiToGymWrapper(gym.Env, Generic[State, ActionSpec, Observation]):
585585
"""A wrapper that converts a Jumanji `Environment` to one that follows the `gym.Env` API."""
586586

587-
# Flag that prevents `gym.register` from misinterpreting the `_step` and
588-
# `_reset` as signs of a deprecated gym Env API.
589-
_gym_disable_underscore_compat: ClassVar[bool] = True
590-
591587
def __init__(
592588
self,
593589
env: Environment[State, ActionSpec, Observation],
@@ -618,21 +614,21 @@ def reset(key: chex.PRNGKey) -> Tuple[State, Observation, Optional[Dict]]:
618614

619615
def step(
620616
state: State, action: chex.Array
621-
) -> Tuple[State, Observation, chex.Array, bool, Optional[Any]]:
617+
) -> Tuple[State, Observation, chex.Array, chex.Array, chex.Array, Optional[Any]]:
622618
"""Step function of a Jumanji environment to be jitted."""
623619
state, timestep = self._env.step(state, action)
624-
done = jnp.bool_(timestep.last())
625-
return state, timestep.observation, timestep.reward, done, timestep.extras
620+
term = timestep.discount.astype(bool)
621+
trunc = timestep.last().astype(bool)
622+
return state, timestep.observation, timestep.reward, term, trunc, timestep.extras
626623

627624
self._step = jax.jit(step, backend=self.backend)
628625

629626
def reset(
630627
self,
631628
*,
632629
seed: Optional[int] = None,
633-
return_info: bool = False,
634630
options: Optional[dict] = None,
635-
) -> Union[GymObservation, Tuple[GymObservation, Optional[Any]]]:
631+
) -> Tuple[GymObservation, Dict[str, Any]]:
636632
"""Resets the environment to an initial state by starting a new sequence
637633
and returns the first `Observation` of this sequence.
638634
@@ -648,13 +644,11 @@ def reset(
648644
# Convert the observation to a numpy array or a nested dict thereof
649645
obs = jumanji_to_gym_obs(obs)
650646

651-
if return_info:
652-
info = jax.tree_util.tree_map(np.asarray, extras)
653-
return obs, info
654-
else:
655-
return obs # type: ignore
647+
return obs, jax.device_get(extras)
656648

657-
def step(self, action: chex.ArrayNumpy) -> Tuple[GymObservation, float, bool, Optional[Any]]:
649+
def step(
650+
self, action: chex.ArrayNumpy
651+
) -> Tuple[GymObservation, float, bool, bool, Dict[str, Any]]:
658652
"""Updates the environment according to the action and returns an `Observation`.
659653
660654
Args:
@@ -667,16 +661,17 @@ def step(self, action: chex.ArrayNumpy) -> Tuple[GymObservation, float, bool, Op
667661
info: contains supplementary information such as metrics.
668662
"""
669663

670-
action = jnp.array(action) # Convert input numpy array to JAX array
671-
self._state, obs, reward, done, extras = self._step(self._state, action)
664+
action_jax = jnp.asarray(action) # Convert input numpy array to JAX array
665+
self._state, obs, reward, term, trunc, extras = self._step(self._state, action_jax)
672666

673667
# Convert to get the correct signature
674668
obs = jumanji_to_gym_obs(obs)
675669
reward = float(reward)
676-
terminated = bool(done)
677-
info = jax.tree_util.tree_map(np.asarray, extras)
670+
terminated = bool(term)
671+
truncated = bool(trunc)
672+
info = jax.device_get(extras)
678673

679-
return obs, reward, terminated, info
674+
return obs, reward, terminated, truncated, info
680675

681676
def seed(self, seed: int = 0) -> None:
682677
"""Function which sets the seed for the environment's random number generator(s).

jumanji/wrappers_test.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
import chex
1919
import dm_env.specs
20-
import gym
20+
import gymnasium as gym
2121
import jax
2222
import jax.numpy as jnp
2323
import numpy as np
@@ -265,15 +265,18 @@ def test_jumanji_environment_to_gym_env__reset(
265265
self, fake_gym_env: FakeJumanjiToGymWrapper
266266
) -> None:
267267
"""Validates reset function of the wrapped environment."""
268-
observation1 = fake_gym_env.reset()
268+
observation1, info1 = fake_gym_env.reset()
269269
state1 = fake_gym_env._state
270-
observation2 = fake_gym_env.reset()
270+
observation2, info2 = fake_gym_env.reset()
271271
state2 = fake_gym_env._state
272272

273273
# Observation is typically numpy array
274274
assert isinstance(observation1, chex.ArrayNumpy)
275275
assert isinstance(observation2, chex.ArrayNumpy)
276276

277+
assert isinstance(info1, dict)
278+
assert isinstance(info2, dict)
279+
277280
# Check that the observations are equal
278281
chex.assert_trees_all_equal(observation1, observation2)
279282
assert_trees_are_different(state1, state2)
@@ -282,12 +285,14 @@ def test_jumanji_environment_to_gym_env__step(
282285
self, fake_gym_env: FakeJumanjiToGymWrapper
283286
) -> None:
284287
"""Validates step function of the wrapped environment."""
285-
observation = fake_gym_env.reset()
288+
observation, _ = fake_gym_env.reset()
286289
action = fake_gym_env.action_space.sample()
287-
next_observation, reward, terminated, info = fake_gym_env.step(action)
290+
next_observation, reward, terminated, truncated, info = fake_gym_env.step(action)
288291
assert_trees_are_different(observation, next_observation)
289292
assert isinstance(reward, float)
290293
assert isinstance(terminated, bool)
294+
assert isinstance(truncated, bool)
295+
assert isinstance(info, dict)
291296

292297
def test_jumanji_environment_to_gym_env__observation_space(
293298
self, fake_gym_env: FakeJumanjiToGymWrapper

requirements/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
chex>=0.1.3
22
dm-env>=1.5
3-
gym>=0.22.0
3+
gymnasium>=1.0
44
huggingface-hub
55
jax>=0.2.26
66
matplotlib~=3.7.4

0 commit comments

Comments
 (0)