Incorporating Edge Features into GATConv for edge-level prediction #9055
-
I am encountering this problem, where I need to predict the label of the 4000 input edges, but the model only output 1000 predictions. |
Beta Was this translation helpful? Give feedback.
Answered by
EdisonLeeeee
Mar 18, 2024
Replies: 1 comment
-
You need an additional module (decoder) to decode the edge representations from the node representations. For example: class EdgeDecoder(torch.nn.Module):
def __init__(self,):
super().__init__()
self.lin1 = torch.nn.Linear(32, 64)
self.lin2 = torch.nn.Linear(64, 1)
def forward(self, x, edge_index):
x1 = x[edge_index[0]]
x2 = x[edge_index[1]]
h = torch.cat([x1, x2], dim=1)
h = self.lin1(h).relu()
return self.lin2(h).sigmoid()
# ...
decoder = EdgeDecoder()
# ...
pedictions = model(train_data)
predictions = decoder(predictions, train_data.edge_index)
# ... |
Beta Was this translation helpful? Give feedback.
0 replies
Answer selected by
koongjl
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
You need an additional module (decoder) to decode the edge representations from the node representations. For example: