Skip to content

Commit 521517c

Browse files
HFooladiclaude
andcommitted
feat: implement Phase 3.2 Graph Attention Network (GAT)
Add GAT model with multi-head attention for adaptive neighbor weighting. Key features include optional edge feature incorporation in attention, attention dropout for regularization, and API compatibility with UncertaintyGCN/UncertaintyMPNN for drop-in replacement. - GATConfig with n_heads, edge_features, attention_dropout_rate params - GATAttention for single-head concat-based attention mechanism - GATLayer for multi-head attention with head concatenation/averaging - UncertaintyGAT with dual heads (mean/variance) and log_var clipping - Training utilities: create_gat_optimizer, train_gat_step, eval_gat_step - MC Dropout uncertainty via get_gat_uncertainties() - 36 comprehensive tests covering all components - Demo example comparing GAT vs GCN vs MPNN on ESOL dataset Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent 46d37c0 commit 521517c

File tree

7 files changed

+1472
-25
lines changed

7 files changed

+1472
-25
lines changed

CHANGELOG.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
88
## [Unreleased]
99

1010
### Added
11+
- **Graph Attention Network (GAT)** (`molax/models/gat.py`)
12+
- `UncertaintyGAT` model with multi-head attention for adaptive neighbor weighting
13+
- `GATConfig` with configurable n_heads, edge_features, attention_dropout_rate
14+
- `GATAttention` and `GATLayer` components
15+
- Training utilities: `train_gat_step`, `eval_gat_step`, `get_gat_uncertainties`
16+
- Same API as `UncertaintyGCN`/`UncertaintyMPNN` for drop-in replacement
17+
- Optional edge feature incorporation in attention computation
18+
- Comprehensive tests and demo example
19+
1120
- **Message Passing Neural Network (MPNN)** (`molax/models/mpnn.py`)
1221
- `UncertaintyMPNN` model that leverages edge features (bond information)
1322
- `MPNNConfig` with configurable aggregation (sum, mean, max)

CLAUDE.md

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ mkdocs serve
2727
python examples/simple_active_learning.py
2828
python examples/active_learning_benchmark.py
2929
python examples/mpnn_demo.py
30+
python examples/gat_demo.py
3031
python examples/ensemble_demo.py
3132
python examples/evidential_demo.py
3233
```
@@ -66,7 +67,7 @@ SMILES string
6667
jraph.GraphsTuple (single molecule)
6768
↓ jraph.batch()
6869
jraph.GraphsTuple (batched - all molecules as one graph)
69-
↓ UncertaintyGCN / UncertaintyMPNN / DeepEnsemble / EvidentialGCN
70+
↓ UncertaintyGCN / UncertaintyMPNN / UncertaintyGAT / DeepEnsemble / EvidentialGCN
7071
(mean, variance) predictions
7172
```
7273

@@ -76,6 +77,7 @@ jraph.GraphsTuple (batched - all molecules as one graph)
7677
|------|---------|
7778
| `molax/models/gcn.py` | `GCNConfig`, `UncertaintyGCN`, `MolecularGCN`, `train_step`, `eval_step` |
7879
| `molax/models/mpnn.py` | `MPNNConfig`, `UncertaintyMPNN` for edge-aware message passing |
80+
| `molax/models/gat.py` | `GATConfig`, `UncertaintyGAT` for attention-based message passing |
7981
| `molax/models/ensemble.py` | `EnsembleConfig`, `DeepEnsemble` for ensemble uncertainty |
8082
| `molax/models/evidential.py` | `EvidentialConfig`, `EvidentialGCN` for evidential uncertainty |
8183
| `molax/utils/data.py` | `MolecularDataset`, `smiles_to_jraph`, `batch_graphs` |
@@ -153,6 +155,27 @@ model = UncertaintyMPNN(config, rngs=nnx.Rngs(0))
153155
mean, variance = model(batched_graphs, training=False)
154156
```
155157

158+
### GAT API
159+
160+
```python
161+
from molax.models.gat import GATConfig, UncertaintyGAT
162+
163+
config = GATConfig(
164+
node_features=6,
165+
edge_features=1, # Optional: include edge features in attention
166+
hidden_features=[64, 64],
167+
out_features=1,
168+
n_heads=4, # Multi-head attention
169+
dropout_rate=0.1,
170+
attention_dropout_rate=0.1,
171+
negative_slope=0.2, # LeakyReLU slope
172+
)
173+
model = UncertaintyGAT(config, rngs=nnx.Rngs(0))
174+
175+
# Same API as UncertaintyGCN/UncertaintyMPNN - uses attention for aggregation
176+
mean, variance = model(batched_graphs, training=False)
177+
```
178+
156179
### Calibration Metrics
157180

158181
```python
@@ -186,6 +209,7 @@ optimizer.update(model, grads)
186209
pytest tests/ -v # All tests
187210
pytest tests/test_gcn.py -v # GCN model tests
188211
pytest tests/test_mpnn.py -v # MPNN model tests
212+
pytest tests/test_gat.py -v # GAT model tests
189213
pytest tests/test_ensemble.py -v # Ensemble tests
190214
pytest tests/test_evidential.py -v # Evidential tests
191215
pytest tests/test_acquisition.py -v # Acquisition tests

docs/roadmap.md

Lines changed: 23 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -423,7 +423,9 @@ embeddings = model.extract_embeddings(batched_graphs)
423423

424424
---
425425

426-
### 3.2 Graph Attention Network (GAT)
426+
### 3.2 Graph Attention Network (GAT) ✅
427+
428+
**Status:** Implemented in `molax/models/gat.py`
427429

428430
**What:** Learn edge importance dynamically via attention mechanism.
429431

@@ -433,35 +435,32 @@ embeddings = model.extract_embeddings(batched_graphs)
433435

434436
```python
435437
# molax/models/gat.py
438+
from molax.models.gat import GATConfig, UncertaintyGAT
436439

437-
class GATLayer(nnx.Module):
438-
def __init__(self, in_features, out_features, n_heads=4, rngs=None):
439-
self.n_heads = n_heads
440-
self.head_dim = out_features // n_heads
441-
442-
self.W = nnx.Linear(in_features, out_features, rngs=rngs)
443-
self.attention = nnx.Linear(2 * self.head_dim, 1, rngs=rngs)
444-
445-
def __call__(self, graphs):
446-
nodes = self.W(graphs.nodes) # (N, out_features)
447-
nodes = nodes.reshape(-1, self.n_heads, self.head_dim)
440+
config = GATConfig(
441+
node_features=6,
442+
edge_features=1, # Optional: include edge features in attention
443+
hidden_features=[64, 64],
444+
out_features=1,
445+
n_heads=4,
446+
dropout_rate=0.1,
447+
attention_dropout_rate=0.1,
448+
negative_slope=0.2,
449+
)
450+
model = UncertaintyGAT(config, rngs=nnx.Rngs(0))
448451

449-
# Attention coefficients
450-
src = nodes[graphs.senders]
451-
dst = nodes[graphs.receivers]
452-
e = self.attention(jnp.concatenate([src, dst], axis=-1))
453-
alpha = jraph.segment_softmax(e, graphs.receivers, len(nodes))
452+
# Same API as UncertaintyGCN/UncertaintyMPNN
453+
mean, variance = model(batched_graphs, training=False)
454454

455-
# Aggregate with attention
456-
messages = alpha * src
457-
out = jraph.segment_sum(messages, graphs.receivers, len(nodes))
458-
return out.reshape(-1, self.n_heads * self.head_dim)
455+
# Extract embeddings for Core-Set selection
456+
embeddings = model.extract_embeddings(batched_graphs)
459457
```
460458

461459
**Acceptance Criteria:**
462-
- [ ] Multi-head attention implementation
463-
- [ ] Edge feature incorporation option
464-
- [ ] Dropout on attention weights
460+
- [x] Multi-head attention implementation
461+
- [x] Edge feature incorporation option
462+
- [x] Dropout on attention weights
463+
- [x] Same API as UncertaintyGCN/UncertaintyMPNN for acquisition function compatibility
465464

466465
---
467466

examples/gat_demo.py

Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
"""GAT Demo: Attention-based molecular property prediction.
2+
3+
This script demonstrates how to use the Graph Attention Network (GAT)
4+
for molecular property prediction on the ESOL dataset. GAT uses learned
5+
attention weights to dynamically weight neighbor contributions.
6+
"""
7+
8+
from pathlib import Path
9+
10+
import flax.nnx as nnx
11+
import jax.numpy as jnp
12+
import jraph
13+
14+
from molax.models.gat import (
15+
GATConfig,
16+
UncertaintyGAT,
17+
create_gat_optimizer,
18+
eval_gat_step,
19+
train_gat_step,
20+
)
21+
from molax.models.gcn import GCNConfig, UncertaintyGCN, create_train_state, train_step
22+
from molax.models.mpnn import (
23+
MPNNConfig,
24+
UncertaintyMPNN,
25+
create_mpnn_optimizer,
26+
train_mpnn_step,
27+
)
28+
from molax.utils.data import MolecularDataset
29+
30+
# Configuration
31+
DATASET_PATH = Path(__file__).parent.parent / "datasets" / "esol.csv"
32+
N_EPOCHS = 100
33+
LEARNING_RATE = 1e-3
34+
WEIGHT_DECAY = 1e-4
35+
MAX_GRAD_NORM = 1.0
36+
37+
print("=" * 60)
38+
print("GAT Demo: Attention-Based Molecular Property Prediction")
39+
print("=" * 60)
40+
41+
# Load dataset
42+
print("\nLoading ESOL dataset...")
43+
dataset = MolecularDataset(DATASET_PATH)
44+
train_data, test_data = dataset.split(test_size=0.2, seed=42)
45+
print(f"Train: {len(train_data)} molecules, Test: {len(test_data)} molecules")
46+
47+
# Show feature info
48+
sample_graph = train_data.graphs[0]
49+
print("\nGraph features:")
50+
print(f" Node features: {sample_graph.nodes.shape[1]} (atom properties)")
51+
print(f" Edge features: {sample_graph.edges.shape[1]} (bond type)")
52+
53+
# Batch all data
54+
print("\nBatching data...")
55+
all_train_graphs = jraph.batch(train_data.graphs)
56+
all_train_labels = train_data.labels
57+
all_test_graphs = jraph.batch(test_data.graphs)
58+
all_test_labels = test_data.labels
59+
60+
n_train = len(train_data)
61+
n_test = len(test_data)
62+
train_mask = jnp.ones(n_train, dtype=bool)
63+
test_mask = jnp.ones(n_test, dtype=bool)
64+
65+
# Create GAT model
66+
print("\nCreating GAT model...")
67+
config = GATConfig(
68+
node_features=train_data.n_node_features,
69+
edge_features=1, # Use bond type in attention
70+
hidden_features=[64, 64],
71+
out_features=1,
72+
n_heads=4,
73+
dropout_rate=0.1,
74+
attention_dropout_rate=0.1,
75+
negative_slope=0.2,
76+
)
77+
model = UncertaintyGAT(config, rngs=nnx.Rngs(0))
78+
optimizer = create_gat_optimizer(
79+
model,
80+
learning_rate=LEARNING_RATE,
81+
weight_decay=WEIGHT_DECAY,
82+
max_grad_norm=MAX_GRAD_NORM,
83+
)
84+
85+
print(f" Hidden layers: {config.hidden_features}")
86+
print(f" Attention heads: {config.n_heads}")
87+
print(f" Edge features in attention: {config.edge_features > 0}")
88+
print(f" Dropout rate: {config.dropout_rate}")
89+
print(f" Attention dropout: {config.attention_dropout_rate}")
90+
print(f" Weight decay: {WEIGHT_DECAY}")
91+
92+
# Training loop
93+
print("\nTraining GAT...")
94+
print("-" * 40)
95+
96+
for epoch in range(N_EPOCHS):
97+
# Training step
98+
train_loss = train_gat_step(
99+
model, optimizer, all_train_graphs, all_train_labels, train_mask
100+
)
101+
102+
# Evaluation every 20 epochs
103+
if (epoch + 1) % 20 == 0:
104+
test_mse, _ = eval_gat_step(model, all_test_graphs, all_test_labels, test_mask)
105+
test_rmse = jnp.sqrt(test_mse)
106+
print(
107+
f"Epoch {epoch + 1:3d}: Train Loss = {float(train_loss):.4f}, "
108+
f"Test RMSE = {float(test_rmse):.4f}"
109+
)
110+
111+
# Final evaluation
112+
print("-" * 40)
113+
test_mse, predictions = eval_gat_step(
114+
model, all_test_graphs, all_test_labels, test_mask
115+
)
116+
test_rmse = jnp.sqrt(test_mse)
117+
118+
# Get predictions with uncertainty
119+
mean, variance = model(all_test_graphs, training=False)
120+
mean = mean.squeeze(-1)
121+
variance = variance.squeeze(-1)
122+
123+
print("\nFinal Results:")
124+
print(f" Test RMSE: {float(test_rmse):.4f}")
125+
print(f" Mean predicted variance: {float(jnp.mean(variance[:n_test])):.4f}")
126+
127+
# Show some predictions
128+
print("\nSample predictions (first 5 test molecules):")
129+
print(f"{'Actual':>10} {'Predicted':>10} {'Std Dev':>10}")
130+
for i in range(min(5, n_test)):
131+
actual = float(all_test_labels[i])
132+
pred = float(mean[i])
133+
std = float(jnp.sqrt(variance[i]))
134+
print(f"{actual:>10.3f} {pred:>10.3f} {std:>10.3f}")
135+
136+
# Compare with GCN and MPNN
137+
print("\n" + "=" * 60)
138+
print("Comparing GAT with GCN and MPNN...")
139+
print("=" * 60)
140+
141+
# Train GCN
142+
gcn_config = GCNConfig(
143+
node_features=train_data.n_node_features,
144+
hidden_features=[64, 64],
145+
out_features=1,
146+
dropout_rate=0.1,
147+
)
148+
gcn = UncertaintyGCN(gcn_config, rngs=nnx.Rngs(0))
149+
gcn_optimizer = create_train_state(gcn, learning_rate=LEARNING_RATE)
150+
151+
print("\nTraining GCN for comparison...")
152+
for epoch in range(N_EPOCHS):
153+
train_step(gcn, gcn_optimizer, all_train_graphs, all_train_labels, train_mask)
154+
155+
gcn_mean, _ = gcn(all_test_graphs, training=False)
156+
gcn_mse = jnp.mean((gcn_mean.squeeze(-1)[:n_test] - all_test_labels[:n_test]) ** 2)
157+
gcn_rmse = jnp.sqrt(gcn_mse)
158+
159+
# Train MPNN
160+
mpnn_config = MPNNConfig(
161+
node_features=train_data.n_node_features,
162+
edge_features=1,
163+
hidden_features=[64, 64],
164+
out_features=1,
165+
aggregation="sum",
166+
dropout_rate=0.1,
167+
)
168+
mpnn = UncertaintyMPNN(mpnn_config, rngs=nnx.Rngs(0))
169+
mpnn_optimizer = create_mpnn_optimizer(mpnn, learning_rate=LEARNING_RATE)
170+
171+
print("Training MPNN for comparison...")
172+
for epoch in range(N_EPOCHS):
173+
train_mpnn_step(
174+
mpnn, mpnn_optimizer, all_train_graphs, all_train_labels, train_mask
175+
)
176+
177+
mpnn_mean, _ = mpnn(all_test_graphs, training=False)
178+
mpnn_mse = jnp.mean((mpnn_mean.squeeze(-1)[:n_test] - all_test_labels[:n_test]) ** 2)
179+
mpnn_rmse = jnp.sqrt(mpnn_mse)
180+
181+
print("\n" + "-" * 40)
182+
print("Model Comparison (Test RMSE):")
183+
print("-" * 40)
184+
print(f" GCN: {float(gcn_rmse):.4f}")
185+
print(f" MPNN: {float(mpnn_rmse):.4f}")
186+
print(f" GAT: {float(test_rmse):.4f}")
187+
188+
print("\n" + "=" * 60)
189+
print("GAT demo completed successfully!")
190+
print("=" * 60)

molax/models/__init__.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,16 @@
1616
get_evidential_uncertainties,
1717
train_evidential_step,
1818
)
19+
from .gat import (
20+
GATAttention,
21+
GATConfig,
22+
GATLayer,
23+
UncertaintyGAT,
24+
create_gat_optimizer,
25+
eval_gat_step,
26+
get_gat_uncertainties,
27+
train_gat_step,
28+
)
1929
from .gcn import MolecularGCN, UncertaintyGCN
2030
from .mpnn import (
2131
MessageFunction,
@@ -45,6 +55,14 @@
4555
"train_evidential_step",
4656
"eval_evidential_step",
4757
"get_evidential_uncertainties",
58+
"GATConfig",
59+
"GATAttention",
60+
"GATLayer",
61+
"UncertaintyGAT",
62+
"create_gat_optimizer",
63+
"train_gat_step",
64+
"eval_gat_step",
65+
"get_gat_uncertainties",
4866
"MPNNConfig",
4967
"MessageFunction",
5068
"MessagePassingLayer",

0 commit comments

Comments
 (0)