|
| 1 | +"""MPNN Demo: Edge-aware molecular property prediction. |
| 2 | +
|
| 3 | +This script demonstrates how to use the Message Passing Neural Network (MPNN) |
| 4 | +for molecular property prediction on the ESOL dataset. Unlike GCN, MPNN |
| 5 | +leverages edge features (bond information) in the message passing computation. |
| 6 | +""" |
| 7 | + |
| 8 | +from pathlib import Path |
| 9 | + |
| 10 | +import flax.nnx as nnx |
| 11 | +import jax.numpy as jnp |
| 12 | +import jraph |
| 13 | + |
| 14 | +from molax.models.mpnn import ( |
| 15 | + MPNNConfig, |
| 16 | + UncertaintyMPNN, |
| 17 | + create_mpnn_optimizer, |
| 18 | + eval_mpnn_step, |
| 19 | + train_mpnn_step, |
| 20 | +) |
| 21 | +from molax.utils.data import MolecularDataset |
| 22 | + |
| 23 | +# Configuration |
| 24 | +DATASET_PATH = Path(__file__).parent.parent / "datasets" / "esol.csv" |
| 25 | +N_EPOCHS = 100 |
| 26 | +LEARNING_RATE = 5e-4 # Lower learning rate for stability |
| 27 | + |
| 28 | +print("=" * 60) |
| 29 | +print("MPNN Demo: Edge-Aware Molecular Property Prediction") |
| 30 | +print("=" * 60) |
| 31 | + |
| 32 | +# Load dataset |
| 33 | +print("\nLoading ESOL dataset...") |
| 34 | +dataset = MolecularDataset(DATASET_PATH) |
| 35 | +train_data, test_data = dataset.split(test_size=0.2, seed=42) |
| 36 | +print(f"Train: {len(train_data)} molecules, Test: {len(test_data)} molecules") |
| 37 | + |
| 38 | +# Show edge feature info |
| 39 | +sample_graph = train_data.graphs[0] |
| 40 | +print("\nGraph features:") |
| 41 | +print(f" Node features: {sample_graph.nodes.shape[1]} (atom properties)") |
| 42 | +print(f" Edge features: {sample_graph.edges.shape[1]} (bond type)") |
| 43 | + |
| 44 | +# Batch all data |
| 45 | +print("\nBatching data...") |
| 46 | +all_train_graphs = jraph.batch(train_data.graphs) |
| 47 | +all_train_labels = train_data.labels |
| 48 | +all_test_graphs = jraph.batch(test_data.graphs) |
| 49 | +all_test_labels = test_data.labels |
| 50 | + |
| 51 | +n_train = len(train_data) |
| 52 | +n_test = len(test_data) |
| 53 | +train_mask = jnp.ones(n_train, dtype=bool) |
| 54 | +test_mask = jnp.ones(n_test, dtype=bool) |
| 55 | + |
| 56 | +# Create MPNN model |
| 57 | +print("\nCreating MPNN model...") |
| 58 | +config = MPNNConfig( |
| 59 | + node_features=train_data.n_node_features, |
| 60 | + edge_features=1, # Bond type feature |
| 61 | + hidden_features=[64, 64], |
| 62 | + out_features=1, |
| 63 | + aggregation="sum", |
| 64 | + dropout_rate=0.1, |
| 65 | +) |
| 66 | +model = UncertaintyMPNN(config, rngs=nnx.Rngs(0)) |
| 67 | +optimizer = create_mpnn_optimizer(model, learning_rate=LEARNING_RATE) |
| 68 | + |
| 69 | +print(f" Hidden layers: {config.hidden_features}") |
| 70 | +print(f" Aggregation: {config.aggregation}") |
| 71 | +print(f" Dropout rate: {config.dropout_rate}") |
| 72 | + |
| 73 | +# Training loop |
| 74 | +print("\nTraining MPNN...") |
| 75 | +print("-" * 40) |
| 76 | + |
| 77 | +for epoch in range(N_EPOCHS): |
| 78 | + # Training step |
| 79 | + train_loss = train_mpnn_step( |
| 80 | + model, optimizer, all_train_graphs, all_train_labels, train_mask |
| 81 | + ) |
| 82 | + |
| 83 | + # Evaluation every 20 epochs |
| 84 | + if (epoch + 1) % 20 == 0: |
| 85 | + test_mse, _ = eval_mpnn_step(model, all_test_graphs, all_test_labels, test_mask) |
| 86 | + test_rmse = jnp.sqrt(test_mse) |
| 87 | + print( |
| 88 | + f"Epoch {epoch + 1:3d}: Train Loss = {float(train_loss):.4f}, " |
| 89 | + f"Test RMSE = {float(test_rmse):.4f}" |
| 90 | + ) |
| 91 | + |
| 92 | +# Final evaluation |
| 93 | +print("-" * 40) |
| 94 | +test_mse, predictions = eval_mpnn_step( |
| 95 | + model, all_test_graphs, all_test_labels, test_mask |
| 96 | +) |
| 97 | +test_rmse = jnp.sqrt(test_mse) |
| 98 | + |
| 99 | +# Get predictions with uncertainty |
| 100 | +mean, variance = model(all_test_graphs, training=False) |
| 101 | +mean = mean.squeeze(-1) |
| 102 | +variance = variance.squeeze(-1) |
| 103 | + |
| 104 | +print("\nFinal Results:") |
| 105 | +print(f" Test RMSE: {float(test_rmse):.4f}") |
| 106 | +print(f" Mean predicted variance: {float(jnp.mean(variance[:n_test])):.4f}") |
| 107 | + |
| 108 | +# Show some predictions |
| 109 | +print("\nSample predictions (first 5 test molecules):") |
| 110 | +print(f"{'Actual':>10} {'Predicted':>10} {'Std Dev':>10}") |
| 111 | +for i in range(min(5, n_test)): |
| 112 | + actual = float(all_test_labels[i]) |
| 113 | + pred = float(mean[i]) |
| 114 | + std = float(jnp.sqrt(variance[i])) |
| 115 | + print(f"{actual:>10.3f} {pred:>10.3f} {std:>10.3f}") |
| 116 | + |
| 117 | +print("\n" + "=" * 60) |
| 118 | +print("MPNN demo completed successfully!") |
| 119 | +print("=" * 60) |
0 commit comments