@@ -216,31 +216,115 @@ class FlattenObservations(ConnectorV2):
216216 output_batch["obs"][(episode_2.id_,)][0][0],
217217 np.array([1.0, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0]),
218218 )
219+
220+ # Multi-agent example: Use the connector with a multi-agent observation space.
221+ # The observation space must be a Dict with agent IDs as top-level keys.
222+ from ray.rllib.env.multi_agent_episode import MultiAgentEpisode
223+
224+ # Define a per-agent observation space.
225+ per_agent_obs_space = gym.spaces.Dict({
226+ "a": gym.spaces.Box(-10.0, 10.0, (), np.float32),
227+ "b": gym.spaces.Tuple([
228+ gym.spaces.Discrete(2),
229+ gym.spaces.Box(-1.0, 1.0, (2, 1), np.float32),
230+ ]),
231+ "c": gym.spaces.MultiDiscrete([2, 3]),
232+ })
233+
234+ # Create a multi-agent observation space with agent IDs as keys.
235+ multi_agent_obs_space = gym.spaces.Dict({
236+ "agent_1": per_agent_obs_space,
237+ "agent_2": per_agent_obs_space,
238+ })
239+
240+ # Create a multi-agent episode with observations for both agents.
241+ # Agent IDs are inferred from the keys in the observations dict.
242+ ma_episode = MultiAgentEpisode(
243+ observations=[
244+ {
245+ "agent_1": {
246+ "a": np.array(-10.0, np.float32),
247+ "b": (1, np.array([[-1.0], [-1.0]], np.float32)),
248+ "c": np.array([0, 2]),
249+ },
250+ "agent_2": {
251+ "a": np.array(10.0, np.float32),
252+ "b": (0, np.array([[1.0], [1.0]], np.float32)),
253+ "c": np.array([1, 1]),
254+ },
255+ },
256+ ],
257+ )
258+
259+ # Construct the connector for multi-agent, flattening only agent_1's observations.
260+ # Note: If agent_ids is None (the default), all agents' observations are flattened.
261+ connector = FlattenObservations(
262+ multi_agent_obs_space,
263+ act_space,
264+ multi_agent=True,
265+ agent_ids=["agent_1"],
266+ )
267+
268+ # Call the connector.
269+ output_batch = connector(
270+ rl_module=None,
271+ batch={},
272+ episodes=[ma_episode],
273+ explore=True,
274+ shared_data={},
275+ )
276+
277+ # agent_1's observation is flattened.
278+ check(
279+ ma_episode.agent_episodes["agent_1"].get_observations(0),
280+ # box() disc(2). box(2, 1). multidisc(2, 3)........
281+ np.array([-10.0, 0.0, 1.0, -1.0, -1.0, 1.0, 0.0, 0.0, 0.0, 1.0]),
282+ )
283+
284+ # agent_2's observation is unchanged (not in agent_ids).
285+ check(
286+ ma_episode.agent_episodes["agent_2"].get_observations(0),
287+ {
288+ "a": np.array(10.0, np.float32),
289+ "b": (0, np.array([[1.0], [1.0]], np.float32)),
290+ "c": np.array([1, 1]),
291+ },
292+ )
219293 """
220294
221295 @override (ConnectorV2 )
222296 def recompute_output_observation_space (
223297 self ,
224- input_observation_space ,
225- input_action_space ,
298+ input_observation_space : gym . Space ,
299+ input_action_space : gym . Space ,
226300 ) -> gym .Space :
227301 self ._input_obs_base_struct = get_base_struct_from_space (
228302 self .input_observation_space
229303 )
230304
231305 if self ._multi_agent :
232306 spaces = {}
233- for agent_id , space in self ._input_obs_base_struct .items ():
307+ assert isinstance (
308+ input_observation_space , gym .spaces .Dict
309+ ), f"To flatten a Multi-Agent observation, it is expected that observation space is a dictionary, its actual type is { type (input_observation_space )} "
310+
311+ for agent_id , space in input_observation_space .items ():
234312 # Remove keys, if necessary.
235313 # TODO (simon): Maybe allow to remove different keys for different agents.
236314 if self ._keys_to_remove :
315+ assert isinstance (
316+ space , gym .spaces .Dict
317+ ), f"To remove keys from an observation space requires that it be a dictionary, its actual type is { type (space )} "
318+
237319 self ._input_obs_base_struct [agent_id ] = {
238320 k : v
239321 for k , v in self ._input_obs_base_struct [agent_id ].items ()
240322 if k not in self ._keys_to_remove
241323 }
324+
242325 if self ._agent_ids and agent_id not in self ._agent_ids :
243- spaces [agent_id ] = self ._input_obs_base_struct [agent_id ]
326+ # For nested spaces, we need to use the original Spaces (rather than the reduced version)
327+ spaces [agent_id ] = self .input_observation_space [agent_id ]
244328 else :
245329 sample = flatten_inputs_to_1d_tensor (
246330 tree .map_structure (
@@ -253,15 +337,20 @@ def recompute_output_observation_space(
253337 spaces [agent_id ] = Box (
254338 float ("-inf" ), float ("inf" ), (len (sample ),), np .float32
255339 )
340+
256341 return gym .spaces .Dict (spaces )
257342 else :
258343 # Remove keys, if necessary.
259344 if self ._keys_to_remove :
345+ assert isinstance (
346+ input_observation_space , gym .spaces .Dict
347+ ), f"To remove keys from an observation space requires that it be a dictionary, its actual type is { type (input_observation_space )} "
260348 self ._input_obs_base_struct = {
261349 k : v
262350 for k , v in self ._input_obs_base_struct .items ()
263351 if k not in self ._keys_to_remove
264352 }
353+
265354 sample = flatten_inputs_to_1d_tensor (
266355 tree .map_structure (
267356 lambda s : s .sample (),
@@ -286,13 +375,17 @@ def __init__(
286375 """Initializes a FlattenObservations instance.
287376
288377 Args:
378+ input_observation_space: The input observation space. For multi-agent
379+ setups, this must be a Dict space with agent IDs as top-level keys
380+ mapping to each agent's individual observation space.
381+ input_action_space: The input action space.
289382 multi_agent: Whether this connector operates on multi-agent observations,
290383 in which case, the top-level of the Dict space (where agent IDs are
291384 mapped to individual agents' observation spaces) is left as-is.
292385 agent_ids: If multi_agent is True, this argument defines a collection of
293- AgentIDs for which to flatten. AgentIDs not in this collection are
294- ignored .
295- If None, flatten observations for all AgentIDs. None is the default .
386+ AgentIDs for which to flatten. AgentIDs not in this collection will
387+ have their observations passed through unchanged .
388+ If None (the default) , flatten observations for all AgentIDs.
296389 as_learner_connector: Whether this connector is part of a Learner connector
297390 pipeline, as opposed to an env-to-module pipeline.
298391 Note, this is usually only used for offline rl where the data comes
0 commit comments