Skip to content

Commit e30e3ac

Browse files
committed
static gni
1 parent e71dfa2 commit e30e3ac

File tree

5 files changed

+110
-3
lines changed

5 files changed

+110
-3
lines changed

chebai_graph/preprocessing/datasets/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
ChEBI50_Atom_WGNOnly_GraphProp,
33
ChEBI50_NFGE_NGN_GraphProp,
44
ChEBI50_NFGE_WGN_GraphProp,
5+
ChEBI50_StaticGNI,
56
ChEBI50_WFGE_NGN_GraphProp,
67
ChEBI50_WFGE_WGN_GraphProp,
78
ChEBI50GraphData,
@@ -19,4 +20,5 @@
1920
"ChEBI50_NFGE_WGN_GraphProp",
2021
"ChEBI50_WFGE_NGN_GraphProp",
2122
"ChEBI50_WFGE_WGN_GraphProp",
23+
"ChEBI50_StaticGNI",
2224
]

chebai_graph/preprocessing/datasets/chebi.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
AtomsFGReader_NoFGEdges_NoGraphNode,
3232
GraphPropertyReader,
3333
GraphReader,
34+
RandomNodeInitializationReader,
3435
)
3536

3637
from .utils import resolve_property
@@ -188,9 +189,11 @@ def __init__(
188189
f"[Info] Atom-level features will be zero-padded with "
189190
f"{self.zero_pad_atom} additional dimensions."
190191
)
191-
print(
192-
f"Data module uses these properties (ordered): {', '.join([str(p) for p in self.properties])}"
193-
)
192+
193+
if self.properties:
194+
print(
195+
f"Data module uses these properties (ordered): {', '.join([str(p) for p in self.properties])}"
196+
)
194197

195198
def _merge_props_into_base(self, row: pd.Series) -> GeomData:
196199
"""
@@ -504,6 +507,24 @@ def _merge_props_into_base(
504507
)
505508

506509

510+
class ChEBI50_StaticGNI(DataPropertiesSetter, ChEBIOver50):
511+
READER = RandomNodeInitializationReader
512+
513+
def _setup_properties(self): ...
514+
515+
def load_processed_data_from_file(self, filename):
516+
base_data = super().load_processed_data_from_file(filename)
517+
base_df = pd.DataFrame(base_data)
518+
519+
rank_zero_info(
520+
f"Use following values for given parameters for model configuration: \n\t"
521+
f"in_channels: {self.reader.num_node_properties} , "
522+
f"edge_dim: {self.reader.num_bond_properties}, "
523+
f"n_molecule_properties: {self.reader.num_molecule_properties}"
524+
)
525+
return base_df[base_data[0].keys()].to_dict("records")
526+
527+
507528
class ChEBI50GraphProperties(GraphPropertiesMixIn, ChEBIOver50):
508529
"""ChEBIOver50 dataset with molecular property encodings."""
509530

chebai_graph/preprocessing/reader/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
AtomsFGReader_NoFGEdges_NoGraphNode,
77
)
88
from .reader import GraphPropertyReader, GraphReader
9+
from .static_gni import RandomNodeInitializationReader
910

1011
__all__ = [
1112
"GraphReader",
@@ -15,4 +16,5 @@
1516
"AtomFGReader_NoFGEdges_WithGraphNode",
1617
"AtomFGReader_WithFGEdges_NoGraphNode",
1718
"AtomFGReader_WithFGEdges_WithGraphNode",
19+
"RandomNodeInitializationReader",
1820
]
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
"""
2+
Abboud, Ralph, et al.
3+
"The surprising power of graph neural networks with random node initialization."
4+
arXiv preprint arXiv:2010.01179 (2020).
5+
6+
Code Reference: https://github.com/ralphabb/GNN-RNI/blob/main/GNNHyb.py
7+
"""
8+
9+
import torch
10+
from torch_geometric.data import Data as GeomData
11+
12+
from .reader import GraphPropertyReader
13+
14+
15+
class RandomNodeInitializationReader(GraphPropertyReader):
16+
def __init__(
17+
self,
18+
num_node_properties: int,
19+
num_bond_properties: int,
20+
num_molecule_properties: int,
21+
distribution: str,
22+
*args,
23+
**kwargs,
24+
):
25+
super().__init__(*args, **kwargs)
26+
self.num_node_properties = num_node_properties
27+
self.num_bond_properties = num_bond_properties
28+
self.num_molecule_properties = num_molecule_properties
29+
assert distribution in ["normal", "uniform", "xavier_normal", "xavier_uniform"]
30+
self.distribution = distribution
31+
32+
def name(self) -> str:
33+
"""
34+
Get the name identifier of the reader.
35+
36+
Returns:
37+
str: The name of the reader.
38+
"""
39+
return f"gni-{self.distribution}-node{self.num_node_properties}-bond{self.num_bond_properties}-mol{self.num_molecule_properties}"
40+
41+
def _read_data(self, raw_data):
42+
data: GeomData = super()._read_data(raw_data)
43+
random_x = torch.empty(data.x.shape[0], self.num_node_properties)
44+
random_edge_attr = torch.empty(
45+
data.edge_index.shape[1], self.num_bond_properties
46+
)
47+
random_molecule_properties = torch.empty(1, self.num_molecule_properties)
48+
49+
if self.distribution == "normal":
50+
torch.nn.init.normal_(random_x)
51+
torch.nn.init.normal_(random_edge_attr)
52+
torch.nn.init.normal_(random_molecule_properties)
53+
elif self.distribution == "uniform":
54+
torch.nn.init.uniform_(random_x, a=-1.0, b=1.0)
55+
torch.nn.init.uniform_(random_edge_attr, a=-1.0, b=1.0)
56+
torch.nn.init.uniform_(random_molecule_properties, a=-1.0, b=1.0)
57+
elif self.distribution == "xavier_normal":
58+
torch.nn.init.xavier_normal_(random_x)
59+
torch.nn.init.xavier_normal_(random_edge_attr)
60+
torch.nn.init.xavier_normal_(random_molecule_properties)
61+
elif self.distribution == "xavier_uniform":
62+
torch.nn.init.xavier_uniform_(random_x)
63+
torch.nn.init.xavier_uniform_(random_edge_attr)
64+
torch.nn.init.xavier_uniform_(random_molecule_properties)
65+
else:
66+
raise ValueError("Unknown distribution type")
67+
68+
data.x = random_x
69+
data.edge_attr = random_edge_attr
70+
data.molecule_attr = random_molecule_properties
71+
return data
72+
73+
def read_property(self, *args, **kwargs) -> Exception:
74+
"""This reader does not support reading specific properties."""
75+
raise NotImplementedError("This reader only performs random initialization.")
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
class_path: chebai_graph.preprocessing.datasets.ChEBI50_StaticGNI
2+
init_args:
3+
reader_kwargs:
4+
num_node_properties: 158
5+
num_bond_properties: 7
6+
num_molecule_properties: 200
7+
distribution: normal

0 commit comments

Comments
 (0)