11import logging
22import itertools
33import numpy as np
4- from typing import Any , Dict , List , Optional , Tuple , Union
4+ from typing import Any , Dict , List , Optional , Tuple , Union , Set
55
66import gym
77from gym import error , spaces
@@ -57,17 +57,26 @@ def __init__(
5757 :param no_graphics: Whether to run the Unity simulator in no-graphics mode
5858 :param allow_multiple_visual_obs: If True, return a list of visual observations instead of only one.
5959 """
60+ base_port = 5005
61+ if environment_filename is None :
62+ base_port = UnityEnvironment .DEFAULT_EDITOR_PORT
63+
6064 self ._env = UnityEnvironment (
61- environment_filename , worker_id , no_graphics = no_graphics
65+ environment_filename ,
66+ worker_id ,
67+ base_port = base_port ,
68+ no_graphics = no_graphics ,
6269 )
6370
6471 # Take a single step so that the brain information will be sent over
6572 if not self ._env .get_agent_groups ():
6673 self ._env .step ()
6774
6875 self .visual_obs = None
69- self ._current_state = None
70- self ._n_agents = None
76+ self ._n_agents = - 1
77+ self ._done_agents : Set [int ] = set ()
78+ # Save the step result from the last time all Agents requested decisions.
79+ self ._previous_step_result : BatchedStepResult = None
7180 self ._multiagent = multiagent
7281 self ._flattener = None
7382 # Hidden flag used by Atari environments to determine if the game is over
@@ -111,6 +120,7 @@ def __init__(
111120 self ._env .reset ()
112121 step_result = self ._env .get_step_result (self .brain_name )
113122 self ._check_agents (step_result .n_agents ())
123+ self ._previous_step_result = step_result
114124
115125 # Set observation and action spaces
116126 if self .group_spec .is_action_discrete ():
@@ -153,16 +163,15 @@ def reset(self) -> Union[List[np.ndarray], np.ndarray]:
153163 Returns: observation (object/list): the initial observation of the
154164 space.
155165 """
156- self ._env .reset ()
157- info = self ._env .get_step_result (self .brain_name )
158- n_agents = info .n_agents ()
166+ step_result = self ._step (True )
167+ n_agents = step_result .n_agents ()
159168 self ._check_agents (n_agents )
160169 self .game_over = False
161170
162171 if not self ._multiagent :
163- res : GymStepResult = self ._single_step (info )
172+ res : GymStepResult = self ._single_step (step_result )
164173 else :
165- res = self ._multi_step (info )
174+ res = self ._multi_step (step_result )
166175 return res [0 ]
167176
168177 def step (self , action : List [Any ]) -> GymStepResult :
@@ -204,19 +213,20 @@ def step(self, action: List[Any]) -> GymStepResult:
204213
205214 spec = self .group_spec
206215 action = np .array (action ).reshape ((self ._n_agents , spec .action_size ))
216+ action = self ._sanitize_action (action )
207217 self ._env .set_actions (self .brain_name , action )
208- self ._env .step ()
209- info = self ._env .get_step_result (self .brain_name )
210- n_agents = info .n_agents ()
218+
219+ step_result = self ._step ()
220+
221+ n_agents = step_result .n_agents ()
211222 self ._check_agents (n_agents )
212- self ._current_state = info
213223
214224 if not self ._multiagent :
215- single_res = self ._single_step (info )
225+ single_res = self ._single_step (step_result )
216226 self .game_over = single_res [2 ]
217227 return single_res
218228 else :
219- multi_res = self ._multi_step (info )
229+ multi_res = self ._multi_step (step_result )
220230 self .game_over = all (multi_res [2 ])
221231 return multi_res
222232
@@ -233,8 +243,13 @@ def _single_step(self, info: BatchedStepResult) -> GymSingleStepResult:
233243 self .visual_obs = self ._preprocess_single (visual_obs [0 ][0 ])
234244
235245 default_observation = self .visual_obs
236- else :
246+ elif self . _get_vec_obs_size () > 0 :
237247 default_observation = self ._get_vector_obs (info )[0 , :]
248+ else :
249+ raise UnityGymException (
250+ "The Agent does not have vector observations and the environment was not setup"
251+ + "to use visual observations."
252+ )
238253
239254 return (
240255 default_observation ,
@@ -335,7 +350,7 @@ def _check_agents(self, n_agents: int) -> None:
335350 "The environment was launched as a mutli-agent environment, however"
336351 "there is only one agent in the scene."
337352 )
338- if self ._n_agents is None :
353+ if self ._n_agents == - 1 :
339354 self ._n_agents = n_agents
340355 logger .info ("{} agents within environment." .format (n_agents ))
341356 elif self ._n_agents != n_agents :
@@ -344,6 +359,84 @@ def _check_agents(self, n_agents: int) -> None:
344359 "initialization. This is not supported."
345360 )
346361
362+ def _sanitize_info (self , step_result : BatchedStepResult ) -> BatchedStepResult :
363+ n_extra_agents = step_result .n_agents () - self ._n_agents
364+ if n_extra_agents < 0 or n_extra_agents > self ._n_agents :
365+ # In this case, some Agents did not request a decision when expected
366+ # or too many requested a decision
367+ raise UnityGymException (
368+ "The number of agents in the scene does not match the expected number."
369+ )
370+
371+ # remove the done Agents
372+ indices_to_keep : List [int ] = []
373+ for index , is_done in enumerate (step_result .done ):
374+ if not is_done :
375+ indices_to_keep .append (index )
376+
377+ # Set the new AgentDone flags to True
378+ # Note that the corresponding agent_id that gets marked done will be different
379+ # than the original agent that was done, but this is OK since the gym interface
380+ # only cares about the ordering.
381+ for index , agent_id in enumerate (step_result .agent_id ):
382+ if not self ._previous_step_result .contains_agent (agent_id ):
383+ step_result .done [index ] = True
384+ if agent_id in self ._done_agents :
385+ step_result .done [index ] = True
386+ self ._done_agents = set ()
387+ self ._previous_step_result = step_result # store the new original
388+
389+ _mask : Optional [List [np .array ]] = None
390+ if step_result .action_mask is not None :
391+ _mask = []
392+ for mask_index in range (len (step_result .action_mask )):
393+ _mask .append (step_result .action_mask [mask_index ][indices_to_keep ])
394+ new_obs : List [np .array ] = []
395+ for obs_index in range (len (step_result .obs )):
396+ new_obs .append (step_result .obs [obs_index ][indices_to_keep ])
397+ return BatchedStepResult (
398+ obs = new_obs ,
399+ reward = step_result .reward [indices_to_keep ],
400+ done = step_result .done [indices_to_keep ],
401+ max_step = step_result .max_step [indices_to_keep ],
402+ agent_id = step_result .agent_id [indices_to_keep ],
403+ action_mask = _mask ,
404+ )
405+
406+ def _sanitize_action (self , action : np .array ) -> np .array :
407+ if self ._previous_step_result .n_agents () == self ._n_agents :
408+ return action
409+ sanitized_action = np .zeros (
410+ (self ._previous_step_result .n_agents (), self .group_spec .action_size )
411+ )
412+ input_index = 0
413+ for index in range (self ._previous_step_result .n_agents ()):
414+ if not self ._previous_step_result .done [index ]:
415+ sanitized_action [index , :] = action [input_index , :]
416+ input_index = input_index + 1
417+ return sanitized_action
418+
419+ def _step (self , needs_reset : bool = False ) -> BatchedStepResult :
420+ if needs_reset :
421+ self ._env .reset ()
422+ else :
423+ self ._env .step ()
424+ info = self ._env .get_step_result (self .brain_name )
425+ # Two possible cases here:
426+ # 1) all agents requested decisions (some of which might be done)
427+ # 2) some Agents were marked Done in between steps.
428+ # In case 2, we re-request decisions until all agents request a real decision.
429+ while info .n_agents () - sum (info .done ) < self ._n_agents :
430+ if not info .done .all ():
431+ raise UnityGymException (
432+ "The environment does not have the expected amount of agents."
433+ + "Some agents did not request decisions at the same time."
434+ )
435+ self ._done_agents .update (list (info .agent_id ))
436+ self ._env .step ()
437+ info = self ._env .get_step_result (self .brain_name )
438+ return self ._sanitize_info (info )
439+
347440 @property
348441 def metadata (self ):
349442 return {"render.modes" : ["rgb_array" ]}
0 commit comments