-
pytorch_geometric/torch_geometric/data/graph_store.py Lines 277 to 281 in fb1d855 TL;DR. The CSR format that this is converted to does not have a long enough rowptr tensor to represent graphs that have isolated nodes where such nodes have a higher index than any of the source nodes for the edge indices that you define. This means that the random_walk implementation in both pyg_lib and torch_cluster fails. With that being said, is there a way I can get the CSR representation without losing these nodes? More explanation: So, I have been getting this fun runtime error in my implementation of node2vec that I have been writing (I have my reasons, but they aren't relevant to this story). It was a CUDA Runtime Error saying I was going out of bounds and accessing memory I should not be. However, as far as I could tell, I was not going out of bounds. Then, I looked deeper, and it turns out that given a graph like the following the CSR representation is as follows rowptr=tensor([0, 1, 1, 2, 3], device='cuda:0')
col=tensor([3, 1, 2], device='cuda:0') This means that we are losing information about the graph because the CSR representation, which includes nodes 4 and 5, should look like rowptr=tensor([0, 1, 1, 2, 3, 3, 3], device='cuda:0')
col=tensor([3, 1, 2], device='cuda:0') This feels like it should be addressed in an issue if it's a problem, but maybe I'm missing something, so I want to ask first. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 15 replies
-
you have isolated nodes in your graph that are not being included in the CSR representation, causing issues in your node2vec implementation. To preserve these isolated nodes, you can modify the code that converts the COO format to CSR. row = torch.tensor([0, 2, 3], dtype=torch.long)
col = torch.tensor([3, 1, 2], dtype=torch.long)
num_nodes = 6 # Optional, include the isolated nodes
rowptr, col = coo_to_csr(row, col, num_nodes) you could try this function here to handle COO to CSR - above is how you can call and assign it. def coo_to_csr(row, col, num_nodes=None):
if num_nodes is None:
num_nodes = int(row.max()) + 1
row, perm = index_sort(row, max_value=num_nodes)
col = col[perm]
rowptr = torch.zeros(num_nodes + 1, dtype=row.dtype, device=row.device)
rowptr[1:] = (row[:-1] != row[1:]).to(rowptr.dtype)
rowptr = torch.cumsum(rowptr, dim=0)
return rowptr, col |
Beta Was this translation helpful? Give feedback.
Thanks for the example, helped a lot. I fixed it in #7316. Sorry for any inconvenience.