Skip to content

Commit 3edf151

Browse files
authored
Merge pull request #30 from ChEB-AI/documentation
Documentation for Augmented graphs and GNI
2 parents 0ef9e87 + 83ebc62 commit 3edf151

File tree

3 files changed

+301
-30
lines changed

3 files changed

+301
-30
lines changed

README.md

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,3 +75,114 @@ The list can be found in the `configs/data/chebi50_graph_properties.yml` file.
7575
```bash
7676
python -m chebai fit --trainer=configs/training/default_trainer.yml --trainer.logger=configs/training/csv_logger.yml --model=../python-chebai-graph/configs/model/gnn_res_gated.yml --model.train_metrics=configs/metrics/micro-macro-f1.yml --model.test_metrics=configs/metrics/micro-macro-f1.yml --model.val_metrics=configs/metrics/micro-macro-f1.yml --data=../python-chebai-graph/configs/data/chebi50_graph_properties.yml --data.init_args.batch_size=128 --trainer.accumulate_grad_batches=4 --data.init_args.num_workers=10 --model.pass_loss_kwargs=false --data.init_args.chebi_version=241 --trainer.min_epochs=200 --trainer.max_epochs=200 --model.criterion=configs/loss/bce.yml
7777
```
78+
79+
## Augmented Graphs
80+
81+
82+
Graph Neural Networks (GNNs) often fail to explicitly leverage the chemically meaningful substructures present within molecules (i.e. **functional groups (FGs)**). To make this implicit information explicitly accessible to GNNs, we augment molecular graphs with **artificial nodes** that represent these substructures. The resulting graph are referred to as **augmented graphs**.
83+
> Note: Rings are also treated as functional groups in our work.
84+
85+
In these augmented graphs, each functional group node is connected to the atoms that constitute the group. Additionally, two functional group nodes are connected if any atom belonging to one group shares a bond with an atom from the other group. We further introduce a **graph node**, an extra node connected to all functional group nodes.
86+
87+
Among all the connection schemes we evaluated, this configuration delivered the strongest performance. We denote it using the abbreviation **WFG_WFGE_WGN** in our work and is shown in below figure.
88+
89+
<img width="1220" height="668" alt="mol_to_aug_mol" src="https://github.com/user-attachments/assets/0aba6b80-765b-45a6-913a-7d628f14a5db" />
90+
91+
</br>
92+
</br>
93+
94+
Below is the command for the model and data configuration that achieved the best classification performance using augmented graphs.
95+
96+
```bash
97+
python -m chebai fit --trainer=configs/training/default_trainer.yml --trainer.logger=configs/training/wandb_logger.yml --model=../python-chebai-graph/configs/model/gat_aug_amgpool.yml --model.train_metrics=configs/metrics/micro-macro-f1.yml --model.test_metrics=configs/metrics/micro-macro-f1.yml --model.val_metrics=configs/metrics/micro-macro-f1.yml --model.config.v2=True --data=../python-chebai-graph/configs/data/chebi50_aug_prop_as_per_node.yml --data.init_args.batch_size=128 --trainer.accumulate_grad_batches=4 --data.init_args.num_workers=10 --model.pass_loss_kwargs=false --data.init_args.chebi_version=241 --trainer.min_epochs=200 --trainer.max_epochs=200 --model.criterion=configs/loss/bce.yml --trainer.logger.init_args.name=gatv2_amg_s0
98+
```
99+
100+
### Model Hyperparameters
101+
102+
#### **GAT Architecture**
103+
104+
To use a GAT-based model, choose **one** of the following configs:
105+
106+
- **Standard Pooling**: `--model=../python-chebai-graph/configs/model/gat.yml`
107+
> Standard pooling sums the learned representations from all the nodes to produce a single representation which is used for classification.
108+
109+
- **Atom-Augmented Node Pooling**: `--model=../python-chebai-graph/configs/model/gat_aug_aagpool.yml`
110+
> With this pooling stratergy, the learned representations are first separated into **two distinct sets**: those from atom nodes and those from all artificial nodes (both functional groups and the graph node). The representations within each set are aggregated separately (using summation) to yield two distinct single vectors. These two resulting vectors are then concatenated before being passed to the classification layer.
111+
112+
- **Atom–Motif–Graph Node Pooling**: `--model=../python-chebai-graph/configs/model/gat_aug_amgpool.yml`
113+
> This approach employs a finer granularity of separation, distinguishing learned representations into **three distinct sets**: atom nodes, Functional Group (FG) nodes, and the single graph node. Summation is performed separately on the atom node set and the FG node set, yielding two vectors. These two vectors are then concatenated along with the single vector corresponding to the graph node before the final linear layer.
114+
115+
#### GAT-specific hyperparameters
116+
117+
- **Number of message-passing layers**: `--model.config.num_layers=5` (default: 4)
118+
- **Attention heads**: `--model.config.heads=4` (default: 8)
119+
> **Note**: The number of heads should be divisible by the output channels (or hidden channels if output channels are not specified).
120+
- **Use GATv2**: `--model.config.v2=True` (default: False)
121+
> **Note**: GATv2 addresses the limitation of static attention in GAT by introducing a dynamic attention mechanism. For further details, please refer to the [original GATv2 paper](https://arxiv.org/abs/2105.14491).
122+
123+
#### **ResGated Architecture**
124+
125+
To use a ResGated GNN model, choose **one** of the following configs:
126+
127+
- **Atom–Motif–Graph Node Pooling**: `--model=../python-chebai-graph/configs/model/res_aug_amgpool.yml`
128+
- **Atom-Augmented Node Pooling**: `--model=../python-chebai-graph/configs/model/res_aug_aagpool.yml`
129+
- **Standard Pooling**: `--model=../python-chebai-graph/configs/model/resgated.yml`
130+
131+
#### **Common Hyperparameters**
132+
133+
These can be used for both GAT and ResGated architectures:
134+
135+
- **Dropout**: `--model.config.dropout=0.1` (default: 0)
136+
- **Number of final linear layers**: `--model.n_linear_layers=2` (default: 1)
137+
138+
## Random Node Initialization
139+
140+
### Static Node Initialization
141+
142+
In this type of node initialization, the node features (and/or edge features) of the given molecular graph are initialized only once during dataset creation with the given initialization scheme.
143+
144+
```bash
145+
python -m chebai fit --trainer=configs/training/default_trainer.yml --trainer.logger=configs/training/wandb_logger.yml --model=../python-chebai-graph/configs/model/resgated.yml --model.config.in_channels=203 --model.config.edge_dim=11 --model.train_metrics=configs/metrics/micro-macro-f1.yml --model.test_metrics=configs/metrics/micro-macro-f1.yml --model.val_metrics=configs/metrics/micro-macro-f1.yml --data=../python-chebai-graph/configs/data/chebi50_graph_properties.yml --data.pad_node_features=45 --data.pad_edge_features=4 --data.init_args.batch_size=128 --trainer.accumulate_grad_batches=4 --data.init_args.num_workers=10 --data.init_args.persistent_workers=False --model.pass_loss_kwargs=false --data.init_args.chebi_version=241 --trainer.min_epochs=200 --trainer.max_epochs=200 --model.criterion=configs/loss/bce.yml --trainer.logger.init_args.name=gni_res_props+zeros_s0
146+
```
147+
148+
In the above command, for each node we use the 158 node features (corresponding the node properties defined in `chebi50_graph_properties.yml`) which are retrieved from RDKit and additional 45 additional features (specified by `--data.pad_node_features=45`) drawn from a normal distribution (default).
149+
150+
You can change the distribution from which additional are drawn by using the following config in above command: `--data.distribution=zeros`
151+
152+
Available distributions: `"normal", "uniform", "xavier_normal", "xavier_uniform", "zeros"`
153+
154+
155+
Similarly, each edge is initialized with 7 RDKit features and 4 additional features drawn from the given distribution.
156+
157+
158+
If you want all node (and edge) features to be drawn from a given distribution (i.e., ignore RDKit features), use: `--data=../python-chebai-graph/configs/data/chebi50_static_gni.yml`
159+
160+
161+
Refer to the data class code for details.
162+
163+
164+
### Dynamic Node Initialization
165+
166+
In this type of node initialization, the node features (and/or edge features) of the molecular graph are initialized at **each forward pass** of the model using the given initialization scheme.
167+
168+
169+
170+
Currently, dynamic node initialization is implemented only for the **resgated** architecture by specifying: `--model=../python-chebai-graph/configs/model/resgated_dynamic_gni.yml`
171+
172+
To keep RDKit features and *add* dynamically initialized features use the following config in the command:
173+
174+
```
175+
--model.config.complete_randomness=False
176+
--model.config.pad_node_features=45
177+
```
178+
179+
The additional features are drawn from normal distribution (default). You can change it using:`--model.config.distribution=uniform`
180+
181+
If all features should be initialized from the given distribution, remove the complete_randomness flag (default is True).
182+
183+
184+
Please find below the command for a typical dynamic node initialization:
185+
186+
```bash
187+
python -m chebai fit --trainer=configs/training/default_trainer.yml --trainer.logger=configs/training/wandb_logger.yml --model=../python-chebai-graph/configs/model/resgated_dynamic_gni.yml --model.config.in_channels=203 --model.config.edge_dim=11 --model.config.complete_randomness=False --model.config.pad_node_features=45 --model.config.pad_edge_features=4 --model.train_metrics=configs/metrics/micro-macro-f1.yml --model.test_metrics=configs/metrics/micro-macro-f1.yml --model.val_metrics=configs/metrics/micro-macro-f1.yml --data=../python-chebai-graph/configs/data/chebi50_graph_properties.yml --data.init_args.batch_size=128 --trainer.accumulate_grad_batches=4 --data.init_args.num_workers=10 --data.init_args.persistent_workers=False --model.pass_loss_kwargs=false --data.init_args.chebi_version=241 --trainer.min_epochs=200 --trainer.max_epochs=200 --model.criterion=configs/loss/bce.yml --trainer.logger.init_args.name=gni_dres_props+rand_s0
188+
```

chebai_graph/models/dynamic_gni.py

Lines changed: 55 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,23 @@
1+
"""
2+
ResGatedDynamicGNIGraphPred
3+
------------------------------------------------
4+
5+
Module providing a ResGated GNN model that applies Random Node Initialization
6+
(RNI) dynamically at each forward pass. This follows the approach from:
7+
8+
Abboud, R., et al. (2020). "The surprising power of graph neural networks with
9+
random node initialization." arXiv preprint arXiv:2010.01179.
10+
11+
The module exposes:
12+
- ResGatedDynamicGNI: a model that can either completely replace node/edge
13+
features with random tensors each forward pass or pad existing features with
14+
additional random features.
15+
- ResGatedDynamicGNIGraphPred: a thin wrapper that instantiates the above for
16+
graph-level prediction pipelines.
17+
"""
18+
19+
__all__ = ["ResGatedDynamicGNIGraphPred"]
20+
121
from typing import Any
222

323
import torch
@@ -14,12 +34,37 @@
1434

1535
class ResGatedDynamicGNI(GraphModelBase):
1636
"""
17-
Base model class for applying ResGatedGraphConv layers to graph-structured data
18-
with dynamic initialization of features for nodes and edges.
19-
20-
Args:
21-
config (dict): Configuration dictionary containing model hyperparameters.
22-
**kwargs: Additional keyword arguments for parent class.
37+
ResGated GNN with dynamic Random Node Initialization (RNI).
38+
39+
This model supports two modes controlled by the `config`:
40+
41+
- complete_randomness (bool-like): If True, **replace** node and edge
42+
features entirely with randomly initialized tensors each forward pass.
43+
If False, the model **pads** existing features with extra randomly
44+
initialized features on-the-fly.
45+
46+
- pad_node_features (int, optional): Number of random columns to append
47+
to each node feature vector when `complete_randomness` is False.
48+
49+
- pad_edge_features (int, optional): Number of random columns to append
50+
to each edge feature vector when `complete_randomness` is False.
51+
52+
- distribution (str): Distribution for random initialization. Must be one
53+
of RandomFeatureInitializationReader.DISTRIBUTIONS.
54+
55+
Parameters
56+
----------
57+
config : Dict[str, Any]
58+
Configuration dictionary containing model hyperparameters. Expected keys
59+
used by this class:
60+
- distribution (optional, default "normal")
61+
- complete_randomness (optional, default "True")
62+
- pad_node_features (optional, int)
63+
- pad_edge_features (optional, int)
64+
Keys required by GraphModelBase (e.g., in_channels, hidden_channels,
65+
out_channels, num_layers, edge_dim) should also be present.
66+
**kwargs : Any
67+
Additional keyword arguments forwarded to GraphModelBase.
2368
"""
2469

2570
def __init__(self, config: dict[str, Any], **kwargs: Any):
@@ -96,6 +141,8 @@ def forward(self, batch: dict[str, Any]) -> Tensor:
96141

97142
new_x = None
98143
new_edge_attr = None
144+
145+
# If replacing features entirely with random values
99146
if self.complete_randomness:
100147
new_x = torch.empty(
101148
graph_data.x.shape[0], graph_data.x.shape[1], device=self.device
@@ -110,6 +157,8 @@ def forward(self, batch: dict[str, Any]) -> Tensor:
110157
RandomFeatureInitializationReader.random_gni(
111158
new_edge_attr, self.distribution
112159
)
160+
161+
# If padding existing features with additional random columns
113162
else:
114163
if self.pad_node_features is not None:
115164
pad_node = torch.empty(

0 commit comments

Comments
 (0)