Skip to content
Discussion options

You must be logged in to vote

In general, data.num_nodes is undefined for bipartite graphs, and you should simply avoid to call it in this scenario. Instead, simply use x_s.size(0) and x_t.size(0) to infer the number of source and destination nodes. Alternatively, you can override num_nodes directly:

from torch_geometric.data import Data
class BipartiteData(Data):
    def __init__(self, edge_index=None, x_s=None, x_t=None, edge_attr=None):
        super().__init__()
        self.edge_index = edge_index
        self.x_s = x_s
        self.x_t = x_t
        self.edge_attr = edge_attr

    def __inc__(self, key, value, *args, **kwargs):
        if key == 'edge_index':
            return torch.tensor([[self.x_s.size(0)], [s…

Replies: 1 comment 4 replies

Comment options

You must be logged in to vote
4 replies
@zcaicaros
Comment options

@rusty1s
Comment options

@zcaicaros
Comment options

@rusty1s
Comment options

Answer selected by zcaicaros
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants