Skip to content

Commit 76847d7

Browse files
pseudo-rnd-thoughtsMark TowersArturNiederfahrenhorst
authored andcommitted
[rllib] Update flatten_observations.py for nested spaces for ignored multi-agent (ray-project#59928)
## Description For the FlattenObservation connector, in the mutli-agent case, if any of the agent's had a nested space, if it was ignored then we used the base_struct version of the observation_spaces rather than the actual spaces which for nested spaces is a problem. FYI, I believe that `get_base_struct_from_space` can be removed as Gymnasium has added the ability to treat `Dict` and `Tuple` spaces as like a dict or tuple for iterating, calling `.keys()`, `.values()` and `.items()`. For a future PR possibly ## Related issues Fixes ray-project#59849 --------- Signed-off-by: Mark Towers <mark@anyscale.com> Signed-off-by: Artur Niederfahrenhorst <attaismyname@googlemail.com> Co-authored-by: Mark Towers <mark@anyscale.com> Co-authored-by: Artur Niederfahrenhorst <attaismyname@googlemail.com> Co-authored-by: Artur Niederfahrenhorst <artur@anyscale.com> Signed-off-by: jasonwrwang <jasonwrwang@tencent.com>
1 parent 538a29e commit 76847d7

File tree

1 file changed

+100
-7
lines changed

1 file changed

+100
-7
lines changed

rllib/connectors/common/flatten_observations.py

Lines changed: 100 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -216,31 +216,115 @@ class FlattenObservations(ConnectorV2):
216216
output_batch["obs"][(episode_2.id_,)][0][0],
217217
np.array([1.0, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0]),
218218
)
219+
220+
# Multi-agent example: Use the connector with a multi-agent observation space.
221+
# The observation space must be a Dict with agent IDs as top-level keys.
222+
from ray.rllib.env.multi_agent_episode import MultiAgentEpisode
223+
224+
# Define a per-agent observation space.
225+
per_agent_obs_space = gym.spaces.Dict({
226+
"a": gym.spaces.Box(-10.0, 10.0, (), np.float32),
227+
"b": gym.spaces.Tuple([
228+
gym.spaces.Discrete(2),
229+
gym.spaces.Box(-1.0, 1.0, (2, 1), np.float32),
230+
]),
231+
"c": gym.spaces.MultiDiscrete([2, 3]),
232+
})
233+
234+
# Create a multi-agent observation space with agent IDs as keys.
235+
multi_agent_obs_space = gym.spaces.Dict({
236+
"agent_1": per_agent_obs_space,
237+
"agent_2": per_agent_obs_space,
238+
})
239+
240+
# Create a multi-agent episode with observations for both agents.
241+
# Agent IDs are inferred from the keys in the observations dict.
242+
ma_episode = MultiAgentEpisode(
243+
observations=[
244+
{
245+
"agent_1": {
246+
"a": np.array(-10.0, np.float32),
247+
"b": (1, np.array([[-1.0], [-1.0]], np.float32)),
248+
"c": np.array([0, 2]),
249+
},
250+
"agent_2": {
251+
"a": np.array(10.0, np.float32),
252+
"b": (0, np.array([[1.0], [1.0]], np.float32)),
253+
"c": np.array([1, 1]),
254+
},
255+
},
256+
],
257+
)
258+
259+
# Construct the connector for multi-agent, flattening only agent_1's observations.
260+
# Note: If agent_ids is None (the default), all agents' observations are flattened.
261+
connector = FlattenObservations(
262+
multi_agent_obs_space,
263+
act_space,
264+
multi_agent=True,
265+
agent_ids=["agent_1"],
266+
)
267+
268+
# Call the connector.
269+
output_batch = connector(
270+
rl_module=None,
271+
batch={},
272+
episodes=[ma_episode],
273+
explore=True,
274+
shared_data={},
275+
)
276+
277+
# agent_1's observation is flattened.
278+
check(
279+
ma_episode.agent_episodes["agent_1"].get_observations(0),
280+
# box() disc(2). box(2, 1). multidisc(2, 3)........
281+
np.array([-10.0, 0.0, 1.0, -1.0, -1.0, 1.0, 0.0, 0.0, 0.0, 1.0]),
282+
)
283+
284+
# agent_2's observation is unchanged (not in agent_ids).
285+
check(
286+
ma_episode.agent_episodes["agent_2"].get_observations(0),
287+
{
288+
"a": np.array(10.0, np.float32),
289+
"b": (0, np.array([[1.0], [1.0]], np.float32)),
290+
"c": np.array([1, 1]),
291+
},
292+
)
219293
"""
220294

221295
@override(ConnectorV2)
222296
def recompute_output_observation_space(
223297
self,
224-
input_observation_space,
225-
input_action_space,
298+
input_observation_space: gym.Space,
299+
input_action_space: gym.Space,
226300
) -> gym.Space:
227301
self._input_obs_base_struct = get_base_struct_from_space(
228302
self.input_observation_space
229303
)
230304

231305
if self._multi_agent:
232306
spaces = {}
233-
for agent_id, space in self._input_obs_base_struct.items():
307+
assert isinstance(
308+
input_observation_space, gym.spaces.Dict
309+
), f"To flatten a Multi-Agent observation, it is expected that observation space is a dictionary, its actual type is {type(input_observation_space)}"
310+
311+
for agent_id, space in input_observation_space.items():
234312
# Remove keys, if necessary.
235313
# TODO (simon): Maybe allow to remove different keys for different agents.
236314
if self._keys_to_remove:
315+
assert isinstance(
316+
space, gym.spaces.Dict
317+
), f"To remove keys from an observation space requires that it be a dictionary, its actual type is {type(space)}"
318+
237319
self._input_obs_base_struct[agent_id] = {
238320
k: v
239321
for k, v in self._input_obs_base_struct[agent_id].items()
240322
if k not in self._keys_to_remove
241323
}
324+
242325
if self._agent_ids and agent_id not in self._agent_ids:
243-
spaces[agent_id] = self._input_obs_base_struct[agent_id]
326+
# For nested spaces, we need to use the original Spaces (rather than the reduced version)
327+
spaces[agent_id] = self.input_observation_space[agent_id]
244328
else:
245329
sample = flatten_inputs_to_1d_tensor(
246330
tree.map_structure(
@@ -253,15 +337,20 @@ def recompute_output_observation_space(
253337
spaces[agent_id] = Box(
254338
float("-inf"), float("inf"), (len(sample),), np.float32
255339
)
340+
256341
return gym.spaces.Dict(spaces)
257342
else:
258343
# Remove keys, if necessary.
259344
if self._keys_to_remove:
345+
assert isinstance(
346+
input_observation_space, gym.spaces.Dict
347+
), f"To remove keys from an observation space requires that it be a dictionary, its actual type is {type(input_observation_space)}"
260348
self._input_obs_base_struct = {
261349
k: v
262350
for k, v in self._input_obs_base_struct.items()
263351
if k not in self._keys_to_remove
264352
}
353+
265354
sample = flatten_inputs_to_1d_tensor(
266355
tree.map_structure(
267356
lambda s: s.sample(),
@@ -286,13 +375,17 @@ def __init__(
286375
"""Initializes a FlattenObservations instance.
287376
288377
Args:
378+
input_observation_space: The input observation space. For multi-agent
379+
setups, this must be a Dict space with agent IDs as top-level keys
380+
mapping to each agent's individual observation space.
381+
input_action_space: The input action space.
289382
multi_agent: Whether this connector operates on multi-agent observations,
290383
in which case, the top-level of the Dict space (where agent IDs are
291384
mapped to individual agents' observation spaces) is left as-is.
292385
agent_ids: If multi_agent is True, this argument defines a collection of
293-
AgentIDs for which to flatten. AgentIDs not in this collection are
294-
ignored.
295-
If None, flatten observations for all AgentIDs. None is the default.
386+
AgentIDs for which to flatten. AgentIDs not in this collection will
387+
have their observations passed through unchanged.
388+
If None (the default), flatten observations for all AgentIDs.
296389
as_learner_connector: Whether this connector is part of a Learner connector
297390
pipeline, as opposed to an env-to-module pipeline.
298391
Note, this is usually only used for offline rl where the data comes

0 commit comments

Comments
 (0)