Skip to content
Discussion options

You must be logged in to vote

In general, you would first apply a CNN on your images, and then use the embedding produced by the CNN as input to your GNN. Given images of shape [num_nodes, num_channels, width, height], you can do:

img = CNN(img)
img = img.view(num_nodes, -1)
x = torch.cat([img, feature_vector], dim=-1)
out = GNN(x, edge_index)

Keep in mind that this will not scale well for large graphs. Currently, the model is trained jointly, that is each image for every node is processed together inside the CNN). An alternative is to use some pre-trained CNN, process the embeddings of nodes once, and use them afterwards as detached input to your GNN.

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by srg9000
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants