Skip to content

Commit a4acb65

Browse files
Fixing Gym for single and multi-agent following reset changes on C# (#3417)
* """""Fixing"""""" * Update gym-unity/gym_unity/envs/__init__.py Co-Authored-By: Chris Elion <[email protected]> * Update gym-unity/gym_unity/envs/__init__.py Co-Authored-By: Chris Elion <[email protected]> * addressing comments * Update gym-unity/gym_unity/envs/__init__.py Co-Authored-By: Chris Elion <[email protected]> * Update gym-unity/gym_unity/envs/__init__.py Co-Authored-By: Chris Elion <[email protected]> * Update gym-unity/gym_unity/envs/__init__.py Co-Authored-By: Chris Elion <[email protected]> * bug fix * Fixing the test * gym multiagent comments (#3421) * rename and comments * enumerate Co-authored-by: Chris Elion <[email protected]>
1 parent 7abe4d4 commit a4acb65

File tree

2 files changed

+111
-18
lines changed

2 files changed

+111
-18
lines changed

gym-unity/gym_unity/envs/__init__.py

Lines changed: 110 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import logging
22
import itertools
33
import numpy as np
4-
from typing import Any, Dict, List, Optional, Tuple, Union
4+
from typing import Any, Dict, List, Optional, Tuple, Union, Set
55

66
import gym
77
from gym import error, spaces
@@ -57,17 +57,26 @@ def __init__(
5757
:param no_graphics: Whether to run the Unity simulator in no-graphics mode
5858
:param allow_multiple_visual_obs: If True, return a list of visual observations instead of only one.
5959
"""
60+
base_port = 5005
61+
if environment_filename is None:
62+
base_port = UnityEnvironment.DEFAULT_EDITOR_PORT
63+
6064
self._env = UnityEnvironment(
61-
environment_filename, worker_id, no_graphics=no_graphics
65+
environment_filename,
66+
worker_id,
67+
base_port=base_port,
68+
no_graphics=no_graphics,
6269
)
6370

6471
# Take a single step so that the brain information will be sent over
6572
if not self._env.get_agent_groups():
6673
self._env.step()
6774

6875
self.visual_obs = None
69-
self._current_state = None
70-
self._n_agents = None
76+
self._n_agents = -1
77+
self._done_agents: Set[int] = set()
78+
# Save the step result from the last time all Agents requested decisions.
79+
self._previous_step_result: BatchedStepResult = None
7180
self._multiagent = multiagent
7281
self._flattener = None
7382
# Hidden flag used by Atari environments to determine if the game is over
@@ -111,6 +120,7 @@ def __init__(
111120
self._env.reset()
112121
step_result = self._env.get_step_result(self.brain_name)
113122
self._check_agents(step_result.n_agents())
123+
self._previous_step_result = step_result
114124

115125
# Set observation and action spaces
116126
if self.group_spec.is_action_discrete():
@@ -153,16 +163,15 @@ def reset(self) -> Union[List[np.ndarray], np.ndarray]:
153163
Returns: observation (object/list): the initial observation of the
154164
space.
155165
"""
156-
self._env.reset()
157-
info = self._env.get_step_result(self.brain_name)
158-
n_agents = info.n_agents()
166+
step_result = self._step(True)
167+
n_agents = step_result.n_agents()
159168
self._check_agents(n_agents)
160169
self.game_over = False
161170

162171
if not self._multiagent:
163-
res: GymStepResult = self._single_step(info)
172+
res: GymStepResult = self._single_step(step_result)
164173
else:
165-
res = self._multi_step(info)
174+
res = self._multi_step(step_result)
166175
return res[0]
167176

168177
def step(self, action: List[Any]) -> GymStepResult:
@@ -204,19 +213,20 @@ def step(self, action: List[Any]) -> GymStepResult:
204213

205214
spec = self.group_spec
206215
action = np.array(action).reshape((self._n_agents, spec.action_size))
216+
action = self._sanitize_action(action)
207217
self._env.set_actions(self.brain_name, action)
208-
self._env.step()
209-
info = self._env.get_step_result(self.brain_name)
210-
n_agents = info.n_agents()
218+
219+
step_result = self._step()
220+
221+
n_agents = step_result.n_agents()
211222
self._check_agents(n_agents)
212-
self._current_state = info
213223

214224
if not self._multiagent:
215-
single_res = self._single_step(info)
225+
single_res = self._single_step(step_result)
216226
self.game_over = single_res[2]
217227
return single_res
218228
else:
219-
multi_res = self._multi_step(info)
229+
multi_res = self._multi_step(step_result)
220230
self.game_over = all(multi_res[2])
221231
return multi_res
222232

@@ -233,8 +243,13 @@ def _single_step(self, info: BatchedStepResult) -> GymSingleStepResult:
233243
self.visual_obs = self._preprocess_single(visual_obs[0][0])
234244

235245
default_observation = self.visual_obs
236-
else:
246+
elif self._get_vec_obs_size() > 0:
237247
default_observation = self._get_vector_obs(info)[0, :]
248+
else:
249+
raise UnityGymException(
250+
"The Agent does not have vector observations and the environment was not setup"
251+
+ "to use visual observations."
252+
)
238253

239254
return (
240255
default_observation,
@@ -335,7 +350,7 @@ def _check_agents(self, n_agents: int) -> None:
335350
"The environment was launched as a mutli-agent environment, however"
336351
"there is only one agent in the scene."
337352
)
338-
if self._n_agents is None:
353+
if self._n_agents == -1:
339354
self._n_agents = n_agents
340355
logger.info("{} agents within environment.".format(n_agents))
341356
elif self._n_agents != n_agents:
@@ -344,6 +359,84 @@ def _check_agents(self, n_agents: int) -> None:
344359
"initialization. This is not supported."
345360
)
346361

362+
def _sanitize_info(self, step_result: BatchedStepResult) -> BatchedStepResult:
363+
n_extra_agents = step_result.n_agents() - self._n_agents
364+
if n_extra_agents < 0 or n_extra_agents > self._n_agents:
365+
# In this case, some Agents did not request a decision when expected
366+
# or too many requested a decision
367+
raise UnityGymException(
368+
"The number of agents in the scene does not match the expected number."
369+
)
370+
371+
# remove the done Agents
372+
indices_to_keep: List[int] = []
373+
for index, is_done in enumerate(step_result.done):
374+
if not is_done:
375+
indices_to_keep.append(index)
376+
377+
# Set the new AgentDone flags to True
378+
# Note that the corresponding agent_id that gets marked done will be different
379+
# than the original agent that was done, but this is OK since the gym interface
380+
# only cares about the ordering.
381+
for index, agent_id in enumerate(step_result.agent_id):
382+
if not self._previous_step_result.contains_agent(agent_id):
383+
step_result.done[index] = True
384+
if agent_id in self._done_agents:
385+
step_result.done[index] = True
386+
self._done_agents = set()
387+
self._previous_step_result = step_result # store the new original
388+
389+
_mask: Optional[List[np.array]] = None
390+
if step_result.action_mask is not None:
391+
_mask = []
392+
for mask_index in range(len(step_result.action_mask)):
393+
_mask.append(step_result.action_mask[mask_index][indices_to_keep])
394+
new_obs: List[np.array] = []
395+
for obs_index in range(len(step_result.obs)):
396+
new_obs.append(step_result.obs[obs_index][indices_to_keep])
397+
return BatchedStepResult(
398+
obs=new_obs,
399+
reward=step_result.reward[indices_to_keep],
400+
done=step_result.done[indices_to_keep],
401+
max_step=step_result.max_step[indices_to_keep],
402+
agent_id=step_result.agent_id[indices_to_keep],
403+
action_mask=_mask,
404+
)
405+
406+
def _sanitize_action(self, action: np.array) -> np.array:
407+
if self._previous_step_result.n_agents() == self._n_agents:
408+
return action
409+
sanitized_action = np.zeros(
410+
(self._previous_step_result.n_agents(), self.group_spec.action_size)
411+
)
412+
input_index = 0
413+
for index in range(self._previous_step_result.n_agents()):
414+
if not self._previous_step_result.done[index]:
415+
sanitized_action[index, :] = action[input_index, :]
416+
input_index = input_index + 1
417+
return sanitized_action
418+
419+
def _step(self, needs_reset: bool = False) -> BatchedStepResult:
420+
if needs_reset:
421+
self._env.reset()
422+
else:
423+
self._env.step()
424+
info = self._env.get_step_result(self.brain_name)
425+
# Two possible cases here:
426+
# 1) all agents requested decisions (some of which might be done)
427+
# 2) some Agents were marked Done in between steps.
428+
# In case 2, we re-request decisions until all agents request a real decision.
429+
while info.n_agents() - sum(info.done) < self._n_agents:
430+
if not info.done.all():
431+
raise UnityGymException(
432+
"The environment does not have the expected amount of agents."
433+
+ "Some agents did not request decisions at the same time."
434+
)
435+
self._done_agents.update(list(info.agent_id))
436+
self._env.step()
437+
info = self._env.get_step_result(self.brain_name)
438+
return self._sanitize_info(info)
439+
347440
@property
348441
def metadata(self):
349442
return {"render.modes": ["rgb_array"]}

gym-unity/gym_unity/tests/test_gym.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ def create_mock_vector_step_result(num_agents=1, number_visual_observations=0):
122122
123123
:int num_agents: Number of "agents" to imitate in your BatchedStepResult values.
124124
"""
125-
obs = [np.array([num_agents * [1, 2, 3]])]
125+
obs = [np.array([num_agents * [1, 2, 3]]).reshape(num_agents, 3)]
126126
if number_visual_observations:
127127
obs += [np.zeros(shape=(num_agents, 8, 8, 3), dtype=np.float32)]
128128
rewards = np.array(num_agents * [1.0])

0 commit comments

Comments
 (0)