From 59a2bbdfa9f962bcd9b6404142f9bf2c2b8f900c Mon Sep 17 00:00:00 2001 From: Yoann Poupart <66315201+Xmaster6y@users.noreply.github.com> Date: Sat, 4 Oct 2025 18:36:33 +0200 Subject: [PATCH 1/3] no default --- benchmarl/conf/model/layers/gnn.yaml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/benchmarl/conf/model/layers/gnn.yaml b/benchmarl/conf/model/layers/gnn.yaml index 4e9776ce..19f6ccfa 100644 --- a/benchmarl/conf/model/layers/gnn.yaml +++ b/benchmarl/conf/model/layers/gnn.yaml @@ -4,8 +4,7 @@ topology: full self_loops: False gnn_class: torch_geometric.nn.conv.GraphConv -gnn_kwargs: - aggr: "add" +gnn_kwargs: {} position_key: null pos_features: 0 From b0efe9aaea0659e805956988046e4d70983f73d1 Mon Sep 17 00:00:00 2001 From: Yoann Poupart <66315201+Xmaster6y@users.noreply.github.com> Date: Sat, 4 Oct 2025 18:40:54 +0200 Subject: [PATCH 2/3] better support --- benchmarl/models/gnn.py | 56 ++++++++++++++++++++++++++++------------- 1 file changed, 38 insertions(+), 18 deletions(-) diff --git a/benchmarl/models/gnn.py b/benchmarl/models/gnn.py index 855ba81b..a5e039e2 100644 --- a/benchmarl/models/gnn.py +++ b/benchmarl/models/gnn.py @@ -57,7 +57,7 @@ class Gnn(Model): topology (str): Topology of the graph adjacency matrix. Options: "full", "empty", "from_pos". "from_pos" builds the topology dynamically based on ``position_key`` and ``edge_radius``. self_loops (str): Whether the resulting adjacency matrix will have self loops. - gnn_class (Type[torch_geometric.nn.MessagePassing]): the gnn convolution class to use + gnn_class (Type[torch.nn.Module]): the gnn convolution class to use gnn_kwargs (dict, optional): the dict of arguments to pass to the gnn conv class position_key (str, optional): if provided, it will need to match a leaf key in the tensordict coming from the env (in the `observation_spec`) representing the agent position. @@ -124,7 +124,7 @@ def __init__( self, topology: str, self_loops: bool, - gnn_class: Type[torch_geometric.nn.MessagePassing], + gnn_class: Type[torch.nn.Module], gnn_kwargs: Optional[dict], position_key: Optional[str], exclude_pos_from_node_features: Optional[bool], @@ -142,6 +142,9 @@ def __init__( self.edge_radius = edge_radius self.pos_features = pos_features self.vel_features = vel_features + self.is_message_passing = issubclass( + gnn_class, torch_geometric.nn.MessagePassing + ) super().__init__(**kwargs) @@ -164,9 +167,24 @@ def __init__( if gnn_kwargs is None: gnn_kwargs = {} - gnn_kwargs.update( - {"in_channels": self.input_features, "out_channels": self.output_features} - ) + if self.is_message_passing: + gnn_kwargs.update( + { + "in_channels": self.input_features, + "out_channels": self.output_features, + } + ) + else: + if self.input_features != self.output_features: + raise ValueError( + "Input and output features must be the same for non-MessagePassing GNN classes" + ) + gnn_kwargs.update( + { + "channels": self.input_features, + "heads": gnn_kwargs.get("heads", 1), + } + ) self.gnn_supports_edge_attrs = ( "edge_dim" in inspect.getfullargspec(gnn_class).args ) @@ -315,18 +333,20 @@ def _forward(self, tensordict: TensorDictBase) -> TensorDictBase: input = torch.cat(input, dim=-1) batch_size = input.shape[:-2] - graph = _batch_from_dense_to_ptg( - x=input, - edge_index=self.edge_index, - pos=pos, - vel=vel, - self_loops=self.self_loops, - edge_radius=self.edge_radius, - ) - forward_gnn_params = { - "x": graph.x, - "edge_index": graph.edge_index, - } + if self.is_message_passing: + graph = _batch_from_dense_to_ptg( + x=input, + edge_index=self.edge_index, + pos=pos, + vel=vel, + self_loops=self.self_loops, + edge_radius=self.edge_radius, + ) + forward_gnn_params = {"x": graph.x, "edge_index": graph.edge_index} + else: + forward_gnn_params = { + "x": input, + } if ( self.position_key is not None or self.velocity_key is not None ) and self.gnn_supports_edge_attrs: @@ -468,7 +488,7 @@ class GnnConfig(ModelConfig): topology: str = MISSING self_loops: bool = MISSING - gnn_class: Type[torch_geometric.nn.MessagePassing] = MISSING + gnn_class: Type[torch.nn.Module] = MISSING gnn_kwargs: Optional[dict] = None position_key: Optional[str] = None From a2310c537e861ab78d1839a07edac37d7c7fcd2b Mon Sep 17 00:00:00 2001 From: Yoann Poupart <66315201+Xmaster6y@users.noreply.github.com> Date: Sat, 4 Oct 2025 19:04:25 +0200 Subject: [PATCH 3/3] some tests --- test/test_models.py | 97 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 97 insertions(+) diff --git a/test/test_models.py b/test/test_models.py index 0cf0c798..266574ae 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -386,6 +386,103 @@ def test_gnn_edge_attrs( output = gnn(obs_input) assert output_spec.expand(batch_size).is_in(output) + @pytest.mark.parametrize("share_params", [True, False]) + def test_gnn_attention( + self, + share_params, + n_agents=3, + agent_goup="agents", + features=5, + ): + torch.manual_seed(0) + + input_spec = Composite( + { + agent_goup: Composite( + {"in": Unbounded(shape=(n_agents, features))}, + shape=(n_agents,), + ) + } + ) + + output_spec = Composite( + { + agent_goup: Composite( + {"out": Unbounded(shape=(n_agents, features))}, + shape=(n_agents,), + ) + }, + ) + + gnn = GnnConfig( + topology="full", + self_loops=True, + gnn_class=torch_geometric.nn.attention.PerformerAttention, + gnn_kwargs=None, + exclude_pos_from_node_features=False, + ).get_model( + input_spec=input_spec, + output_spec=output_spec, + agent_group=agent_goup, + input_has_agent_dim=True, + n_agents=n_agents, + centralised=False, + share_params=share_params, + device="cpu", + action_spec=None, + ) + + obs_input = input_spec.expand(4).rand() + output = gnn(obs_input) + assert output_spec.expand(4).is_in(output) + + @pytest.mark.parametrize("share_params", [True, False]) + def test_gnn_attention_raises( + self, + share_params, + n_agents=3, + agent_goup="agents", + features=5, + ): + torch.manual_seed(0) + + input_spec = Composite( + { + agent_goup: Composite( + {"in": Unbounded(shape=(n_agents, features))}, + shape=(n_agents,), + ) + } + ) + + output_spec = Composite( + { + agent_goup: Composite( + {"out": Unbounded(shape=(n_agents, features + 1))}, + shape=(n_agents,), + ) + }, + ) + + with pytest.raises(ValueError, match="Input and output features must"): + GnnConfig( + topology="full", + self_loops=True, + gnn_class=torch_geometric.nn.attention.PerformerAttention, + gnn_kwargs=None, + exclude_pos_from_node_features=False, + ).get_model( + input_spec=input_spec, + output_spec=output_spec, + agent_group=agent_goup, + input_has_agent_dim=True, + n_agents=n_agents, + centralised=False, + share_params=share_params, + device="cpu", + action_spec=None, + ) + class TestDeepsets: @pytest.mark.parametrize("share_params", [True, False])