Skip to content

Commit a4e0165

Browse files
HFooladiclaude
andcommitted
docs: update changelog, roadmap, and CLAUDE.md for MPNN
- Add MPNN to CHANGELOG.md [Unreleased] section - Mark Phase 3.1 MPNN as complete in roadmap with acceptance criteria - Add MPNN to CLAUDE.md: data flow, key files, API section, testing Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent 801c2e2 commit a4e0165

File tree

3 files changed

+53
-43
lines changed

3 files changed

+53
-43
lines changed

CHANGELOG.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
88
## [Unreleased]
99

1010
### Added
11+
- **Message Passing Neural Network (MPNN)** (`molax/models/mpnn.py`)
12+
- `UncertaintyMPNN` model that leverages edge features (bond information)
13+
- `MPNNConfig` with configurable aggregation (sum, mean, max)
14+
- `MessageFunction` and `MessagePassingLayer` components
15+
- Training utilities: `train_mpnn_step`, `eval_mpnn_step`, `get_mpnn_uncertainties`
16+
- Same API as `UncertaintyGCN` for drop-in replacement with acquisition functions
17+
- Comprehensive tests (32 tests) and demo example
1118

1219
### Changed
1320

CLAUDE.md

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ mkdocs serve
2626
# Run examples
2727
python examples/simple_active_learning.py
2828
python examples/active_learning_benchmark.py
29+
python examples/mpnn_demo.py
2930
python examples/ensemble_demo.py
3031
python examples/evidential_demo.py
3132
```
@@ -65,7 +66,7 @@ SMILES string
6566
jraph.GraphsTuple (single molecule)
6667
↓ jraph.batch()
6768
jraph.GraphsTuple (batched - all molecules as one graph)
68-
↓ UncertaintyGCN / DeepEnsemble / EvidentialGCN
69+
↓ UncertaintyGCN / UncertaintyMPNN / DeepEnsemble / EvidentialGCN
6970
(mean, variance) predictions
7071
```
7172

@@ -74,6 +75,7 @@ jraph.GraphsTuple (batched - all molecules as one graph)
7475
| File | Purpose |
7576
|------|---------|
7677
| `molax/models/gcn.py` | `GCNConfig`, `UncertaintyGCN`, `MolecularGCN`, `train_step`, `eval_step` |
78+
| `molax/models/mpnn.py` | `MPNNConfig`, `UncertaintyMPNN` for edge-aware message passing |
7779
| `molax/models/ensemble.py` | `EnsembleConfig`, `DeepEnsemble` for ensemble uncertainty |
7880
| `molax/models/evidential.py` | `EvidentialConfig`, `EvidentialGCN` for evidential uncertainty |
7981
| `molax/utils/data.py` | `MolecularDataset`, `smiles_to_jraph`, `batch_graphs` |
@@ -132,6 +134,25 @@ model = EvidentialGCN(config, rngs=nnx.Rngs(0))
132134
mean, aleatoric_var, epistemic_var = model(batched_graphs, training=False)
133135
```
134136

137+
### MPNN API
138+
139+
```python
140+
from molax.models.mpnn import MPNNConfig, UncertaintyMPNN
141+
142+
config = MPNNConfig(
143+
node_features=6,
144+
edge_features=1, # Bond type feature
145+
hidden_features=[64, 64],
146+
out_features=1,
147+
aggregation="sum", # or "mean", "max"
148+
dropout_rate=0.1,
149+
)
150+
model = UncertaintyMPNN(config, rngs=nnx.Rngs(0))
151+
152+
# Same API as UncertaintyGCN - uses edge features in message passing
153+
mean, variance = model(batched_graphs, training=False)
154+
```
155+
135156
### Calibration Metrics
136157

137158
```python
@@ -163,7 +184,8 @@ optimizer.update(model, grads)
163184

164185
```bash
165186
pytest tests/ -v # All tests
166-
pytest tests/test_gcn.py -v # Model tests
187+
pytest tests/test_gcn.py -v # GCN model tests
188+
pytest tests/test_mpnn.py -v # MPNN model tests
167189
pytest tests/test_ensemble.py -v # Ensemble tests
168190
pytest tests/test_evidential.py -v # Evidential tests
169191
pytest tests/test_acquisition.py -v # Acquisition tests

docs/roadmap.md

Lines changed: 22 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -384,7 +384,9 @@ def expected_gradient_length(model, graphs, labels_placeholder):
384384

385385
Multiple architectures capture different inductive biases about molecular structure.
386386

387-
### 3.1 Message Passing Neural Network (MPNN)
387+
### 3.1 Message Passing Neural Network (MPNN) ✅
388+
389+
**Status:** Implemented in `molax/models/mpnn.py`
388390

389391
**What:** Generalized framework with explicit edge feature processing.
390392

@@ -394,51 +396,30 @@ Multiple architectures capture different inductive biases about molecular struct
394396

395397
```python
396398
# molax/models/mpnn.py
397-
from dataclasses import dataclass
399+
from molax.models.mpnn import MPNNConfig, UncertaintyMPNN
400+
401+
config = MPNNConfig(
402+
node_features=6,
403+
edge_features=1, # Bond type feature
404+
hidden_features=[64, 64],
405+
out_features=1,
406+
aggregation="sum", # or "mean", "max"
407+
dropout_rate=0.1,
408+
)
409+
model = UncertaintyMPNN(config, rngs=nnx.Rngs(0))
398410

399-
@dataclass
400-
class MPNNConfig:
401-
node_features: int
402-
edge_features: int # NEW: bond features
403-
hidden_features: int = 64
404-
message_passes: int = 3
405-
out_features: int = 1
406-
407-
class MPNN(nnx.Module):
408-
def __init__(self, config: MPNNConfig, rngs: nnx.Rngs):
409-
self.message_fn = nnx.Linear(
410-
config.node_features + config.edge_features,
411-
config.hidden_features,
412-
rngs=rngs
413-
)
414-
self.update_fn = nnx.GRU(
415-
config.hidden_features,
416-
config.hidden_features,
417-
rngs=rngs
418-
)
419-
# ... readout layers
411+
# Same API as UncertaintyGCN
412+
mean, variance = model(batched_graphs, training=False)
420413

421-
def __call__(self, graphs, training: bool = False):
422-
nodes = graphs.nodes
423-
edges = graphs.edges # Bond features
424-
425-
for _ in range(self.config.message_passes):
426-
# Message: combine source node + edge features
427-
messages = self.message_fn(
428-
jnp.concatenate([nodes[graphs.senders], edges], axis=-1)
429-
)
430-
# Aggregate: sum messages per node
431-
aggregated = jraph.segment_sum(messages, graphs.receivers, len(nodes))
432-
# Update: GRU update
433-
nodes = self.update_fn(nodes, aggregated)
434-
435-
return self.readout(nodes, graphs)
414+
# Extract embeddings for Core-Set selection
415+
embeddings = model.extract_embeddings(batched_graphs)
436416
```
437417

438418
**Acceptance Criteria:**
439-
- [ ] MPNN with edge feature support
440-
- [ ] Configurable message/update functions
441-
- [ ] GRU and MLP update variants
419+
- [x] MPNN with edge feature support
420+
- [x] Configurable aggregation (sum, mean, max)
421+
- [x] Same API as UncertaintyGCN for acquisition function compatibility
422+
- [x] MC Dropout uncertainty via `get_mpnn_uncertainties()`
442423

443424
---
444425

0 commit comments

Comments
 (0)