Handle shape of node representations #5751
-
Hi. I am using GNNs to get node representations to feed to downstream models. I have the following model:
I want my network to output a representation of each node, assuming there can be 50265 of them (num_nodes). Specifically, I want a way to produce a representation of shape |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
You want to drop the x = self.lin(x)
x, mask = to_dense_batch(x, batch)
return x # [batch_size, num_nodes, num_features] |
Beta Was this translation helpful? Give feedback.
You want to drop the
global_add_pool
here. In addition, your finalLinear
module should outputhidden_size
features instead of 50265. Then, you can do: