Skip to content
Discussion options

You must be logged in to vote

I think the issue is that your heterogeneous model is still expected to return single tensors, not dictionaries of tensors. As such, what you can do is something like

class HeteroSAGE(torch.nn.Module):
    def __init__(self, metadata):
        super().__init__()
        self.graph_sage = to_hetero(GraphSAGE(-1, 32, num_layers=2), metadata)
        self.lin = torch.nn.Linear(32, 1)

    def forward(self, x_dict, edge_index_dict) -> torch.Tensor:
        return self.lin(self.graph_sage(x_dict, edge_index_dict)['paper'])

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@bl3e967
Comment options

Answer selected by bl3e967
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants