-
Hi PyG creator and community, I have a question regarding bipartite graph data in the following. It would be much appreciated if can get any hint from you! By following this doc, I create a bipartite graph data like this:
Then I instantiate it like:
But when I access the
Q1: Is it an expected behavior to count only the number of target nodes? To remove the warning, I add
Now the warning is gone and the
So far so good, then a new problem happens when I try to create a batch using
Then I just remove the line
As expected, it only counts the target nodes which is Q2: So how to create bipartite graph data and its batch? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 4 replies
-
In general, from torch_geometric.data import Data
class BipartiteData(Data):
def __init__(self, edge_index=None, x_s=None, x_t=None, edge_attr=None):
super().__init__()
self.edge_index = edge_index
self.x_s = x_s
self.x_t = x_t
self.edge_attr = edge_attr
def __inc__(self, key, value, *args, **kwargs):
if key == 'edge_index':
return torch.tensor([[self.x_s.size(0)], [self.x_t.size(0)]])
else:
return super().__inc__(key, value, *args, **kwargs)
@property
def num_nodes(self):
return self.x_s.size(0) + self.x_t.size(1) |
Beta Was this translation helpful? Give feedback.
In general,
data.num_nodes
is undefined for bipartite graphs, and you should simply avoid to call it in this scenario. Instead, simply usex_s.size(0)
andx_t.size(0)
to infer the number of source and destination nodes. Alternatively, you can overridenum_nodes
directly: