@@ -58,35 +58,45 @@ def get_aspirin_graph(self):
5858
5959 # --- Edge list (bidirectional) ---
6060 # Shape of edge_index for undirected graph: 2 x num_of_edges; (2x26)
61- # 2 directed edges of one undirected edge are adjacent to each other --- this is needed
62-
63- # fmt: off
6461 # Generated using RDKIT 2024.9.6
65- edge_index = torch .tensor ([
66- [0 , 1 , 1 , 2 , 1 , 3 , 3 , 4 , 4 , 5 , 5 , 6 , 6 , 7 , 7 , 8 , 8 , 9 , 9 , 10 , 10 , 11 , 10 , 12 , 9 , 4 ], # Start atoms (u)
67- [1 , 0 , 2 , 1 , 3 , 1 , 4 , 3 , 5 , 4 , 6 , 5 , 7 , 6 , 8 , 7 , 9 , 8 , 10 , 9 , 11 , 10 , 12 , 10 , 4 , 9 ] # End atoms (v)
62+ # fmt: off
63+ _edge_index = torch .tensor ([
64+ [0 , 1 , 1 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 , 10 , 9 ], # Start atoms (u)
65+ [1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 , 11 , 12 , 4 ] # End atoms (v)
6866 ], dtype = torch .long )
6967 # fmt: on
7068
69+ # Reverse the edges
70+ reversed_edge_index = _edge_index [[1 , 0 ], :]
71+
72+ # First all directed edges from source to target are placed,
73+ # then all directed edges from target to source are placed --- this is needed
74+ undirected_edge_index = torch .cat ([_edge_index , reversed_edge_index ], dim = 1 )
75+
7176 # --- Dummy edge features ---
72- # Shape of edge_attr : num_of_edges x num_of_edges_features
77+ # Shape of undirected_edge_attr : num_of_edges x num_of_edges_features (26 x 1)
7378 # fmt: off
74- edge_attr = torch .tensor ([
75- [1 ], [ 1 ], # C0 - C1, This two features belong to elements at index 0 and 1 in `edge_index`
76- [2 ], [ 2 ], # C1 - C2, This two features belong to elements at index 2 and 3 in `edge_index`
77- [2 ], [ 2 ], # C1 - O3, This two features belong to elements at index 4 and 5 in `edge_index`
78- [2 ], [ 2 ], # O3 - C4, This two features belong to elements at index 6 and 7 in `edge_index`
79- [1 ], [ 1 ], # C4 - C5, This two features belong to elements at index 8 and 9 in `edge_index`
80- [1 ], [ 1 ], # C5 - C6, This two features belong to elements at index 10 and 11 in `edge_index`
81- [1 ], [ 1 ], # C6 - C7, This two features belong to elements at index 12 and 13 in `edge_index`
82- [1 ], [ 1 ], # C7 - C8, This two features belong to elements at index 14 and 15 in `edge_index`
83- [1 ], [ 1 ], # C8 - C9, This two features belong to elements at index 16 and 17 in `edge_index`
84- [1 ], [ 1 ], # C9 - C10, This two features belong to elements at index 18 and 19 in `edge_index`
85- [1 ], [ 1 ], # C10 - O11, This two features belong to elements at index 20 and 21 in `edge_index`
86- [1 ], [ 1 ], # C10 - O12, This two features belong to elements at index 22 and 23 in `edge_index`
87- [1 ], [ 1 ], # C9 - C4, This two features belong to elements at index 24 and 25 in `edge_index`
79+ _edge_attr = torch .tensor ([
80+ [1 ], # C0 - C1, This two features belong to elements at index 0 in `edge_index`
81+ [2 ], # C1 - C2, This two features belong to elements at index 1 in `edge_index`
82+ [2 ], # C1 - O3, This two features belong to elements at index 2 in `edge_index`
83+ [2 ], # O3 - C4, This two features belong to elements at index 3 in `edge_index`
84+ [1 ], # C4 - C5, This two features belong to elements at index 4 in `edge_index`
85+ [1 ], # C5 - C6, This two features belong to elements at index 5 in `edge_index`
86+ [1 ], # C6 - C7, This two features belong to elements at index 6 in `edge_index`
87+ [1 ], # C7 - C8, This two features belong to elements at index 7 in `edge_index`
88+ [1 ], # C8 - C9, This two features belong to elements at index 8 in `edge_index`
89+ [1 ], # C9 - C10, This two features belong to elements at index 9 in `edge_index`
90+ [1 ], # C10 - O11, This two features belong to elements at index 10 in `edge_index`
91+ [1 ], # C10 - O12, This two features belong to elements at index 11 in `edge_index`
92+ [1 ], # C9 - C4, This two features belong to elements at index 12 in `edge_index`
8893 ], dtype = torch .float )
8994 # fmt: on
9095
96+ # Alignement of edge attributes should in same order as of edge_index
97+ undirected_edge_attr = torch .cat ([_edge_attr , _edge_attr ], dim = 0 )
98+
9199 # Create graph data object
92- return Data (x = x , edge_index = edge_index , edge_attr = edge_attr )
100+ return Data (
101+ x = x , edge_index = undirected_edge_index , edge_attr = undirected_edge_attr
102+ )
0 commit comments