Skip to content

Commit 8b2197c

Browse files
[Feature] Allow multiple observation keys in specs (#82)
1 parent e6d46be commit 8b2197c

File tree

3 files changed

+34
-21
lines changed

3 files changed

+34
-21
lines changed

benchmarl/algorithms/common.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -67,12 +67,6 @@ def _check_specs(self):
6767
"you can apply a transform to your environment to satisfy this criteria."
6868
)
6969
for group in self.group_map.keys():
70-
if len(self.observation_spec[group].keys(True, True)) != 1:
71-
raise ValueError(
72-
"Observation spec must contain one entry per group"
73-
" to follow the library conventions, "
74-
"you can apply a transform to your environment to satisfy this criteria."
75-
)
7670
if (
7771
len(self.action_spec[group].keys(True, True)) != 1
7872
or list(self.action_spec[group].keys())[0] != "action"

benchmarl/environments/meltingpot/common.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,13 @@
1010
from tensordict import TensorDictBase
1111

1212
from torchrl.data import CompositeSpec
13-
from torchrl.envs import DoubleToFloat, DTypeCastTransform, EnvBase, Transform
13+
from torchrl.envs import (
14+
DoubleToFloat,
15+
DTypeCastTransform,
16+
EnvBase,
17+
FlattenObservation,
18+
Transform,
19+
)
1420

1521
from benchmarl.environments.common import Task
1622
from benchmarl.utils import DEVICE_TYPING
@@ -81,6 +87,7 @@ def get_env_fun(
8187
return lambda: MeltingpotEnv(
8288
substrate=self.name.lower(),
8389
categorical_actions=True,
90+
device=device,
8491
**self.config,
8592
)
8693

@@ -100,7 +107,23 @@ def group_map(self, env: EnvBase) -> Dict[str, List[str]]:
100107
return env.group_map
101108

102109
def get_env_transforms(self, env: EnvBase) -> List[Transform]:
103-
return [DoubleToFloat()]
110+
interaction_inventories_keys = [
111+
(group, "observation", "INTERACTION_INVENTORIES")
112+
for group in self.group_map(env).keys()
113+
if (group, "observation", "INTERACTION_INVENTORIES")
114+
in env.observation_spec.keys(True, True)
115+
]
116+
return [DoubleToFloat()] + (
117+
[
118+
FlattenObservation(
119+
in_keys=interaction_inventories_keys,
120+
first_dim=-2,
121+
last_dim=-1,
122+
)
123+
]
124+
if len(interaction_inventories_keys)
125+
else []
126+
)
104127

105128
def get_replay_buffer_transforms(self, env: EnvBase) -> List[Transform]:
106129
return [
@@ -141,11 +164,6 @@ def observation_spec(self, env: EnvBase) -> CompositeSpec:
141164
for group_key in list(observation_spec.keys()):
142165
if group_key not in self.group_map(env).keys():
143166
del observation_spec[group_key]
144-
else:
145-
group_obs_spec = observation_spec[group_key]["observation"]
146-
for key in list(group_obs_spec.keys()):
147-
if key != "RGB":
148-
del group_obs_spec[key]
149167
return observation_spec
150168

151169
def info_spec(self, env: EnvBase) -> Optional[CompositeSpec]:

benchmarl/experiment/experiment.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -387,14 +387,6 @@ def _setup_task(self):
387387
device=self.config.sampling_device,
388388
)
389389
)
390-
self.observation_spec = self.task.observation_spec(test_env)
391-
self.info_spec = self.task.info_spec(test_env)
392-
self.state_spec = self.task.state_spec(test_env)
393-
self.action_mask_spec = self.task.action_mask_spec(test_env)
394-
self.action_spec = self.task.action_spec(test_env)
395-
self.group_map = self.task.group_map(test_env)
396-
self.train_group_map = copy.deepcopy(self.group_map)
397-
self.max_steps = self.task.max_steps(test_env)
398390

399391
transforms_env = self.task.get_env_transforms(test_env)
400392
transforms_training = transforms_env + [
@@ -418,6 +410,15 @@ def _setup_task(self):
418410
self.config.sampling_device
419411
)
420412

413+
self.observation_spec = self.task.observation_spec(self.test_env)
414+
self.info_spec = self.task.info_spec(self.test_env)
415+
self.state_spec = self.task.state_spec(self.test_env)
416+
self.action_mask_spec = self.task.action_mask_spec(self.test_env)
417+
self.action_spec = self.task.action_spec(self.test_env)
418+
self.group_map = self.task.group_map(self.test_env)
419+
self.train_group_map = copy.deepcopy(self.group_map)
420+
self.max_steps = self.task.max_steps(self.test_env)
421+
421422
def _setup_algorithm(self):
422423
self.algorithm = self.algorithm_config.get_algorithm(experiment=self)
423424
self.replay_buffers = {

0 commit comments

Comments
 (0)