Skip to content

Commit 0a28a16

Browse files
[Feature] Allow multiple observation keys 2 (#83)
1 parent 8b2197c commit 0a28a16

File tree

2 files changed

+30
-3
lines changed

2 files changed

+30
-3
lines changed

benchmarl/algorithms/qmix.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -177,13 +177,19 @@ def process_batch(self, group: str, batch: TensorDictBase) -> TensorDictBase:
177177

178178
def get_mixer(self, group: str) -> TensorDictModule:
179179
n_agents = len(self.group_map[group])
180-
group_observation_key = list(self.observation_spec[group].keys())[0]
181180

182181
if self.state_spec is not None:
183-
global_state_key = list(self.state_spec.keys())[0]
182+
global_state_key = list(self.state_spec.keys(True, True))[0]
184183
state_shape = self.state_spec[global_state_key].shape
185184
in_keys = [(group, "chosen_action_value"), global_state_key]
186185
else:
186+
group_observation_keys = list(self.observation_spec[group].keys(True, True))
187+
if len(group_observation_keys) > 1:
188+
raise ValueError(
189+
"QMIX called without a global state and multiple observation keys, currently the mixer"
190+
"takes only one observation key, please raise an issue if you need this fauture."
191+
)
192+
group_observation_key = group_observation_keys[0]
187193
state_shape = self.observation_spec[group, group_observation_key].shape
188194
in_keys = [(group, "chosen_action_value"), (group, group_observation_key)]
189195

benchmarl/environments/common.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,11 +156,32 @@ def group_map(self, env: EnvBase) -> Dict[str, List[str]]:
156156
def observation_spec(self, env: EnvBase) -> CompositeSpec:
157157
"""
158158
A spec for the observation.
159-
Must be a CompositeSpec with one (group_name, observation_key) entry per group.
159+
Must be a CompositeSpec with as many entries as needed nested under the ``group_name`` key.
160160
161161
Args:
162162
env (EnvBase): An environment created via self.get_env_fun
163163
164+
Examples:
165+
>>> print(task.observation_spec(env))
166+
CompositeSpec(
167+
agents: CompositeSpec(
168+
observation: CompositeSpec(
169+
image: UnboundedDiscreteTensorSpec(
170+
shape=torch.Size([8, 88, 88, 3]),
171+
space=ContinuousBox(
172+
low=Tensor(shape=torch.Size([8, 88, 88, 3]), device=cpu, dtype=torch.int64, contiguous=True),
173+
high=Tensor(shape=torch.Size([8, 88, 88, 3]), device=cpu, dtype=torch.int64, contiguous=True)),
174+
device=cpu,
175+
dtype=torch.uint8,
176+
domain=discrete),
177+
array: UnboundedContinuousTensorSpec(
178+
shape=torch.Size([8, 3]),
179+
space=None,
180+
device=cpu,
181+
dtype=torch.float32,
182+
domain=continuous), device=cpu, shape=torch.Size([8])), device=cpu, shape=torch.Size([8])), device=cpu, shape=torch.Size([]))
183+
184+
164185
"""
165186
raise NotImplementedError
166187

0 commit comments

Comments
 (0)