Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions benchmarl/conf/model/layers/gnn.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@ topology: full
self_loops: False

gnn_class: torch_geometric.nn.conv.GraphConv
gnn_kwargs:
aggr: "add"
gnn_kwargs: {}

position_key: null
pos_features: 0
Expand Down
56 changes: 38 additions & 18 deletions benchmarl/models/gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class Gnn(Model):
topology (str): Topology of the graph adjacency matrix. Options: "full", "empty", "from_pos". "from_pos" builds
the topology dynamically based on ``position_key`` and ``edge_radius``.
self_loops (str): Whether the resulting adjacency matrix will have self loops.
gnn_class (Type[torch_geometric.nn.MessagePassing]): the gnn convolution class to use
gnn_class (Type[torch.nn.Module]): the gnn convolution class to use
gnn_kwargs (dict, optional): the dict of arguments to pass to the gnn conv class
position_key (str, optional): if provided, it will need to match a leaf key in the tensordict coming from the env
(in the `observation_spec`) representing the agent position.
Expand Down Expand Up @@ -124,7 +124,7 @@ def __init__(
self,
topology: str,
self_loops: bool,
gnn_class: Type[torch_geometric.nn.MessagePassing],
gnn_class: Type[torch.nn.Module],
gnn_kwargs: Optional[dict],
position_key: Optional[str],
exclude_pos_from_node_features: Optional[bool],
Expand All @@ -142,6 +142,9 @@ def __init__(
self.edge_radius = edge_radius
self.pos_features = pos_features
self.vel_features = vel_features
self.is_message_passing = issubclass(
gnn_class, torch_geometric.nn.MessagePassing
)

super().__init__(**kwargs)

Expand All @@ -164,9 +167,24 @@ def __init__(

if gnn_kwargs is None:
gnn_kwargs = {}
gnn_kwargs.update(
{"in_channels": self.input_features, "out_channels": self.output_features}
)
if self.is_message_passing:
gnn_kwargs.update(
{
"in_channels": self.input_features,
"out_channels": self.output_features,
}
)
else:
if self.input_features != self.output_features:
raise ValueError(
"Input and output features must be the same for non-MessagePassing GNN classes"
)
gnn_kwargs.update(
{
"channels": self.input_features,
"heads": gnn_kwargs.get("heads", 1),
}
)
self.gnn_supports_edge_attrs = (
"edge_dim" in inspect.getfullargspec(gnn_class).args
)
Expand Down Expand Up @@ -315,18 +333,20 @@ def _forward(self, tensordict: TensorDictBase) -> TensorDictBase:
input = torch.cat(input, dim=-1)
batch_size = input.shape[:-2]

graph = _batch_from_dense_to_ptg(
x=input,
edge_index=self.edge_index,
pos=pos,
vel=vel,
self_loops=self.self_loops,
edge_radius=self.edge_radius,
)
forward_gnn_params = {
"x": graph.x,
"edge_index": graph.edge_index,
}
if self.is_message_passing:
graph = _batch_from_dense_to_ptg(
x=input,
edge_index=self.edge_index,
pos=pos,
vel=vel,
self_loops=self.self_loops,
edge_radius=self.edge_radius,
)
forward_gnn_params = {"x": graph.x, "edge_index": graph.edge_index}
else:
forward_gnn_params = {
"x": input,
}
if (
self.position_key is not None or self.velocity_key is not None
) and self.gnn_supports_edge_attrs:
Expand Down Expand Up @@ -468,7 +488,7 @@ class GnnConfig(ModelConfig):
topology: str = MISSING
self_loops: bool = MISSING

gnn_class: Type[torch_geometric.nn.MessagePassing] = MISSING
gnn_class: Type[torch.nn.Module] = MISSING
gnn_kwargs: Optional[dict] = None

position_key: Optional[str] = None
Expand Down
97 changes: 97 additions & 0 deletions test/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,103 @@ def test_gnn_edge_attrs(
output = gnn(obs_input)
assert output_spec.expand(batch_size).is_in(output)

@pytest.mark.parametrize("share_params", [True, False])
def test_gnn_attention(
self,
share_params,
n_agents=3,
agent_goup="agents",
features=5,
):
torch.manual_seed(0)

input_spec = Composite(
{
agent_goup: Composite(
{"in": Unbounded(shape=(n_agents, features))},
shape=(n_agents,),
)
}
)

output_spec = Composite(
{
agent_goup: Composite(
{"out": Unbounded(shape=(n_agents, features))},
shape=(n_agents,),
)
},
)

gnn = GnnConfig(
topology="full",
self_loops=True,
gnn_class=torch_geometric.nn.attention.PerformerAttention,
gnn_kwargs=None,
exclude_pos_from_node_features=False,
).get_model(
input_spec=input_spec,
output_spec=output_spec,
agent_group=agent_goup,
input_has_agent_dim=True,
n_agents=n_agents,
centralised=False,
share_params=share_params,
device="cpu",
action_spec=None,
)

obs_input = input_spec.expand(4).rand()
output = gnn(obs_input)
assert output_spec.expand(4).is_in(output)

@pytest.mark.parametrize("share_params", [True, False])
def test_gnn_attention_raises(
self,
share_params,
n_agents=3,
agent_goup="agents",
features=5,
):
torch.manual_seed(0)

input_spec = Composite(
{
agent_goup: Composite(
{"in": Unbounded(shape=(n_agents, features))},
shape=(n_agents,),
)
}
)

output_spec = Composite(
{
agent_goup: Composite(
{"out": Unbounded(shape=(n_agents, features + 1))},
shape=(n_agents,),
)
},
)

with pytest.raises(ValueError, match="Input and output features must"):
GnnConfig(
topology="full",
self_loops=True,
gnn_class=torch_geometric.nn.attention.PerformerAttention,
gnn_kwargs=None,
exclude_pos_from_node_features=False,
).get_model(
input_spec=input_spec,
output_spec=output_spec,
agent_group=agent_goup,
input_has_agent_dim=True,
n_agents=n_agents,
centralised=False,
share_params=share_params,
device="cpu",
action_spec=None,
)


class TestDeepsets:
@pytest.mark.parametrize("share_params", [True, False])
Expand Down