-
Hey. I have a couple questions regarding the MetaLayer of PyTorch Geometric https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/nn/meta.html:
Now I tried to stay close to the code of the NodeModel, however I am not entirely sure if I got this correct, and advise would be much appreciated.
If someone could help me with this it would be greatly appreciated. Thanks in advance! |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
class GlobalModel(torch.nn.Module):
...
def forward(self, x, edge_index, edge_attr, u, batch):
node_aggregate = scatter_mean(x, batch, dim=0)
edge_aggregate = scatter_mean(edge_attr, batch[edge_index[1]], dim=0)
out = torch.cat([u, node_aggregate, edge_aggregate], dim=1)
return self.global_mlp(out)
|
Beta Was this translation helpful? Give feedback.
MetaLayer
tries to follow Algorithm 1 in the given paper. Global graph featuresu
are passed to both edge and node-level modules. For the global module, you need to take care of aggregating node and edge features by yourself, e.g, step 4-6 needs to be implemented inglobal_model
, e.g.: