Skip to content

Commit 1a8dcb6

Browse files
committed
first src to tgt edges then tgt to src
- instead of using adjacent directed edge, this one is better approach since we can stack edge attributes generated later without any further logic to rearrange edge_attr
1 parent 0a9760d commit 1a8dcb6

File tree

2 files changed

+37
-30
lines changed

2 files changed

+37
-30
lines changed

chebai_graph/preprocessing/reader.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,8 @@
88
import torch
99
from lightning_utilities.core.rank_zero import rank_zero_info, rank_zero_warn
1010
from torch_geometric.data import Data as GeomData
11-
from torch_geometric.utils import from_networkx, to_undirected
1211

13-
import chebai_graph.preprocessing.properties as properties
12+
from chebai_graph.preprocessing import properties
1413
from chebai_graph.preprocessing.collate import GraphCollator
1514

1615

@@ -55,12 +54,10 @@ def _read_data(self, raw_data):
5554

5655
x = torch.zeros((mol.GetNumAtoms(), 0))
5756

58-
# We need to ensure that directed edges which form a undirected edge are adjacent to each other
59-
edge_index_list = [[], []]
60-
for bond in mol.GetBonds():
61-
edge_index_list[0].extend([bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()])
62-
edge_index_list[1].extend([bond.GetEndAtomIdx(), bond.GetBeginAtomIdx()])
63-
edge_index = torch.tensor(edge_index_list, dtype=torch.long)
57+
# First source to target edges, then target to source edges
58+
src = [bond.GetBeginAtomIdx() for bond in mol.GetBonds()]
59+
tgt = [bond.GetEndAtomIdx() for bond in mol.GetBonds()]
60+
edge_index = torch.tensor([src + tgt, tgt + src], dtype=torch.long)
6461

6562
# edge_index.shape == [2, num_edges]; edge_attr.shape == [num_edges, num_edge_features]
6663
edge_attr = torch.zeros((edge_index.size(1), 0))

tests/unit/test_data.py

Lines changed: 32 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -58,35 +58,45 @@ def get_aspirin_graph(self):
5858

5959
# --- Edge list (bidirectional) ---
6060
# Shape of edge_index for undirected graph: 2 x num_of_edges; (2x26)
61-
# 2 directed edges of one undirected edge are adjacent to each other --- this is needed
62-
63-
# fmt: off
6461
# Generated using RDKIT 2024.9.6
65-
edge_index = torch.tensor([
66-
[0, 1, 1, 2, 1, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9, 10, 10, 11, 10, 12, 9, 4], # Start atoms (u)
67-
[1, 0, 2, 1, 3, 1, 4, 3, 5, 4, 6, 5, 7, 6, 8, 7, 9, 8, 10, 9, 11, 10, 12, 10, 4, 9] # End atoms (v)
62+
# fmt: off
63+
_edge_index = torch.tensor([
64+
[0, 1, 1, 3, 4, 5, 6, 7, 8, 9, 10, 10, 9], # Start atoms (u)
65+
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 4] # End atoms (v)
6866
], dtype=torch.long)
6967
# fmt: on
7068

69+
# Reverse the edges
70+
reversed_edge_index = _edge_index[[1, 0], :]
71+
72+
# First all directed edges from source to target are placed,
73+
# then all directed edges from target to source are placed --- this is needed
74+
undirected_edge_index = torch.cat([_edge_index, reversed_edge_index], dim=1)
75+
7176
# --- Dummy edge features ---
72-
# Shape of edge_attr: num_of_edges x num_of_edges_features
77+
# Shape of undirected_edge_attr: num_of_edges x num_of_edges_features (26 x 1)
7378
# fmt: off
74-
edge_attr = torch.tensor([
75-
[1], [1], # C0 - C1, This two features belong to elements at index 0 and 1 in `edge_index`
76-
[2], [2], # C1 - C2, This two features belong to elements at index 2 and 3 in `edge_index`
77-
[2], [2], # C1 - O3, This two features belong to elements at index 4 and 5 in `edge_index`
78-
[2], [2], # O3 - C4, This two features belong to elements at index 6 and 7 in `edge_index`
79-
[1], [1], # C4 - C5, This two features belong to elements at index 8 and 9 in `edge_index`
80-
[1], [1], # C5 - C6, This two features belong to elements at index 10 and 11 in `edge_index`
81-
[1], [1], # C6 - C7, This two features belong to elements at index 12 and 13 in `edge_index`
82-
[1], [1], # C7 - C8, This two features belong to elements at index 14 and 15 in `edge_index`
83-
[1], [1], # C8 - C9, This two features belong to elements at index 16 and 17 in `edge_index`
84-
[1], [1], # C9 - C10, This two features belong to elements at index 18 and 19 in `edge_index`
85-
[1], [1], # C10 - O11, This two features belong to elements at index 20 and 21 in `edge_index`
86-
[1], [1], # C10 - O12, This two features belong to elements at index 22 and 23 in `edge_index`
87-
[1], [1], # C9 - C4, This two features belong to elements at index 24 and 25 in `edge_index`
79+
_edge_attr = torch.tensor([
80+
[1], # C0 - C1, This two features belong to elements at index 0 in `edge_index`
81+
[2], # C1 - C2, This two features belong to elements at index 1 in `edge_index`
82+
[2], # C1 - O3, This two features belong to elements at index 2 in `edge_index`
83+
[2], # O3 - C4, This two features belong to elements at index 3 in `edge_index`
84+
[1], # C4 - C5, This two features belong to elements at index 4 in `edge_index`
85+
[1], # C5 - C6, This two features belong to elements at index 5 in `edge_index`
86+
[1], # C6 - C7, This two features belong to elements at index 6 in `edge_index`
87+
[1], # C7 - C8, This two features belong to elements at index 7 in `edge_index`
88+
[1], # C8 - C9, This two features belong to elements at index 8 in `edge_index`
89+
[1], # C9 - C10, This two features belong to elements at index 9 in `edge_index`
90+
[1], # C10 - O11, This two features belong to elements at index 10 in `edge_index`
91+
[1], # C10 - O12, This two features belong to elements at index 11 in `edge_index`
92+
[1], # C9 - C4, This two features belong to elements at index 12 in `edge_index`
8893
], dtype=torch.float)
8994
# fmt: on
9095

96+
# Alignement of edge attributes should in same order as of edge_index
97+
undirected_edge_attr = torch.cat([_edge_attr, _edge_attr], dim=0)
98+
9199
# Create graph data object
92-
return Data(x=x, edge_index=edge_index, edge_attr=edge_attr)
100+
return Data(
101+
x=x, edge_index=undirected_edge_index, edge_attr=undirected_edge_attr
102+
)

0 commit comments

Comments
 (0)