Skip to content

Commit 3a9a40c

Browse files
matteobettiniezhang7423
andauthored
[Feature] Allow multipe inputs to models (#73)
* add multiagent cnn implementation and tests * amend * amend * amend * amend * amend * amend * amend * amend * amend * amend * amend * amend * docs * mend * mend * amend * amend * amend * amend * amend * amend * amend * amend * amend * amend * amend * amend * amend * amend * amend * amend --------- Co-authored-by: ezhang7423 <[email protected]>
1 parent e272278 commit 3a9a40c

File tree

10 files changed

+208
-158
lines changed

10 files changed

+208
-158
lines changed

benchmarl/algorithms/iddpg.py

Lines changed: 3 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from dataclasses import dataclass, MISSING
88
from typing import Dict, Iterable, Tuple, Type
99

10-
import torch
1110
from tensordict import TensorDictBase
1211
from tensordict.nn import TensorDictModule, TensorDictSequential
1312
from torchrl.data import CompositeSpec, UnboundedContinuousTensorSpec
@@ -188,34 +187,12 @@ def process_batch(self, group: str, batch: TensorDictBase) -> TensorDictBase:
188187
def get_value_module(self, group: str) -> TensorDictModule:
189188
n_agents = len(self.group_map[group])
190189
modules = []
191-
group_observation_key = list(self.observation_spec[group].keys())[0]
192190

193-
modules.append(
194-
TensorDictModule(
195-
lambda obs, action: torch.cat([obs, action], dim=-1),
196-
in_keys=[
197-
(group, group_observation_key),
198-
(group, "action"),
199-
],
200-
out_keys=[(group, "obs_action")],
201-
)
202-
)
203191
critic_input_spec = CompositeSpec(
204192
{
205-
group: CompositeSpec(
206-
{
207-
"obs_action": UnboundedContinuousTensorSpec(
208-
shape=(
209-
n_agents,
210-
self.observation_spec[
211-
group, group_observation_key
212-
].shape[-1]
213-
+ self.action_spec[group, "action"].shape[-1],
214-
)
215-
)
216-
},
217-
shape=(n_agents,),
218-
)
193+
group: self.observation_spec[group]
194+
.clone()
195+
.update(self.action_spec[group])
219196
}
220197
)
221198
critic_output_spec = CompositeSpec(

benchmarl/algorithms/isac.py

Lines changed: 3 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from dataclasses import dataclass, MISSING
88
from typing import Dict, Iterable, Optional, Tuple, Type, Union
99

10-
import torch
1110
from tensordict import TensorDictBase
1211
from tensordict.nn import NormalParamExtractor, TensorDictModule, TensorDictSequential
1312
from torch.distributions import Categorical
@@ -315,31 +314,12 @@ def get_discrete_value_module(self, group: str) -> TensorDictModule:
315314
def get_continuous_value_module(self, group: str) -> TensorDictModule:
316315
n_agents = len(self.group_map[group])
317316
modules = []
318-
group_observation_key = list(self.observation_spec[group].keys())[0]
319317

320-
modules.append(
321-
TensorDictModule(
322-
lambda obs, action: torch.cat([obs, action], dim=-1),
323-
in_keys=[(group, group_observation_key), (group, "action")],
324-
out_keys=[(group, "obs_action")],
325-
)
326-
)
327318
critic_input_spec = CompositeSpec(
328319
{
329-
group: CompositeSpec(
330-
{
331-
"obs_action": UnboundedContinuousTensorSpec(
332-
shape=(
333-
n_agents,
334-
self.observation_spec[
335-
group, group_observation_key
336-
].shape[-1]
337-
+ self.action_spec[group, "action"].shape[-1],
338-
)
339-
)
340-
},
341-
shape=(n_agents,),
342-
)
320+
group: self.observation_spec[group]
321+
.clone()
322+
.update(self.action_spec[group])
343323
}
344324
)
345325

benchmarl/algorithms/maddpg.py

Lines changed: 11 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from dataclasses import dataclass, MISSING
88
from typing import Dict, Iterable, Tuple, Type
99

10-
import torch
1110
from tensordict import TensorDictBase
1211
from tensordict.nn import TensorDictModule, TensorDictSequential
1312
from torchrl.data import CompositeSpec, UnboundedContinuousTensorSpec
@@ -41,7 +40,7 @@ def __init__(
4140
loss_function: str,
4241
delay_value: bool,
4342
use_tanh_mapping: bool,
44-
**kwargs
43+
**kwargs,
4544
):
4645
super().__init__(**kwargs)
4746

@@ -188,7 +187,6 @@ def process_batch(self, group: str, batch: TensorDictBase) -> TensorDictBase:
188187
def get_value_module(self, group: str) -> TensorDictModule:
189188
n_agents = len(self.group_map[group])
190189
modules = []
191-
group_observation_key = list(self.observation_spec[group].keys())[0]
192190

193191
if self.share_param_critic:
194192
critic_output_spec = CompositeSpec(
@@ -209,23 +207,18 @@ def get_value_module(self, group: str) -> TensorDictModule:
209207
)
210208

211209
if self.state_spec is not None:
212-
global_state_key = list(self.state_spec.keys())[0]
213210
modules.append(
214211
TensorDictModule(
215-
lambda state, action: torch.cat(
216-
[state, action.reshape(*action.shape[:-2], -1)], dim=-1
217-
),
218-
in_keys=[global_state_key, (group, "action")],
219-
out_keys=["state_action"],
212+
lambda action: action.reshape(*action.shape[:-2], -1),
213+
in_keys=[(group, "action")],
214+
out_keys=["global_action"],
220215
)
221216
)
222-
critic_input_spec = CompositeSpec(
217+
218+
critic_input_spec = self.state_spec.clone().update(
223219
{
224-
"state_action": UnboundedContinuousTensorSpec(
225-
shape=(
226-
self.state_spec[global_state_key].shape[-1]
227-
+ self.action_spec[group, "action"].shape[-1] * n_agents,
228-
)
220+
"global_action": UnboundedContinuousTensorSpec(
221+
shape=(self.action_spec[group, "action"].shape[-1] * n_agents,)
229222
)
230223
}
231224
)
@@ -245,29 +238,11 @@ def get_value_module(self, group: str) -> TensorDictModule:
245238
)
246239

247240
else:
248-
modules.append(
249-
TensorDictModule(
250-
lambda obs, action: torch.cat([obs, action], dim=-1),
251-
in_keys=[(group, group_observation_key), (group, "action")],
252-
out_keys=[(group, "obs_action")],
253-
)
254-
)
255241
critic_input_spec = CompositeSpec(
256242
{
257-
group: CompositeSpec(
258-
{
259-
"obs_action": UnboundedContinuousTensorSpec(
260-
shape=(
261-
n_agents,
262-
self.observation_spec[
263-
group, group_observation_key
264-
].shape[-1]
265-
+ self.action_spec[group, "action"].shape[-1],
266-
)
267-
)
268-
},
269-
shape=(n_agents,),
270-
)
243+
group: self.observation_spec[group]
244+
.clone()
245+
.update(self.action_spec[group])
271246
}
272247
)
273248

benchmarl/algorithms/masac.py

Lines changed: 11 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from dataclasses import dataclass, MISSING
88
from typing import Dict, Iterable, Optional, Tuple, Type, Union
99

10-
import torch
1110
from tensordict import TensorDictBase
1211
from tensordict.nn import NormalParamExtractor, TensorDictModule, TensorDictSequential
1312
from torch.distributions import Categorical
@@ -342,7 +341,6 @@ def get_discrete_value_module(self, group: str) -> TensorDictModule:
342341
def get_continuous_value_module(self, group: str) -> TensorDictModule:
343342
n_agents = len(self.group_map[group])
344343
modules = []
345-
group_observation_key = list(self.observation_spec[group].keys())[0]
346344

347345
if self.share_param_critic:
348346
critic_output_spec = CompositeSpec(
@@ -363,23 +361,19 @@ def get_continuous_value_module(self, group: str) -> TensorDictModule:
363361
)
364362

365363
if self.state_spec is not None:
366-
global_state_key = list(self.state_spec.keys())[0]
364+
367365
modules.append(
368366
TensorDictModule(
369-
lambda state, action: torch.cat(
370-
[state, action.reshape(*action.shape[:-2], -1)], dim=-1
371-
),
372-
in_keys=[global_state_key, (group, "action")],
373-
out_keys=["state_action"],
367+
lambda action: action.reshape(*action.shape[:-2], -1),
368+
in_keys=[(group, "action")],
369+
out_keys=["global_action"],
374370
)
375371
)
376-
critic_input_spec = CompositeSpec(
372+
373+
critic_input_spec = self.state_spec.clone().update(
377374
{
378-
"state_action": UnboundedContinuousTensorSpec(
379-
shape=(
380-
self.state_spec[global_state_key].shape[-1]
381-
+ self.action_spec[group, "action"].shape[-1] * n_agents,
382-
)
375+
"global_action": UnboundedContinuousTensorSpec(
376+
shape=(self.action_spec[group, "action"].shape[-1] * n_agents,)
383377
)
384378
}
385379
)
@@ -399,29 +393,11 @@ def get_continuous_value_module(self, group: str) -> TensorDictModule:
399393
)
400394

401395
else:
402-
modules.append(
403-
TensorDictModule(
404-
lambda obs, action: torch.cat([obs, action], dim=-1),
405-
in_keys=[(group, group_observation_key), (group, "action")],
406-
out_keys=[(group, "obs_action")],
407-
)
408-
)
409396
critic_input_spec = CompositeSpec(
410397
{
411-
group: CompositeSpec(
412-
{
413-
"obs_action": UnboundedContinuousTensorSpec(
414-
shape=(
415-
n_agents,
416-
self.observation_spec[
417-
group, group_observation_key
418-
].shape[-1]
419-
+ self.action_spec[group, "action"].shape[-1],
420-
)
421-
)
422-
},
423-
shape=(n_agents,),
424-
)
398+
group: self.observation_spec[group]
399+
.clone()
400+
.update(self.action_spec[group])
425401
}
426402
)
427403

0 commit comments

Comments
 (0)