Skip to content

Commit ab35c36

Browse files
authored
[BugFix] Fix agent_dim in multiagent nets & account for neg dims (#3290)
1 parent c43f212 commit ab35c36

File tree

3 files changed

+111
-11
lines changed

3 files changed

+111
-11
lines changed

test/test_modules.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
RSSMPrior,
5454
RSSMRollout,
5555
)
56+
from torchrl.modules.models.multiagent import MultiAgentNetBase
5657
from torchrl.modules.models.utils import SquashDims
5758
from torchrl.modules.planners.mppi import MPPIPlanner
5859
from torchrl.objectives.value import TDLambdaEstimator
@@ -1010,6 +1011,89 @@ def test_multiagent_reset_mlp(
10101011
.any()
10111012
)
10121013

1014+
@pytest.mark.parametrize("share_params", [True, False])
1015+
@pytest.mark.parametrize("agent_dim", [1, -3])
1016+
def test_multiagent_custom_agent_dim(self, share_params, agent_dim):
1017+
"""Test that custom agent_dim values work correctly.
1018+
1019+
Regression test for https://github.com/pytorch/rl/issues/3288
1020+
"""
1021+
n_agents = 3
1022+
obs_dim = 5
1023+
seq_len = 6
1024+
output_dim = 4
1025+
1026+
class SingleAgentMLP(nn.Module):
1027+
def __init__(self, in_dim, out_dim):
1028+
super().__init__()
1029+
self.net = nn.Sequential(
1030+
nn.Linear(in_dim, 32),
1031+
nn.Tanh(),
1032+
nn.Linear(32, out_dim),
1033+
)
1034+
1035+
def forward(self, x):
1036+
return self.net(x)
1037+
1038+
class MultiAgentPolicyNet(MultiAgentNetBase):
1039+
def __init__(
1040+
self,
1041+
obs_dim,
1042+
output_dim,
1043+
n_agents,
1044+
share_params,
1045+
agent_dim,
1046+
device=None,
1047+
):
1048+
self.obs_dim = obs_dim
1049+
self.output_dim = output_dim
1050+
self._agent_dim = agent_dim
1051+
1052+
super().__init__(
1053+
n_agents=n_agents,
1054+
centralized=False,
1055+
share_params=share_params,
1056+
agent_dim=agent_dim,
1057+
device=device,
1058+
)
1059+
1060+
def _build_single_net(self, *, device, **kwargs):
1061+
net = SingleAgentMLP(self.obs_dim, self.output_dim)
1062+
return net.to(device) if device is not None else net
1063+
1064+
def _pre_forward_check(self, inputs):
1065+
if inputs.shape[self._agent_dim] != self.n_agents:
1066+
raise ValueError(
1067+
f"Multi-agent network expected input with shape[{self._agent_dim}]={self.n_agents},"
1068+
f" but got {inputs.shape}"
1069+
)
1070+
return inputs
1071+
1072+
policy_net = MultiAgentPolicyNet(
1073+
obs_dim=obs_dim,
1074+
output_dim=output_dim,
1075+
n_agents=n_agents,
1076+
share_params=share_params,
1077+
agent_dim=agent_dim,
1078+
)
1079+
1080+
# Input shape: (batch, n_agents, seq_len, obs_dim) with agents at dim 1
1081+
batch_size = 4
1082+
obs = torch.randn(batch_size, n_agents, seq_len, obs_dim)
1083+
out = policy_net(obs)
1084+
1085+
# Output should preserve agent dimension position
1086+
expected_shape = (batch_size, n_agents, seq_len, output_dim)
1087+
assert (
1088+
out.shape == expected_shape
1089+
), f"Expected {expected_shape}, got {out.shape}"
1090+
1091+
# Verify different agents produce different outputs (unless share_params with same input)
1092+
if not share_params:
1093+
for i in range(n_agents):
1094+
for j in range(i + 1, n_agents):
1095+
assert not torch.allclose(out[:, i], out[:, j])
1096+
10131097
@pytest.mark.parametrize("n_agents", [1, 3])
10141098
@pytest.mark.parametrize("share_params", [True, False])
10151099
@pytest.mark.parametrize("centralized", [True, False])

torchrl/modules/llm/policies/transformers_wrapper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -790,7 +790,7 @@ def _from_transformers_generate_history(self, td, cfg, out) -> TensorDictBase:
790790

791791
if self._device is not None:
792792
response_struct = response_struct.to(self._device)
793-
793+
794794
tokens_prompt_padded = response_struct.get(
795795
"input_ids",
796796
as_padded_tensor=True,

torchrl/modules/models/multiagent.py

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -132,18 +132,32 @@ def forward(self, *inputs: tuple[torch.Tensor]) -> torch.Tensor:
132132
else:
133133
inputs = inputs[0]
134134

135+
# Convert agent_dim to positive index for consistent output placement.
136+
# This ensures the agent dimension stays at the same position relative
137+
# to batch dimensions, even if the network changes the number of dimensions
138+
# (e.g., ConvNet collapses spatial dims).
139+
# NOTE: Must compute this BEFORE _pre_forward_check, which may modify input shape
140+
# (e.g., centralized mode flattens the agent dimension).
141+
agent_dim_positive = self.agent_dim
142+
if agent_dim_positive < 0:
143+
agent_dim_positive = inputs.ndim + agent_dim_positive
144+
135145
inputs = self._pre_forward_check(inputs)
146+
136147
# If parameters are not shared, each agent has its own network
137148
if not self.share_params:
138149
if self.centralized:
139150
output = self.vmap_func_module(
140-
self._empty_net, (0, None), (-2,), randomness=self.vmap_randomness
151+
self._empty_net,
152+
(0, None),
153+
(agent_dim_positive,),
154+
randomness=self.vmap_randomness,
141155
)(self.params, inputs)
142156
else:
143157
output = self.vmap_func_module(
144158
self._empty_net,
145-
(0, self.agent_dim),
146-
(-2,),
159+
(0, agent_dim_positive),
160+
(agent_dim_positive,),
147161
randomness=self.vmap_randomness,
148162
)(self.params, inputs)
149163

@@ -157,14 +171,16 @@ def forward(self, *inputs: tuple[torch.Tensor]) -> torch.Tensor:
157171
# We expand it to maintain the agent dimension, but values will be the same for all agents
158172
n_agent_outputs = output.shape[-1]
159173
output = output.view(*output.shape[:-1], n_agent_outputs)
160-
output = output.unsqueeze(-2)
161-
output = output.expand(
162-
*output.shape[:-2], self.n_agents, n_agent_outputs
163-
)
164-
165-
if output.shape[-2] != (self.n_agents):
174+
# Insert agent dimension at the correct position
175+
output = output.unsqueeze(agent_dim_positive)
176+
# Build the expanded shape
177+
expand_shape = list(output.shape)
178+
expand_shape[agent_dim_positive] = self.n_agents
179+
output = output.expand(*expand_shape)
180+
181+
if output.shape[agent_dim_positive] != (self.n_agents):
166182
raise ValueError(
167-
f"Multi-agent network expected output with shape[-2]={self.n_agents}"
183+
f"Multi-agent network expected output with shape[{agent_dim_positive}]={self.n_agents}"
168184
f" but got {output.shape}"
169185
)
170186

0 commit comments

Comments
 (0)