How to use global_mean_pool for single graph dataset like Cora without batch? #4666
-
Hi, I would like to use global_mean_pool function to get a single graph's embedding, but I got the following error information when I use global_mean_pool(x): And if I set the param 'batch=None', I got another error: So how should I use global_mean_pool correctly if I want to obtain the graph embedding given node embeddings from a single graph? BTW, the version of my pytorch_geometric is 1.6.1. Thanks. Look forward to your reply. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
You can either do: x = x.mean(dim=0, keepdim=True) if batch is None else global_mean_pool(x, batch) or do if batch is None:
batch = torch.zeros(x.size(0), dtype=torch.long, device=x.device)
x = global_mean_pool(x, batch) Note that this is already fixed in later PyG versions. |
Beta Was this translation helpful? Give feedback.
You can either do:
or do
Note that this is already fixed in later PyG versions.