|
53 | 53 | RSSMPrior, |
54 | 54 | RSSMRollout, |
55 | 55 | ) |
| 56 | +from torchrl.modules.models.multiagent import MultiAgentNetBase |
56 | 57 | from torchrl.modules.models.utils import SquashDims |
57 | 58 | from torchrl.modules.planners.mppi import MPPIPlanner |
58 | 59 | from torchrl.objectives.value import TDLambdaEstimator |
@@ -1010,6 +1011,89 @@ def test_multiagent_reset_mlp( |
1010 | 1011 | .any() |
1011 | 1012 | ) |
1012 | 1013 |
|
| 1014 | + @pytest.mark.parametrize("share_params", [True, False]) |
| 1015 | + @pytest.mark.parametrize("agent_dim", [1, -3]) |
| 1016 | + def test_multiagent_custom_agent_dim(self, share_params, agent_dim): |
| 1017 | + """Test that custom agent_dim values work correctly. |
| 1018 | +
|
| 1019 | + Regression test for https://github.com/pytorch/rl/issues/3288 |
| 1020 | + """ |
| 1021 | + n_agents = 3 |
| 1022 | + obs_dim = 5 |
| 1023 | + seq_len = 6 |
| 1024 | + output_dim = 4 |
| 1025 | + |
| 1026 | + class SingleAgentMLP(nn.Module): |
| 1027 | + def __init__(self, in_dim, out_dim): |
| 1028 | + super().__init__() |
| 1029 | + self.net = nn.Sequential( |
| 1030 | + nn.Linear(in_dim, 32), |
| 1031 | + nn.Tanh(), |
| 1032 | + nn.Linear(32, out_dim), |
| 1033 | + ) |
| 1034 | + |
| 1035 | + def forward(self, x): |
| 1036 | + return self.net(x) |
| 1037 | + |
| 1038 | + class MultiAgentPolicyNet(MultiAgentNetBase): |
| 1039 | + def __init__( |
| 1040 | + self, |
| 1041 | + obs_dim, |
| 1042 | + output_dim, |
| 1043 | + n_agents, |
| 1044 | + share_params, |
| 1045 | + agent_dim, |
| 1046 | + device=None, |
| 1047 | + ): |
| 1048 | + self.obs_dim = obs_dim |
| 1049 | + self.output_dim = output_dim |
| 1050 | + self._agent_dim = agent_dim |
| 1051 | + |
| 1052 | + super().__init__( |
| 1053 | + n_agents=n_agents, |
| 1054 | + centralized=False, |
| 1055 | + share_params=share_params, |
| 1056 | + agent_dim=agent_dim, |
| 1057 | + device=device, |
| 1058 | + ) |
| 1059 | + |
| 1060 | + def _build_single_net(self, *, device, **kwargs): |
| 1061 | + net = SingleAgentMLP(self.obs_dim, self.output_dim) |
| 1062 | + return net.to(device) if device is not None else net |
| 1063 | + |
| 1064 | + def _pre_forward_check(self, inputs): |
| 1065 | + if inputs.shape[self._agent_dim] != self.n_agents: |
| 1066 | + raise ValueError( |
| 1067 | + f"Multi-agent network expected input with shape[{self._agent_dim}]={self.n_agents}," |
| 1068 | + f" but got {inputs.shape}" |
| 1069 | + ) |
| 1070 | + return inputs |
| 1071 | + |
| 1072 | + policy_net = MultiAgentPolicyNet( |
| 1073 | + obs_dim=obs_dim, |
| 1074 | + output_dim=output_dim, |
| 1075 | + n_agents=n_agents, |
| 1076 | + share_params=share_params, |
| 1077 | + agent_dim=agent_dim, |
| 1078 | + ) |
| 1079 | + |
| 1080 | + # Input shape: (batch, n_agents, seq_len, obs_dim) with agents at dim 1 |
| 1081 | + batch_size = 4 |
| 1082 | + obs = torch.randn(batch_size, n_agents, seq_len, obs_dim) |
| 1083 | + out = policy_net(obs) |
| 1084 | + |
| 1085 | + # Output should preserve agent dimension position |
| 1086 | + expected_shape = (batch_size, n_agents, seq_len, output_dim) |
| 1087 | + assert ( |
| 1088 | + out.shape == expected_shape |
| 1089 | + ), f"Expected {expected_shape}, got {out.shape}" |
| 1090 | + |
| 1091 | + # Verify different agents produce different outputs (unless share_params with same input) |
| 1092 | + if not share_params: |
| 1093 | + for i in range(n_agents): |
| 1094 | + for j in range(i + 1, n_agents): |
| 1095 | + assert not torch.allclose(out[:, i], out[:, j]) |
| 1096 | + |
1013 | 1097 | @pytest.mark.parametrize("n_agents", [1, 3]) |
1014 | 1098 | @pytest.mark.parametrize("share_params", [True, False]) |
1015 | 1099 | @pytest.mark.parametrize("centralized", [True, False]) |
|
0 commit comments