22# ProrokLab (https://www.proroklab.org/)
33# All rights reserved.
44import contextlib
5- import functools
65import math
76import random
87from ctypes import byref
@@ -47,22 +46,6 @@ def local_seed(vmas_random_state):
4746 random .setstate (py_state )
4847
4948
50- def apply_local_seed (cls ):
51- """Applies the local seed to all the functions."""
52- for attr_name , attr_value in cls .__dict__ .items ():
53- if callable (attr_value ):
54- wrapped = attr_value # Keep reference to original method
55-
56- @functools .wraps (wrapped )
57- def wrapper (self , * args , _wrapped = wrapped , ** kwargs ):
58- with local_seed (cls .vmas_random_state ):
59- return _wrapped (self , * args , ** kwargs )
60-
61- setattr (cls , attr_name , wrapper )
62- return cls
63-
64-
65- @apply_local_seed
6649class Environment (TorchVectorizedObject ):
6750 metadata = {
6851 "render.modes" : ["human" , "rgb_array" ],
@@ -74,6 +57,7 @@ class Environment(TorchVectorizedObject):
7457 random .getstate (),
7558 ]
7659
60+ @local_seed (vmas_random_state )
7761 def __init__ (
7862 self ,
7963 scenario : BaseScenario ,
@@ -108,7 +92,7 @@ def __init__(
10892 self .grad_enabled = grad_enabled
10993 self .terminated_truncated = terminated_truncated
11094
111- observations = self .reset (seed = seed )
95+ observations = self ._reset (seed = seed )
11296
11397 # configure spaces
11498 self .multidiscrete_actions = multidiscrete_actions
@@ -121,6 +105,7 @@ def __init__(
121105 self .visible_display = None
122106 self .text_lines = None
123107
108+ @local_seed (vmas_random_state )
124109 def reset (
125110 self ,
126111 seed : Optional [int ] = None ,
@@ -132,21 +117,112 @@ def reset(
132117 Resets the environment in a vectorized way
133118 Returns observations for all envs and agents
134119 """
120+ return self ._reset (
121+ seed = seed ,
122+ return_observations = return_observations ,
123+ return_info = return_info ,
124+ return_dones = return_dones ,
125+ )
126+
127+ @local_seed (vmas_random_state )
128+ def reset_at (
129+ self ,
130+ index : int ,
131+ return_observations : bool = True ,
132+ return_info : bool = False ,
133+ return_dones : bool = False ,
134+ ):
135+ """
136+ Resets the environment at index
137+ Returns observations for all agents in that environment
138+ """
139+ return self ._reset_at (
140+ index = index ,
141+ return_observations = return_observations ,
142+ return_info = return_info ,
143+ return_dones = return_dones ,
144+ )
145+
146+ @local_seed (vmas_random_state )
147+ def get_from_scenario (
148+ self ,
149+ get_observations : bool ,
150+ get_rewards : bool ,
151+ get_infos : bool ,
152+ get_dones : bool ,
153+ dict_agent_names : Optional [bool ] = None ,
154+ ):
155+ """
156+ Get the environment data from the scenario
157+
158+ Args:
159+ get_observations (bool): whether to return the observations
160+ get_rewards (bool): whether to return the rewards
161+ get_infos (bool): whether to return the infos
162+ get_dones (bool): whether to return the dones
163+ dict_agent_names (bool, optional): whether to return the information in a dictionary with agent names as keys
164+ or in a list
165+
166+ Returns:
167+ The agents' data
168+
169+ """
170+ return self ._get_from_scenario (
171+ get_observations = get_observations ,
172+ get_rewards = get_rewards ,
173+ get_infos = get_infos ,
174+ get_dones = get_dones ,
175+ dict_agent_names = dict_agent_names ,
176+ )
177+
178+ @local_seed (vmas_random_state )
179+ def seed (self , seed = None ):
180+ """
181+ Sets the seed for the environment
182+ Args:
183+ seed (int, optional): Seed for the environment. Defaults to None.
184+
185+ """
186+ return self ._seed (seed = seed )
187+
188+ @local_seed (vmas_random_state )
189+ def done (self ):
190+ """
191+ Get the done flags for the scenario.
192+
193+ Returns:
194+ Either terminated, truncated (if self.terminated_truncated==True) or terminated + truncated (if self.terminated_truncated==False)
195+
196+ """
197+ return self ._done ()
198+
199+ def _reset (
200+ self ,
201+ seed : Optional [int ] = None ,
202+ return_observations : bool = True ,
203+ return_info : bool = False ,
204+ return_dones : bool = False ,
205+ ):
206+ """
207+ Resets the environment in a vectorized way
208+ Returns observations for all envs and agents
209+ """
210+
135211 if seed is not None :
136- self .seed (seed )
212+ self ._seed (seed )
137213 # reset world
138214 self .scenario .env_reset_world_at (env_index = None )
139215 self .steps = torch .zeros (self .num_envs , device = self .device )
140216
141- result = self .get_from_scenario (
217+ result = self ._get_from_scenario (
142218 get_observations = return_observations ,
143219 get_infos = return_info ,
144220 get_rewards = False ,
145221 get_dones = return_dones ,
146222 )
147223 return result [0 ] if result and len (result ) == 1 else result
148224
149- def reset_at (
225+ def _reset_at (
150226 self ,
151227 index : int ,
152228 return_observations : bool = True ,
@@ -161,7 +237,7 @@ def reset_at(
161237 self .scenario .env_reset_world_at (index )
162238 self .steps [index ] = 0
163239
164- result = self .get_from_scenario (
240+ result = self ._get_from_scenario (
165241 get_observations = return_observations ,
166242 get_infos = return_info ,
167243 get_rewards = False ,
@@ -170,7 +246,7 @@ def reset_at(
170246
171247 return result [0 ] if result and len (result ) == 1 else result
172248
173- def get_from_scenario (
249+ def _get_from_scenario (
174250 self ,
175251 get_observations : bool ,
176252 get_rewards : bool ,
@@ -218,23 +294,30 @@ def get_from_scenario(
218294
219295 if self .terminated_truncated :
220296 if get_dones :
221- terminated , truncated = self .done ()
297+ terminated , truncated = self ._done ()
222298 result = [obs , rewards , terminated , truncated , infos ]
223299 else :
224300 if get_dones :
225- dones = self .done ()
301+ dones = self ._done ()
226302 result = [obs , rewards , dones , infos ]
227303
228304 return [data for data in result if data is not None ]
229305
230- def seed (self , seed = None ):
306+ def _seed (self , seed = None ):
307+ """
308+ Sets the seed for the environment
309+ Args:
310+ seed (int, optional): Seed for the environment. Defaults to None.
311+
312+ """
231313 if seed is None :
232314 seed = 0
233315 torch .manual_seed (seed )
234316 np .random .seed (seed )
235317 random .seed (seed )
236318 return [seed ]
237319
320+ @local_seed (vmas_random_state )
238321 def step (self , actions : Union [List , Dict ]):
239322 """Performs a vectorized step on all sub environments using `actions`.
240323 Args:
@@ -309,14 +392,21 @@ def step(self, actions: Union[List, Dict]):
309392
310393 self .steps += 1
311394
312- return self .get_from_scenario (
395+ return self ._get_from_scenario (
313396 get_observations = True ,
314397 get_infos = True ,
315398 get_rewards = True ,
316399 get_dones = True ,
317400 )
318401
319- def done (self ):
402+ def _done (self ):
403+ """
404+ Get the done flags for the scenario.
405+
406+ Returns:
407+ Either terminated, truncated (if self.terminated_truncated==True) or terminated + truncated (if self.terminated_truncated==False)
408+
409+ """
320410 terminated = self .scenario .done ().clone ()
321411
322412 if self .max_steps is not None :
@@ -427,6 +517,7 @@ def get_agent_observation_space(self, agent: Agent, obs: AGENT_OBS_TYPE):
427517 f"Invalid type of observation { obs } for agent { agent .name } "
428518 )
429519
520+ @local_seed (vmas_random_state )
430521 def get_random_action (self , agent : Agent ) -> torch .Tensor :
431522 """Returns a random action for the given agent.
432523
@@ -652,6 +743,7 @@ def _set_action(self, action, agent):
652743 )
653744 agent .action .c += noise
654745
746+ @local_seed (vmas_random_state )
655747 def render (
656748 self ,
657749 mode = "human" ,
0 commit comments