How to use virtual node with PyG? How many methods are there? #9373
Unanswered
StefanIsSmart
asked this question in
Q&A
Replies: 1 comment
-
Hi @StefanIsSmart , Below is an example of adding a virtual node to a molecule graph, import torch
from torch_geometric.data import Data
from torch_geometric.utils import from_smiles
if __name__ == '__main__':
data: Data = from_smiles(smiles='F/C=C/F')
# Creating molecule (graph) level features of the same dim
# as the node level features
mol_features = torch.randint(0, 9, size=(1, data.x.shape[-1]))
# Add a extra node (aka virtual node) containing the molecule level features
data.x = torch.cat((data.x, mol_features))
n = data.num_nodes - 1 # Exclude the virtual node
# Undirected graph
# All nodes -> Virtual node
d1 = torch.tensor([list(range(n)), [n] * n])
# Virtual node -> All nodes
d2 = torch.flip(d1, dims=[0])
# Append to the `edge_index`
data.edge_index = torch.cat((data.edge_index, d1, d2), dim=-1)
# Add edge attr for the newly added edges
dim1, dim2 = data.edge_attr.shape
no_of_missing_edge_attr = data.edge_index.shape[1] - dim1
edge_attr = torch.tensor([0] * dim2).unsqueeze(0).repeat_interleave(no_of_missing_edge_attr, dim=0)
data.edge_attr = torch.cat((data.edge_attr, edge_attr))
print(data.x)
print(data.edge_attr)
print(data.edge_index)
print(data.edge_attr)
# Sort the indices
perm = (data.edge_index[0] * data.x.size(0) + data.edge_index[1]).argsort()
data.edge_index, data.edge_attr = data.edge_index[:, perm], data.edge_attr[perm]
print(data.edge_index)
print(data.edge_attr) You can now use any GNN model and perform the message passing |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Hello everyone, I have a question:
My graph data is fully connected, and they have
node_features
andedge_features
(NotNone
).Now I want to add
graph_level_features
, just like a virtual node.How to do that with PyG?
I want to update
node_features
bygraph_level_features
; and updategraph_level_features
bynode_features
.The difference is the
edge_features
betweengraph_level_features
andnode_features
isNone
.How to do that?
Which method is more elegant and effective?
Beta Was this translation helpful? Give feedback.
All reactions