-
I have a question about the current MetaLayer implementation, especially the node model. In short: does the NodeModel of the MetaLayer take directional edges into account? In long: However now I have an issue with understanding the code of the NodeModel (see below, taken from docu of 2.4.0). To my understanding it aggregates over ALL edges connected to the node regardless of their directionality, while I would want it to only take into account edges pointing towards it. class NodeModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.node_mlp_1 = Seq(Lin(..., ...), ReLU(), Lin(..., ...))
self.node_mlp_2 = Seq(Lin(..., ...), ReLU(), Lin(..., ...))
def forward(self, x, edge_index, edge_attr, u, batch):
# x: [N, F_x], where N is the number of nodes.
# edge_index: [2, E] with max entry N - 1.
# edge_attr: [E, F_e]
# u: [B, F_u]
# batch: [N] with max entry B - 1.
row, col = edge_index
out = torch.cat([x[row], edge_attr], dim=1)
out = self.node_mlp_1(out)
out = scatter(out, col, dim=0, dim_size=x.size(0),
reduce='mean')
out = torch.cat([x, out, u[batch]], dim=1)
return self.node_mlp_2(out) |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
All graphs in PyG are understood as directional. For bidirectional edges, one has to add both directions to the |
Beta Was this translation helpful? Give feedback.
All graphs in PyG are understood as directional. For bidirectional edges, one has to add both directions to the
edge_index
. As such, PyG already does what you want it to do.