Skip to content

Commit 37c76c1

Browse files
authored
Enhance README with augmented graphs explanation
Expanded the section on augmented graphs to explain the use of artificial nodes representing functional groups. Added details on connection schemes and provided commands for model and data configuration.
1 parent 524f665 commit 37c76c1

File tree

1 file changed

+26
-8
lines changed

1 file changed

+26
-8
lines changed

README.md

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,20 @@ python -m chebai fit --trainer=configs/training/default_trainer.yml --trainer.lo
7878

7979
## Augmented Graphs
8080

81-
Below is the command for the model and data configuration that achieved the best classification performance using augmented graphs.
8281

82+
Graph Neural Networks (GNNs) often fail to explicitly leverage the chemically meaningful substructures present within molecules (i.e. **functional groups (FGs)**). To make this implicit information explicitly accessible to GNNs, we augment molecular graphs with **artificial nodes** that represent these substructures. The resulting graph are referred to as **augmented graphs**.
83+
> Note: Rings are also treated as functional groups in our work.
84+
85+
In these augmented graphs, each functional group node is connected to the atoms that constitute the group. Additionally, two functional group nodes are connected if any atom belonging to one group shares a bond with an atom from the other group. We further introduce a **graph node**, an extra node connected to all functional group nodes.
86+
87+
Among all the connection schemes we evaluated, this configuration delivered the strongest performance. We denote it using the abbreviation **WFG_WFGE_WGN** in our work and is shown in below figure.
88+
89+
<img width="1220" height="668" alt="mol_to_aug_mol" src="https://github.com/user-attachments/assets/0aba6b80-765b-45a6-913a-7d628f14a5db" />
90+
91+
</br>
92+
</br>
93+
94+
Below is the command for the model and data configuration that achieved the best classification performance using augmented graphs.
8395

8496
```bash
8597
python -m chebai fit --trainer=configs/training/default_trainer.yml --trainer.logger=configs/training/wandb_logger.yml --model=../python-chebai-graph/configs/model/gat_aug_amgpool.yml --model.train_metrics=configs/metrics/micro-macro-f1.yml --model.test_metrics=configs/metrics/micro-macro-f1.yml --model.val_metrics=configs/metrics/micro-macro-f1.yml --model.config.v2=True --data=../python-chebai-graph/configs/data/chebi50_aug_prop_as_per_node.yml --data.init_args.batch_size=128 --trainer.accumulate_grad_batches=4 --data.init_args.num_workers=10 --model.pass_loss_kwargs=false --data.init_args.chebi_version=241 --trainer.min_epochs=200 --trainer.max_epochs=200 --model.criterion=configs/loss/bce.yml --trainer.logger.init_args.name=gatv2_amg_s0
@@ -91,17 +103,23 @@ python -m chebai fit --trainer=configs/training/default_trainer.yml --trainer.lo
91103

92104
To use a GAT-based model, choose **one** of the following configs:
93105

94-
- **Atom–Motif–Graph Node Pooling**: `--model=../python-chebai-graph/configs/model/gat_aug_amgpool.yml`
95-
- **Atom-Augmented Node Pooling**: `--model=../python-chebai-graph/configs/model/gat_aug_aagpool.yml`
96106
- **Standard Pooling**: `--model=../python-chebai-graph/configs/model/gat.yml`
107+
> Standard pooling sums the learned representations from all the nodes to produce a single representation which is used for classification.
108+
109+
- **Atom-Augmented Node Pooling**: `--model=../python-chebai-graph/configs/model/gat_aug_aagpool.yml`
110+
> With this pooling stratergy, the learned representations are first separated into **two distinct sets**: those from atom nodes and those from all artificial nodes (both functional groups and the graph node). The representations within each set are aggregated separately (using summation) to yield two distinct single vectors. These two resulting vectors are then concatenated before being passed to the classification layer.
111+
112+
- **Atom–Motif–Graph Node Pooling**: `--model=../python-chebai-graph/configs/model/gat_aug_amgpool.yml`
113+
> This approach employs a finer granularity of separation, distinguishing learned representations into **three distinct sets**: atom nodes, Functional Group (FG) nodes, and the single graph node. Summation is performed separately on the atom node set and the FG node set, yielding two vectors. These two vectors are then concatenated along with the single vector corresponding to the graph node before the final linear layer.
97114
98115
#### GAT-specific hyperparameters
99116

100117
- **Number of message-passing layers**: `--model.config.num_layers=5` (default: 4)
101118
- **Attention heads**: `--model.config.heads=4` (default: 8)
102-
> Note: The number of heads should be divisible by the output channels (or hidden channels if output channels are not specified).
119+
> **Note**: The number of heads should be divisible by the output channels (or hidden channels if output channels are not specified).
103120
- **Use GATv2**: `--model.config.v2=True` (default: False)
104-
121+
> **Note**: GATv2 addresses the limitation of static attention in GAT by introducing a dynamic attention mechanism. For further details, please refer to the [original GATv2 paper](https://arxiv.org/abs/2105.14491).
122+
105123
#### **ResGated Architecture**
106124

107125
To use a ResGated GNN model, choose **one** of the following configs:
@@ -117,9 +135,9 @@ These can be used for both GAT and ResGated architectures:
117135
- **Dropout**: `--model.config.dropout=0.1` (default: 0)
118136
- **Number of final linear layers**: `--model.n_linear_layers=2` (default: 1)
119137

120-
# Random Node Initialization
138+
## Random Node Initialization
121139

122-
## Static Node Initialization
140+
### Static Node Initialization
123141

124142
In this type of node initialization, the node features (and/or edge features) of the given molecular graph are initialized only once during dataset creation with the given initialization scheme.
125143

@@ -143,7 +161,7 @@ If you want all node (and edge) features to be drawn from a given distribution (
143161
Refer to the data class code for details.
144162

145163

146-
## Dynamic Node Initialization
164+
### Dynamic Node Initialization
147165

148166
In this type of node initialization, the node features (and/or edge features) of the molecular graph are initialized at **each forward pass** of the model using the given initialization scheme.
149167

0 commit comments

Comments
 (0)