Using multitype edges between the same pair of entities in one metapath in AddMetaPaths #5047
-
Hi! I'm currently working with For example, I have entities AssertionError Traceback (most recent call last)
Input In [503], in <cell line: 1>()
----> 1 data = T.AddMetaPaths(metapaths=metapaths, drop_orig_edges=True, drop_unconnected_nodes=False)(data)
File ~/.conda/envs/myenv/lib/python3.10/site-packages/torch_geometric/transforms/add_metapaths.py:118, in AddMetaPaths.__call__(self, data)
114 for i, edge_type in enumerate(metapath[1:]):
115 adj2 = SparseTensor.from_edge_index(
116 edge_index=data[edge_type].edge_index,
117 sparse_sizes=data[edge_type].size())
--> 118 adj1 = adj1 @ adj2
120 row, col, _ = adj1.coo()
121 new_edge_type = (metapath[0][0], f'metapath_{j}', metapath[-1][-1])
File ~/.conda/envs/myenv/lib/python3.10/site-packages/torch_sparse/matmul.py:149, in <lambda>(self, other)
145 SparseTensor.spspmm = lambda self, other, reduce="sum": spspmm(
146 self, other, reduce)
147 SparseTensor.matmul = lambda self, other, reduce="sum": matmul(
148 self, other, reduce)
--> 149 SparseTensor.__matmul__ = lambda self, other: matmul(self, other, 'sum')
File ~/.conda/envs/myenv/lib/python3.10/site-packages/torch_sparse/matmul.py:140, in matmul(src, other, reduce)
138 return spmm(src, other, reduce)
139 elif isinstance(other, SparseTensor):
--> 140 return spspmm(src, other, reduce)
141 raise ValueError
File ~/.conda/envs/myenv/lib/python3.10/site-packages/torch_sparse/matmul.py:117, in spspmm(src, other, reduce)
114 def spspmm(src: SparseTensor, other: SparseTensor,
115 reduce: str = "sum") -> SparseTensor:
116 if reduce == 'sum' or reduce == 'add':
--> 117 return spspmm_sum(src, other)
118 elif reduce == 'mean' or reduce == 'min' or reduce == 'max':
119 raise NotImplementedError
File ~/.conda/envs/myenv/lib/python3.10/site-packages/torch_sparse/matmul.py:93, in spspmm_sum(src, other)
92 def spspmm_sum(src: SparseTensor, other: SparseTensor) -> SparseTensor:
---> 93 assert src.sparse_size(1) == other.sparse_size(0)
94 rowptrA, colA, valueA = src.csr()
95 rowptrB, colB, valueB = other.csr()
AssertionError: Also showing some toy examples: # NOT OK
temp = {('x', 'a', 'y'):{'edge_index':torch.tensor([[0], [0]])}, ('y', 'a_inv', 'x'):{'edge_index':torch.tensor([[0], [0]])},
('x', 'b', 'y'):{'edge_index':torch.tensor([[0, 0], [0, 1]])}, ('y', 'b_inv', 'x'):{'edge_index':torch.tensor([[0, 1], [0, 0]])}}
temp = HeteroData(temp)
temp['x'].x = torch.randn(1, 4)
metapaths = [[('x', 'a', 'y'), ('y', 'a_inv', 'x')],
[('x', 'a', 'y'), ('y', 'b_inv', 'x')]
]
T.AddMetaPaths(metapaths=metapaths, drop_orig_edges=True, drop_unconnected_nodes=False)(temp)
# NOT OK
temp = {('x', 'a', 'y'):{'edge_index':torch.tensor([[0], [0]])}, ('y', 'a_inv', 'x'):{'edge_index':torch.tensor([[0], [0]])},
('x', 'b', 'y'):{'edge_index':torch.tensor([[0], [1]])}, ('y', 'b_inv', 'x'):{'edge_index':torch.tensor([[1], [0]])}}
temp = HeteroData(temp)
temp['x'].x = torch.randn(1, 4)
metapaths = [[('x', 'a', 'y'), ('y', 'a_inv', 'x')],
[('x', 'a', 'y'), ('y', 'b_inv', 'x')]
]
T.AddMetaPaths(metapaths=metapaths, drop_orig_edges=True, drop_unconnected_nodes=False)(temp)
# NOT OK
temp = {('x', 'a', 'y'):{'edge_index':torch.tensor([[0], [0]])}, ('y', 'a_inv', 'x'):{'edge_index':torch.tensor([[0], [0]])},
('x', 'b', 'y'):{'edge_index':torch.tensor([[1], [0]])}, ('y', 'b_inv', 'x'):{'edge_index':torch.tensor([[0], [1]])}}
temp = HeteroData(temp)
temp['x'].x = torch.randn(2, 4)
metapaths = [[('x', 'a', 'y'), ('y', 'a_inv', 'x')],
[('x', 'a', 'y'), ('y', 'b_inv', 'x')]
]
T.AddMetaPaths(metapaths=metapaths, drop_orig_edges=True, drop_unconnected_nodes=False)(temp)
# OK
temp = {('x', 'a', 'y'):{'edge_index':torch.tensor([[0, 1], [0, 1]])}, ('y', 'a_inv', 'x'):{'edge_index':torch.tensor([[0, 1], [0, 1]])},
('x', 'b', 'y'):{'edge_index':torch.tensor([[0, 1], [1, 0]])}, ('y', 'b_inv', 'x'):{'edge_index':torch.tensor([[1, 0], [0, 1]])}}
temp = HeteroData(temp)
temp['x'].x = torch.randn(2, 4)
metapaths = [
[('x', 'a', 'y'), ('y', 'a_inv', 'x')],
[('x', 'a', 'y'), ('y', 'b_inv', 'x')]
]
T.AddMetaPaths(metapaths=metapaths, drop_orig_edges=True, drop_unconnected_nodes=False)(temp) It looks like the functionality is trying to make sure |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 1 reply
-
Thanks for the issue. @kpstesla Can you take a look? |
Beta Was this translation helpful? Give feedback.
-
@jasperhyp In your examples it looks like the issue is due to the fact that the The easiest way to fix this is to specify the number of nodes in # Not OK
temp = {('x', 'a', 'y'):{'edge_index':torch.tensor([[0], [0]])}, ('y', 'a_inv', 'x'):{'edge_index':torch.tensor([[0], [0]])},
('x', 'b', 'y'):{'edge_index':torch.tensor([[0, 0], [0, 1]])}, ('y', 'b_inv', 'x'):{'edge_index':torch.tensor([[0, 1], [0, 0]])}}
temp = HeteroData(temp)
temp['x'].x = torch.randn(1, 4)
metapaths = [[('x', 'a', 'y'), ('y', 'a_inv', 'x')],
[('x', 'a', 'y'), ('y', 'b_inv', 'x')]
]
AddMetaPaths(metapaths=metapaths, drop_orig_edges=True, drop_unconnected_nodes=False)(temp)
# OK
temp = {('x', 'a', 'y'):{'edge_index':torch.tensor([[0], [0]])}, ('y', 'a_inv', 'x'):{'edge_index':torch.tensor([[0], [0]])},
('x', 'b', 'y'):{'edge_index':torch.tensor([[0, 0], [0, 1]])}, ('y', 'b_inv', 'x'):{'edge_index':torch.tensor([[0, 1], [0, 0]])}}
temp = HeteroData(temp)
temp['x'].x = torch.randn(1, 4)
temp['y'].x = torch.randn(2, 4) # <======
metapaths = [[('x', 'a', 'y'), ('y', 'a_inv', 'x')],
[('x', 'a', 'y'), ('y', 'b_inv', 'x')]
]
AddMetaPaths(metapaths=metapaths, drop_orig_edges=True, drop_unconnected_nodes=False)(temp)
# Also OK
temp = {('x', 'a', 'y'):{'edge_index':torch.tensor([[0], [0]])}, ('y', 'a_inv', 'x'):{'edge_index':torch.tensor([[0], [0]])},
('x', 'b', 'y'):{'edge_index':torch.tensor([[0, 0], [0, 1]])}, ('y', 'b_inv', 'x'):{'edge_index':torch.tensor([[0, 1], [0, 0]])}}
temp = HeteroData(temp)
temp['x'].x = torch.randn(1, 4)
temp['y'].num_nodes = 2 # <======
metapaths = [[('x', 'a', 'y'), ('y', 'a_inv', 'x')],
[('x', 'a', 'y'), ('y', 'b_inv', 'x')]
]
AddMetaPaths(metapaths=metapaths, drop_orig_edges=True, drop_unconnected_nodes=False)(temp) Does this resolve the issue? |
Beta Was this translation helpful? Give feedback.
@jasperhyp In your examples it looks like the issue is due to the fact that the
HeteroData
does not know the number of nodes'y'
has, so when you try to compute metapaths, the number of nodes is inferred to be the number of nodes included in the edge index. Since'a'
and'b_inv'
both mention different numbers of nodes in'y'
, this leads to an error because the number of nodes inferred for y is different, so two incompatible adjacency matrices are constructed.The easiest way to fix this is to specify the number of nodes in
'y'
, either by setting the.num_nodes
attribute, or setting the.x
attribute to a feature matrix of the correct shape. For example: