Skip to content
Discussion options

You must be logged in to vote

You need an additional module (decoder) to decode the edge representations from the node representations. For example:

class EdgeDecoder(torch.nn.Module):
    def __init__(self,):
        super().__init__()
        self.lin1 = torch.nn.Linear(32, 64)
        self.lin2 = torch.nn.Linear(64, 1)

    def forward(self, x, edge_index):
        x1 = x[edge_index[0]]
        x2 = x[edge_index[1]]

        h = torch.cat([x1, x2], dim=1)
        h = self.lin1(h).relu()
        return self.lin2(h).sigmoid()

# ...
decoder = EdgeDecoder()
# ...
pedictions = model(train_data)
predictions = decoder(predictions, train_data.edge_index)
# ...

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by koongjl
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants