Skip to content
Discussion options

You must be logged in to vote

Not sure if the above code will work correctly for batch_size>1, because the output of this line global_x = global_x.repeat(x.shape[0]).view(x.shape[0],global_x.shape[0]) should have a shape of num_nodes*(batches*num_feats) so each nodes features would be concatenated with features from all graphs in the batch.

So maybe check if you can use the below code instead.

  global_x = global_max_pool(x, data.batch)[data.batch]
  x = torch.cat((x, global_x), dim=1)

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@scarpma
Comment options

Answer selected by rusty1s
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants