44Implementation of gym.vector.VectorEnv backed by a Schola Environment.
55"""
66
7- from typing import Dict , List , Tuple , TypeVar , Union
7+ from typing import Dict , List , Optional , Tuple , TypeVar , Union
88from schola .core .unreal_connections import UnrealConnection
9- from schola .core .env import ScholaEnv
9+ from schola .core .env import AutoResetType , ScholaEnv
1010from schola .core .error_manager import EnvironmentException
1111import numpy as np
1212import gymnasium as gym
1616
1717T = TypeVar ("T" )
1818
19+ class GymEnv (gym .Env ):
20+
21+ def __init__ (self ,
22+ unreal_connection : UnrealConnection ,
23+ verbosity : int = 0 ):
24+
25+ self ._env = ScholaEnv (
26+ unreal_connection ,
27+ verbosity = verbosity ,
28+ auto_reset_type = AutoResetType .DISABLED
29+ )
30+ self .id_manager = IdManager (self ._env .ids )
31+
32+ self .observation_space = self ._env .get_obs_space (env_id = 0 , agent_id = 0 )
33+ self .action_space = self ._env .get_action_space (env_id = 0 , agent_id = 0 )
34+ try :
35+ assert self .id_manager .num_ids == 1 , "GymEnv is designed for single-agent non-vectorized environments only. Please use GymVectorEnv for multi-agent or vectorized environments."
36+ except Exception as e :
37+ self ._env .close ()
38+ raise e
39+
40+ def close (self ) -> None :
41+ """
42+ Close the environment and release resources.
43+ """
44+ super ().close ()
45+ # Close the environment connection
46+ return self ._env .close ()
47+
48+ def reset (self , seed : Optional [int ] = None , options : Optional [Dict [str , str ]] = None ) -> Tuple [Dict [str , np .ndarray ], Dict [int , Dict [str , str ]]]:
49+ super ().reset (seed = seed , options = options )
50+ obs , nested_infos = self ._env .hard_reset (env_ids = [0 ],seeds = seed ,options = options )
51+ return obs [0 ][0 ], nested_infos [0 ][0 ]
52+
53+ def step (self , action : Dict [str , np .ndarray ]) -> Tuple [Dict [str , np .ndarray ], float , bool , bool , Dict [str , str ]]:
54+ self ._env .send_actions ({0 : {0 :action }}) # Send action for the first (and only) environment
55+ observations , rewards , terminateds , truncateds , nested_infos = self ._env .poll ()
56+ observations , rewards , terminated , truncated ,infos = observations [0 ][0 ], rewards [0 ][0 ], terminateds [0 ][0 ], truncateds [0 ][0 ], nested_infos [0 ][0 ]
57+ return observations , rewards , terminated , truncated , infos
58+
59+
1960class GymVectorEnv (gym .vector .VectorEnv ):
2061 """
2162 A Gym Vector Environment that wraps a Schola Environment.
@@ -40,7 +81,9 @@ def __init__(
4081 self ._env = ScholaEnv (
4182 unreal_connection ,
4283 verbosity ,
84+ auto_reset_type = AutoResetType .SAME_STEP ,
4385 )
86+
4487 self .id_manager = IdManager (self ._env .ids )
4588 # we just use the default UID to get the shared definition
4689 single_obs_space = self ._env .get_obs_space (* self .id_manager [0 ])
0 commit comments