-
I have a GNN model with a global pooling method, so I can compute the node representations and the graph representations. How can I do this for a batch of graphs? |
Beta Was this translation helpful? Give feedback.
Answered by
rusty1s
May 2, 2022
Replies: 1 comment 1 reply
-
You can consider concatenating both as part of your model: x_global = global_mean_pool(x, batch)
x = torch.cat([x, x_global[batch]], dim=-1) |
Beta Was this translation helpful? Give feedback.
1 reply
Answer selected by
Bowen-n
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
You can consider concatenating both as part of your model: