1414from __future__ import annotations
1515
1616from functools import cached_property
17- from typing import Any , Callable , ClassVar , Dict , Generic , Optional , Tuple , Union
17+ from typing import Any , Callable , Dict , Generic , Optional , Tuple , TypeAlias , Union
1818
1919import chex
2020import dm_env .specs
21- import gym
21+ import gymnasium as gym
2222import jax
2323import jax .numpy as jnp
2424import numpy as np
2828from jumanji .types import TimeStep
2929
3030# Type alias that corresponds to ObsType in the Gym API
31- GymObservation = Any
31+ GymObservation : TypeAlias = chex . ArrayNumpy | Dict [ str , Union [ chex . ArrayNumpy , "GymObservation" ]]
3232
3333
3434class Wrapper (Environment [State , ActionSpec , Observation ], Generic [State , ActionSpec , Observation ]):
@@ -584,10 +584,6 @@ def render(self, state: State) -> Any:
584584class JumanjiToGymWrapper (gym .Env , Generic [State , ActionSpec , Observation ]):
585585 """A wrapper that converts a Jumanji `Environment` to one that follows the `gym.Env` API."""
586586
587- # Flag that prevents `gym.register` from misinterpreting the `_step` and
588- # `_reset` as signs of a deprecated gym Env API.
589- _gym_disable_underscore_compat : ClassVar [bool ] = True
590-
591587 def __init__ (
592588 self ,
593589 env : Environment [State , ActionSpec , Observation ],
@@ -618,21 +614,21 @@ def reset(key: chex.PRNGKey) -> Tuple[State, Observation, Optional[Dict]]:
618614
619615 def step (
620616 state : State , action : chex .Array
621- ) -> Tuple [State , Observation , chex .Array , bool , Optional [Any ]]:
617+ ) -> Tuple [State , Observation , chex .Array , chex . Array , chex . Array , Optional [Any ]]:
622618 """Step function of a Jumanji environment to be jitted."""
623619 state , timestep = self ._env .step (state , action )
624- done = jnp .bool_ (timestep .last ())
625- return state , timestep .observation , timestep .reward , done , timestep .extras
620+ term = timestep .discount .astype (bool )
621+ trunc = timestep .last ().astype (bool )
622+ return state , timestep .observation , timestep .reward , term , trunc , timestep .extras
626623
627624 self ._step = jax .jit (step , backend = self .backend )
628625
629626 def reset (
630627 self ,
631628 * ,
632629 seed : Optional [int ] = None ,
633- return_info : bool = False ,
634630 options : Optional [dict ] = None ,
635- ) -> Union [GymObservation , Tuple [ GymObservation , Optional [ Any ] ]]:
631+ ) -> Tuple [GymObservation , Dict [ str , Any ]]:
636632 """Resets the environment to an initial state by starting a new sequence
637633 and returns the first `Observation` of this sequence.
638634
@@ -648,13 +644,11 @@ def reset(
648644 # Convert the observation to a numpy array or a nested dict thereof
649645 obs = jumanji_to_gym_obs (obs )
650646
651- if return_info :
652- info = jax .tree_util .tree_map (np .asarray , extras )
653- return obs , info
654- else :
655- return obs # type: ignore
647+ return obs , jax .device_get (extras )
656648
657- def step (self , action : chex .ArrayNumpy ) -> Tuple [GymObservation , float , bool , Optional [Any ]]:
649+ def step (
650+ self , action : chex .ArrayNumpy
651+ ) -> Tuple [GymObservation , float , bool , bool , Dict [str , Any ]]:
658652 """Updates the environment according to the action and returns an `Observation`.
659653
660654 Args:
@@ -667,16 +661,17 @@ def step(self, action: chex.ArrayNumpy) -> Tuple[GymObservation, float, bool, Op
667661 info: contains supplementary information such as metrics.
668662 """
669663
670- action = jnp .array (action ) # Convert input numpy array to JAX array
671- self ._state , obs , reward , done , extras = self ._step (self ._state , action )
664+ action_jax = jnp .asarray (action ) # Convert input numpy array to JAX array
665+ self ._state , obs , reward , term , trunc , extras = self ._step (self ._state , action_jax )
672666
673667 # Convert to get the correct signature
674668 obs = jumanji_to_gym_obs (obs )
675669 reward = float (reward )
676- terminated = bool (done )
677- info = jax .tree_util .tree_map (np .asarray , extras )
670+ terminated = bool (term )
671+ truncated = bool (trunc )
672+ info = jax .device_get (extras )
678673
679- return obs , reward , terminated , info
674+ return obs , reward , terminated , truncated , info
680675
681676 def seed (self , seed : int = 0 ) -> None :
682677 """Function which sets the seed for the environment's random number generator(s).
0 commit comments