Generating fully-connected edge index for a Batch object efficiently #4285
-
Though I know it's inefficient to handle fully-connected graphs in PyG, I do need the fully-connected edges. I know how to obtain fully-connected edges for a single graph. However, I don't hope to pre-compute the edges and store them in the disk/ram during the preprocessing stage. Instead, I hope to compute the edges after a Batch object is loaded to the GPU, so I'm looking for an efficient way to generate the fully-connected edges in a batched manner. I saw a way to implement it in this pull request, but it involves a for loop and I'm not sure whether this is efficient enough:
Since the node features in my case are coordinates, one workaround may be using the radius_graph in torch_cluster and setting the radius to a very large number, but it involves some unnecessary computation. Do I need to customize a cuda op so that I can efficiently generate fully-connected edges for a Batch object? |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 2 replies
-
How about you obtain the fully-connected graphs as a |
Beta Was this translation helpful? Give feedback.
-
Two ways I can think of.
import torch
from torch import Tensor
def fully_connect(batch: Tensor) -> Tensor:
batch # tensor of shape [N,], where N is the number of nodes in the batch. N can get very big.
A = torch.cartesian_prod(torch.arange(batch.size(0)), torch.arange(batch.size(0)))
# At this point A is shape [N^2,2]
mask = batch[A[:,0]] == batch[A[:,1]] # This is the key. mask will of shape [N^2, 1]. mask[i] = True iff batch[A[i, 0]] == batch[A[i, 1]]
A = A[mask]
A = A.view(-1, 2)
return A
import torch
from torch import Tensor
def fully_connect(ptr: Tensor) -> Tensor:
""" ptr is shape [batch_size] """
A = torch.concat([torch.cartesian_prod(torch.arange(ptr[i],ptr[i+1]), torch.arange(ptr[i],ptr[i+1]))
for i in range(ptr.size(0)-1) ], dim=0)
A = A.t()
return A
# usage
assert data.ptr is not None
full_connect(data.ptr) Thanks |
Beta Was this translation helpful? Give feedback.
How about you obtain the fully-connected graphs as a
transform
on-the-fly for each individualdata
object? This way, you do not need to think about how to create a batch-wise fully connected graph. To me, it seems tricky to implement without a for-loop.