|
2 | 2 | Example demonstrating how to use UncertaintyGCN for molecular property prediction. |
3 | 3 |
|
4 | 4 | This script shows how to: |
5 | | -1. Create a simple molecule graph |
| 5 | +1. Load molecules and convert to jraph graphs |
6 | 6 | 2. Initialize and use an UncertaintyGCN model |
7 | 7 | 3. Interpret uncertainty in predictions |
8 | 8 | """ |
9 | 9 |
|
10 | | -import jax |
| 10 | +import flax.nnx as nnx |
11 | 11 | import jax.numpy as jnp |
| 12 | +import jraph |
12 | 13 | import matplotlib.pyplot as plt |
13 | | -import numpy as np |
14 | | -from flax import nnx |
15 | | - |
16 | | -from molax.models.gcn import UncertaintyGCN, UncertaintyGCNConfig |
17 | | - |
18 | | -# Set random seed for reproducibility |
19 | | -key = jax.random.PRNGKey(42) |
20 | | -rngs = nnx.Rngs(0, params=1, dropout=2) |
21 | | - |
22 | | -# Create a simple molecule graph (6 atoms with 4 features per atom) |
23 | | -num_atoms = 6 |
24 | | -in_features = 4 |
25 | | - |
26 | | -# 1. Create random atom features |
27 | | -key, subkey = jax.random.split(key) |
28 | | -atom_features = jax.random.normal(subkey, (num_atoms, in_features)) |
29 | | - |
30 | | -# 2. Create a ring-shaped molecular graph |
31 | | -adjacency_matrix = np.zeros((num_atoms, num_atoms)) |
32 | | -for i in range(num_atoms): |
33 | | - # Connect each atom to its neighbors (creating a ring) |
34 | | - adjacency_matrix[i, (i + 1) % num_atoms] = 1.0 |
35 | | - adjacency_matrix[i, (i - 1) % num_atoms] = 1.0 |
36 | | -adjacency_matrix = jnp.array(adjacency_matrix) |
37 | | - |
38 | | -print("atom_features.shape", atom_features.shape) |
39 | | -print("adjacency_matrix.shape", adjacency_matrix.shape) |
40 | | - |
41 | | -# 3. Create and initialize the UncertaintyGCN model |
42 | | -config = UncertaintyGCNConfig( |
43 | | - in_features=in_features, # Number of input features per atom |
44 | | - hidden_features=[32, 16, 8], # GCN layer sizes |
45 | | - out_features=1, # Single property prediction |
46 | | - dropout_rate=0.1, # Dropout for regularization |
47 | | - n_heads=2, |
48 | | - rngs=rngs, |
49 | | -) |
50 | | - |
51 | | -model = UncertaintyGCN(config) |
52 | | - |
53 | | - |
54 | | -# 4. Make a prediction with uncertainty |
55 | | -mean, variance = model(atom_features, adjacency_matrix) |
56 | 14 |
|
57 | | -print(f"Predicted property value: {mean[0]:.4f}") |
58 | | -print(f"Prediction uncertainty (variance): {variance[0]:.4f}") |
59 | | -print( |
60 | | - f"95% confidence interval: ({mean[0] - 1.96 * jnp.sqrt(variance[0]):.4f}, " |
61 | | - f"{mean[0] + 1.96 * jnp.sqrt(variance[0]):.4f})" |
| 15 | +from molax.models.gcn import GCNConfig, UncertaintyGCN |
| 16 | +from molax.utils.data import smiles_to_jraph |
| 17 | + |
| 18 | +# Create some sample molecules with varying complexity |
| 19 | +molecules = [ |
| 20 | + ("C", "methane"), |
| 21 | + ("CC", "ethane"), |
| 22 | + ("CCC", "propane"), |
| 23 | + ("CCCC", "butane"), |
| 24 | + ("c1ccccc1", "benzene"), |
| 25 | + ("CCO", "ethanol"), |
| 26 | + ("CC(=O)O", "acetic acid"), |
| 27 | + ("c1ccc(O)cc1", "phenol"), |
| 28 | +] |
| 29 | + |
| 30 | +print("=" * 60) |
| 31 | +print("UncertaintyGCN Demo: Molecular Property Prediction") |
| 32 | +print("=" * 60) |
| 33 | + |
| 34 | +# Convert SMILES to jraph graphs |
| 35 | +graphs = [smiles_to_jraph(smi) for smi, _ in molecules] |
| 36 | +batched_graphs = jraph.batch(graphs) |
| 37 | + |
| 38 | +print(f"\nLoaded {len(molecules)} molecules") |
| 39 | +print(f"Node features: {graphs[0].nodes.shape[1]}") |
| 40 | + |
| 41 | +# Create and initialize the UncertaintyGCN model |
| 42 | +config = GCNConfig( |
| 43 | + node_features=graphs[0].nodes.shape[1], |
| 44 | + hidden_features=[32, 16], |
| 45 | + out_features=1, |
| 46 | + dropout_rate=0.1, |
62 | 47 | ) |
63 | | - |
64 | | -# 6. Demonstrate uncertainty behavior with modified inputs |
65 | | -test_points = 50 |
66 | | -scaling_factors = jnp.linspace(0.1, 10.0, test_points) |
67 | | -means = [] |
68 | | -uncertainties = [] |
69 | | - |
70 | | -for scale in scaling_factors: |
71 | | - # Scale the input features to create increasingly out-of-distribution examples |
72 | | - scaled_features = atom_features * scale |
73 | | - mean, var = model(scaled_features, adjacency_matrix) |
74 | | - means.append(mean[0]) |
75 | | - uncertainties.append( |
76 | | - jnp.sqrt(var[0]) |
77 | | - ) # Use standard deviation for easier interpretation |
78 | | - |
79 | | -# Convert to arrays |
80 | | -means = jnp.array(means) |
81 | | -uncertainties = jnp.array(uncertainties) |
82 | | - |
83 | | -# 7. Visualize how uncertainty changes with input distribution shift |
84 | | -plt.figure(figsize=(10, 6)) |
85 | | -plt.plot(scaling_factors, means, "b-", label="Prediction") |
86 | | -plt.fill_between( |
87 | | - scaling_factors, |
88 | | - means - 1.96 * uncertainties, |
89 | | - means + 1.96 * uncertainties, |
90 | | - alpha=0.3, |
91 | | - color="b", |
92 | | - label="95% Confidence Interval", |
93 | | -) |
94 | | -plt.xlabel("Input Scaling Factor") |
95 | | -plt.ylabel("Predicted Property") |
96 | | -plt.title("Prediction with Uncertainty for Different Input Scales") |
97 | | -plt.legend() |
98 | | -plt.grid(True) |
99 | | -plt.savefig("examples/uncertainty_demo.png") |
| 48 | +model = UncertaintyGCN(config, rngs=nnx.Rngs(42)) |
| 49 | + |
| 50 | +print("\nModel configuration:") |
| 51 | +print(f" Hidden layers: {config.hidden_features}") |
| 52 | +print(f" Dropout rate: {config.dropout_rate}") |
| 53 | + |
| 54 | +# Make predictions with uncertainty |
| 55 | +print("\n" + "-" * 60) |
| 56 | +print("Predictions with Uncertainty") |
| 57 | +print("-" * 60) |
| 58 | + |
| 59 | +mean, variance = model(batched_graphs, training=False) |
| 60 | +mean = mean.squeeze(-1) |
| 61 | +variance = variance.squeeze(-1) |
| 62 | + |
| 63 | +print(f"{'Molecule':<15} {'Mean':>10} {'Std Dev':>10} {'95% CI'}") |
| 64 | +print("-" * 60) |
| 65 | + |
| 66 | +for i, (_, name) in enumerate(molecules): |
| 67 | + m = float(mean[i]) |
| 68 | + std = float(jnp.sqrt(variance[i])) |
| 69 | + ci_low = m - 1.96 * std |
| 70 | + ci_high = m + 1.96 * std |
| 71 | + print(f"{name:<15} {m:>10.4f} {std:>10.4f} [{ci_low:.2f}, {ci_high:.2f}]") |
| 72 | + |
| 73 | +# Demonstrate MC Dropout uncertainty |
| 74 | +print("\n" + "-" * 60) |
| 75 | +print("MC Dropout Uncertainty (10 samples)") |
| 76 | +print("-" * 60) |
| 77 | + |
| 78 | +mc_predictions = [] |
| 79 | +for _ in range(10): |
| 80 | + pred, _ = model(batched_graphs, training=True) # Dropout active |
| 81 | + mc_predictions.append(pred.squeeze(-1)) |
| 82 | + |
| 83 | +mc_predictions = jnp.stack(mc_predictions) |
| 84 | +mc_mean = jnp.mean(mc_predictions, axis=0) |
| 85 | +mc_std = jnp.std(mc_predictions, axis=0) |
| 86 | + |
| 87 | +print(f"{'Molecule':<15} {'MC Mean':>10} {'MC Std':>10}") |
| 88 | +print("-" * 40) |
| 89 | + |
| 90 | +for i, (_, name) in enumerate(molecules): |
| 91 | + print(f"{name:<15} {float(mc_mean[i]):>10.4f} {float(mc_std[i]):>10.4f}") |
| 92 | + |
| 93 | +# Visualize predictions |
| 94 | +print("\n" + "-" * 60) |
| 95 | +print("Creating visualization...") |
| 96 | + |
| 97 | +fig, ax = plt.subplots(figsize=(10, 6)) |
| 98 | + |
| 99 | +x = range(len(molecules)) |
| 100 | +names = [name for _, name in molecules] |
| 101 | +means = [float(mean[i]) for i in range(len(molecules))] |
| 102 | +stds = [float(jnp.sqrt(variance[i])) for i in range(len(molecules))] |
| 103 | + |
| 104 | +ax.bar(x, means, yerr=[1.96 * s for s in stds], capsize=5, alpha=0.7) |
| 105 | +ax.set_xticks(x) |
| 106 | +ax.set_xticklabels(names, rotation=45, ha="right") |
| 107 | +ax.set_ylabel("Predicted Value") |
| 108 | +ax.set_title("UncertaintyGCN Predictions with 95% Confidence Intervals") |
| 109 | +ax.grid(axis="y", alpha=0.3) |
| 110 | + |
| 111 | +plt.tight_layout() |
| 112 | +plt.savefig("examples/uncertainty_demo.png", dpi=150) |
100 | 113 | plt.close() |
101 | 114 |
|
102 | | -print("\nGenerating out-of-distribution examples:") |
103 | | -for scale in [0.1, 1.0, 10.0]: |
104 | | - scaled_features = atom_features * scale |
105 | | - mean, var = model(scaled_features, adjacency_matrix) |
106 | | - std = jnp.sqrt(var[0]) |
107 | | - print( |
108 | | - f"Scale {scale:.1f}: {mean[0]:.4f} ± {std:.4f} " |
109 | | - f"(95% CI: {mean[0] - 1.96 * std:.4f} to {mean[0] + 1.96 * std:.4f})" |
110 | | - ) |
111 | | - |
112 | | -print("\nDemo completed. Saved visualization to 'uncertainty_demo.png'") |
113 | | -print( |
114 | | - "This demonstrates how uncertainty increases as inputs become more " |
115 | | - "out-of-distribution." |
116 | | -) |
| 115 | +print("Saved visualization to 'examples/uncertainty_demo.png'") |
| 116 | + |
| 117 | +print("\n" + "=" * 60) |
| 118 | +print("Demo completed successfully!") |
| 119 | +print("=" * 60) |
| 120 | +print("\nKey takeaways:") |
| 121 | +print("- UncertaintyGCN outputs both mean prediction and variance") |
| 122 | +print("- Variance head predicts aleatoric (data) uncertainty") |
| 123 | +print("- MC Dropout provides epistemic (model) uncertainty") |
| 124 | +print("- 95% CI = mean ± 1.96 * std") |
0 commit comments