Skip to content
Discussion options

You must be logged in to vote

What you may wanna try is to represent node indices as sparse matrices, and utilize sparse-sparse matrix multiplication, e.g.:

from torch_sparse import SparseTensor
from torch_scatter import scatter_max

adj_t = SparseTensor(row=edge_index[0], col=edge_index[1], sparse_sizes=(N, N)).t()
x = SparseTensor.eye(N)

out = adj_t @ x 

# Get the argmax of each row:
row, col, value = out.coo()

max, argmax = scatter_max(out[col], row, dim=0, dim_size=N)

# Assign most common label:
x = SparseTensor(row=torch.arange(N), col=argmax, sparse_sizes=(N, N))

Regarding your second question, this feels really similar to a classic link prediction problem, e.g.:

row, col = edge_index
out = torch.cat([x[row], x

Replies: 1 comment 2 replies

Comment options

You must be logged in to vote
2 replies
@cameronosmith
Comment options

@rusty1s
Comment options

Answer selected by cameronosmith
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants