Skip to content

Commit 801c2e2

Browse files
HFooladiclaude
andcommitted
feat: implement Phase 3.1 Message Passing Neural Network (MPNN)
Add MPNN model that leverages edge features (bond information) in message passing computation, enabling edge-aware molecular property prediction. Key components: - MPNNConfig: Configuration with edge_features and aggregation params - MessageFunction: MLP computing messages from sender+receiver+edge features - MessagePassingLayer: Configurable aggregation (sum/mean/max) - UncertaintyMPNN: Full model with same API as UncertaintyGCN Includes comprehensive tests (32 tests) and demo example on ESOL dataset. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent a79d451 commit 801c2e2

File tree

4 files changed

+1141
-0
lines changed

4 files changed

+1141
-0
lines changed

examples/mpnn_demo.py

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
"""MPNN Demo: Edge-aware molecular property prediction.
2+
3+
This script demonstrates how to use the Message Passing Neural Network (MPNN)
4+
for molecular property prediction on the ESOL dataset. Unlike GCN, MPNN
5+
leverages edge features (bond information) in the message passing computation.
6+
"""
7+
8+
from pathlib import Path
9+
10+
import flax.nnx as nnx
11+
import jax.numpy as jnp
12+
import jraph
13+
14+
from molax.models.mpnn import (
15+
MPNNConfig,
16+
UncertaintyMPNN,
17+
create_mpnn_optimizer,
18+
eval_mpnn_step,
19+
train_mpnn_step,
20+
)
21+
from molax.utils.data import MolecularDataset
22+
23+
# Configuration
24+
DATASET_PATH = Path(__file__).parent.parent / "datasets" / "esol.csv"
25+
N_EPOCHS = 100
26+
LEARNING_RATE = 5e-4 # Lower learning rate for stability
27+
28+
print("=" * 60)
29+
print("MPNN Demo: Edge-Aware Molecular Property Prediction")
30+
print("=" * 60)
31+
32+
# Load dataset
33+
print("\nLoading ESOL dataset...")
34+
dataset = MolecularDataset(DATASET_PATH)
35+
train_data, test_data = dataset.split(test_size=0.2, seed=42)
36+
print(f"Train: {len(train_data)} molecules, Test: {len(test_data)} molecules")
37+
38+
# Show edge feature info
39+
sample_graph = train_data.graphs[0]
40+
print("\nGraph features:")
41+
print(f" Node features: {sample_graph.nodes.shape[1]} (atom properties)")
42+
print(f" Edge features: {sample_graph.edges.shape[1]} (bond type)")
43+
44+
# Batch all data
45+
print("\nBatching data...")
46+
all_train_graphs = jraph.batch(train_data.graphs)
47+
all_train_labels = train_data.labels
48+
all_test_graphs = jraph.batch(test_data.graphs)
49+
all_test_labels = test_data.labels
50+
51+
n_train = len(train_data)
52+
n_test = len(test_data)
53+
train_mask = jnp.ones(n_train, dtype=bool)
54+
test_mask = jnp.ones(n_test, dtype=bool)
55+
56+
# Create MPNN model
57+
print("\nCreating MPNN model...")
58+
config = MPNNConfig(
59+
node_features=train_data.n_node_features,
60+
edge_features=1, # Bond type feature
61+
hidden_features=[64, 64],
62+
out_features=1,
63+
aggregation="sum",
64+
dropout_rate=0.1,
65+
)
66+
model = UncertaintyMPNN(config, rngs=nnx.Rngs(0))
67+
optimizer = create_mpnn_optimizer(model, learning_rate=LEARNING_RATE)
68+
69+
print(f" Hidden layers: {config.hidden_features}")
70+
print(f" Aggregation: {config.aggregation}")
71+
print(f" Dropout rate: {config.dropout_rate}")
72+
73+
# Training loop
74+
print("\nTraining MPNN...")
75+
print("-" * 40)
76+
77+
for epoch in range(N_EPOCHS):
78+
# Training step
79+
train_loss = train_mpnn_step(
80+
model, optimizer, all_train_graphs, all_train_labels, train_mask
81+
)
82+
83+
# Evaluation every 20 epochs
84+
if (epoch + 1) % 20 == 0:
85+
test_mse, _ = eval_mpnn_step(model, all_test_graphs, all_test_labels, test_mask)
86+
test_rmse = jnp.sqrt(test_mse)
87+
print(
88+
f"Epoch {epoch + 1:3d}: Train Loss = {float(train_loss):.4f}, "
89+
f"Test RMSE = {float(test_rmse):.4f}"
90+
)
91+
92+
# Final evaluation
93+
print("-" * 40)
94+
test_mse, predictions = eval_mpnn_step(
95+
model, all_test_graphs, all_test_labels, test_mask
96+
)
97+
test_rmse = jnp.sqrt(test_mse)
98+
99+
# Get predictions with uncertainty
100+
mean, variance = model(all_test_graphs, training=False)
101+
mean = mean.squeeze(-1)
102+
variance = variance.squeeze(-1)
103+
104+
print("\nFinal Results:")
105+
print(f" Test RMSE: {float(test_rmse):.4f}")
106+
print(f" Mean predicted variance: {float(jnp.mean(variance[:n_test])):.4f}")
107+
108+
# Show some predictions
109+
print("\nSample predictions (first 5 test molecules):")
110+
print(f"{'Actual':>10} {'Predicted':>10} {'Std Dev':>10}")
111+
for i in range(min(5, n_test)):
112+
actual = float(all_test_labels[i])
113+
pred = float(mean[i])
114+
std = float(jnp.sqrt(variance[i]))
115+
print(f"{actual:>10.3f} {pred:>10.3f} {std:>10.3f}")
116+
117+
print("\n" + "=" * 60)
118+
print("MPNN demo completed successfully!")
119+
print("=" * 60)

molax/models/__init__.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,16 @@
1717
train_evidential_step,
1818
)
1919
from .gcn import MolecularGCN, UncertaintyGCN
20+
from .mpnn import (
21+
MessageFunction,
22+
MessagePassingLayer,
23+
MPNNConfig,
24+
UncertaintyMPNN,
25+
create_mpnn_optimizer,
26+
eval_mpnn_step,
27+
get_mpnn_uncertainties,
28+
train_mpnn_step,
29+
)
2030

2131
__all__ = [
2232
"MolecularGCN",
@@ -35,4 +45,12 @@
3545
"train_evidential_step",
3646
"eval_evidential_step",
3747
"get_evidential_uncertainties",
48+
"MPNNConfig",
49+
"MessageFunction",
50+
"MessagePassingLayer",
51+
"UncertaintyMPNN",
52+
"create_mpnn_optimizer",
53+
"train_mpnn_step",
54+
"eval_mpnn_step",
55+
"get_mpnn_uncertainties",
3856
]

0 commit comments

Comments
 (0)