Skip to content
Discussion options

You must be logged in to vote
  1. The MetaLayer tries to follow Algorithm 1 in the given paper. Global graph features u are passed to both edge and node-level modules. For the global module, you need to take care of aggregating node and edge features by yourself, e.g, step 4-6 needs to be implemented in global_model, e.g.:
class GlobalModel(torch.nn.Module):
    ...
    def forward(self, x, edge_index, edge_attr, u, batch):
        node_aggregate = scatter_mean(x, batch, dim=0)
        edge_aggregate = scatter_mean(edge_attr, batch[edge_index[1]], dim=0)
        out = torch.cat([u, node_aggregate, edge_aggregate], dim=1)
        return self.global_mlp(out)
  1. This looks correct to me, although you can also implement the two

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@varghele
Comment options

Answer selected by varghele
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants