-
Hi, I use PyG to implement my graph pooling model, however, it seems PyG suffers from numerical unstable on GPUs when using However when I need to implement a graph pooling layer, I need some global pooling operations like |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 4 replies
-
You can use from torch_scatter import segment_csr
global_x = segment_csr(x, batch.ptr, reduce='sum') |
Beta Was this translation helpful? Give feedback.
You can use
segment_csr
fromtorch-scatter
: