Skip to content

Commit 46d37c0

Browse files
HFooladiclaude
andcommitted
fix: update uncertainty_gcn_demo.py to use current jraph-based API
The demo was using an outdated API (UncertaintyGCNConfig, adjacency matrices). Updated to use current API (GCNConfig, jraph graphs, smiles_to_jraph). Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent dbcce34 commit 46d37c0

File tree

1 file changed

+110
-102
lines changed

1 file changed

+110
-102
lines changed

examples/uncertainty_gcn_demo.py

Lines changed: 110 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -2,115 +2,123 @@
22
Example demonstrating how to use UncertaintyGCN for molecular property prediction.
33
44
This script shows how to:
5-
1. Create a simple molecule graph
5+
1. Load molecules and convert to jraph graphs
66
2. Initialize and use an UncertaintyGCN model
77
3. Interpret uncertainty in predictions
88
"""
99

10-
import jax
10+
import flax.nnx as nnx
1111
import jax.numpy as jnp
12+
import jraph
1213
import matplotlib.pyplot as plt
13-
import numpy as np
14-
from flax import nnx
15-
16-
from molax.models.gcn import UncertaintyGCN, UncertaintyGCNConfig
17-
18-
# Set random seed for reproducibility
19-
key = jax.random.PRNGKey(42)
20-
rngs = nnx.Rngs(0, params=1, dropout=2)
21-
22-
# Create a simple molecule graph (6 atoms with 4 features per atom)
23-
num_atoms = 6
24-
in_features = 4
25-
26-
# 1. Create random atom features
27-
key, subkey = jax.random.split(key)
28-
atom_features = jax.random.normal(subkey, (num_atoms, in_features))
29-
30-
# 2. Create a ring-shaped molecular graph
31-
adjacency_matrix = np.zeros((num_atoms, num_atoms))
32-
for i in range(num_atoms):
33-
# Connect each atom to its neighbors (creating a ring)
34-
adjacency_matrix[i, (i + 1) % num_atoms] = 1.0
35-
adjacency_matrix[i, (i - 1) % num_atoms] = 1.0
36-
adjacency_matrix = jnp.array(adjacency_matrix)
37-
38-
print("atom_features.shape", atom_features.shape)
39-
print("adjacency_matrix.shape", adjacency_matrix.shape)
40-
41-
# 3. Create and initialize the UncertaintyGCN model
42-
config = UncertaintyGCNConfig(
43-
in_features=in_features, # Number of input features per atom
44-
hidden_features=[32, 16, 8], # GCN layer sizes
45-
out_features=1, # Single property prediction
46-
dropout_rate=0.1, # Dropout for regularization
47-
n_heads=2,
48-
rngs=rngs,
49-
)
50-
51-
model = UncertaintyGCN(config)
52-
53-
54-
# 4. Make a prediction with uncertainty
55-
mean, variance = model(atom_features, adjacency_matrix)
5614

57-
print(f"Predicted property value: {mean[0]:.4f}")
58-
print(f"Prediction uncertainty (variance): {variance[0]:.4f}")
59-
print(
60-
f"95% confidence interval: ({mean[0] - 1.96 * jnp.sqrt(variance[0]):.4f}, "
61-
f"{mean[0] + 1.96 * jnp.sqrt(variance[0]):.4f})"
15+
from molax.models.gcn import GCNConfig, UncertaintyGCN
16+
from molax.utils.data import smiles_to_jraph
17+
18+
# Create some sample molecules with varying complexity
19+
molecules = [
20+
("C", "methane"),
21+
("CC", "ethane"),
22+
("CCC", "propane"),
23+
("CCCC", "butane"),
24+
("c1ccccc1", "benzene"),
25+
("CCO", "ethanol"),
26+
("CC(=O)O", "acetic acid"),
27+
("c1ccc(O)cc1", "phenol"),
28+
]
29+
30+
print("=" * 60)
31+
print("UncertaintyGCN Demo: Molecular Property Prediction")
32+
print("=" * 60)
33+
34+
# Convert SMILES to jraph graphs
35+
graphs = [smiles_to_jraph(smi) for smi, _ in molecules]
36+
batched_graphs = jraph.batch(graphs)
37+
38+
print(f"\nLoaded {len(molecules)} molecules")
39+
print(f"Node features: {graphs[0].nodes.shape[1]}")
40+
41+
# Create and initialize the UncertaintyGCN model
42+
config = GCNConfig(
43+
node_features=graphs[0].nodes.shape[1],
44+
hidden_features=[32, 16],
45+
out_features=1,
46+
dropout_rate=0.1,
6247
)
63-
64-
# 6. Demonstrate uncertainty behavior with modified inputs
65-
test_points = 50
66-
scaling_factors = jnp.linspace(0.1, 10.0, test_points)
67-
means = []
68-
uncertainties = []
69-
70-
for scale in scaling_factors:
71-
# Scale the input features to create increasingly out-of-distribution examples
72-
scaled_features = atom_features * scale
73-
mean, var = model(scaled_features, adjacency_matrix)
74-
means.append(mean[0])
75-
uncertainties.append(
76-
jnp.sqrt(var[0])
77-
) # Use standard deviation for easier interpretation
78-
79-
# Convert to arrays
80-
means = jnp.array(means)
81-
uncertainties = jnp.array(uncertainties)
82-
83-
# 7. Visualize how uncertainty changes with input distribution shift
84-
plt.figure(figsize=(10, 6))
85-
plt.plot(scaling_factors, means, "b-", label="Prediction")
86-
plt.fill_between(
87-
scaling_factors,
88-
means - 1.96 * uncertainties,
89-
means + 1.96 * uncertainties,
90-
alpha=0.3,
91-
color="b",
92-
label="95% Confidence Interval",
93-
)
94-
plt.xlabel("Input Scaling Factor")
95-
plt.ylabel("Predicted Property")
96-
plt.title("Prediction with Uncertainty for Different Input Scales")
97-
plt.legend()
98-
plt.grid(True)
99-
plt.savefig("examples/uncertainty_demo.png")
48+
model = UncertaintyGCN(config, rngs=nnx.Rngs(42))
49+
50+
print("\nModel configuration:")
51+
print(f" Hidden layers: {config.hidden_features}")
52+
print(f" Dropout rate: {config.dropout_rate}")
53+
54+
# Make predictions with uncertainty
55+
print("\n" + "-" * 60)
56+
print("Predictions with Uncertainty")
57+
print("-" * 60)
58+
59+
mean, variance = model(batched_graphs, training=False)
60+
mean = mean.squeeze(-1)
61+
variance = variance.squeeze(-1)
62+
63+
print(f"{'Molecule':<15} {'Mean':>10} {'Std Dev':>10} {'95% CI'}")
64+
print("-" * 60)
65+
66+
for i, (_, name) in enumerate(molecules):
67+
m = float(mean[i])
68+
std = float(jnp.sqrt(variance[i]))
69+
ci_low = m - 1.96 * std
70+
ci_high = m + 1.96 * std
71+
print(f"{name:<15} {m:>10.4f} {std:>10.4f} [{ci_low:.2f}, {ci_high:.2f}]")
72+
73+
# Demonstrate MC Dropout uncertainty
74+
print("\n" + "-" * 60)
75+
print("MC Dropout Uncertainty (10 samples)")
76+
print("-" * 60)
77+
78+
mc_predictions = []
79+
for _ in range(10):
80+
pred, _ = model(batched_graphs, training=True) # Dropout active
81+
mc_predictions.append(pred.squeeze(-1))
82+
83+
mc_predictions = jnp.stack(mc_predictions)
84+
mc_mean = jnp.mean(mc_predictions, axis=0)
85+
mc_std = jnp.std(mc_predictions, axis=0)
86+
87+
print(f"{'Molecule':<15} {'MC Mean':>10} {'MC Std':>10}")
88+
print("-" * 40)
89+
90+
for i, (_, name) in enumerate(molecules):
91+
print(f"{name:<15} {float(mc_mean[i]):>10.4f} {float(mc_std[i]):>10.4f}")
92+
93+
# Visualize predictions
94+
print("\n" + "-" * 60)
95+
print("Creating visualization...")
96+
97+
fig, ax = plt.subplots(figsize=(10, 6))
98+
99+
x = range(len(molecules))
100+
names = [name for _, name in molecules]
101+
means = [float(mean[i]) for i in range(len(molecules))]
102+
stds = [float(jnp.sqrt(variance[i])) for i in range(len(molecules))]
103+
104+
ax.bar(x, means, yerr=[1.96 * s for s in stds], capsize=5, alpha=0.7)
105+
ax.set_xticks(x)
106+
ax.set_xticklabels(names, rotation=45, ha="right")
107+
ax.set_ylabel("Predicted Value")
108+
ax.set_title("UncertaintyGCN Predictions with 95% Confidence Intervals")
109+
ax.grid(axis="y", alpha=0.3)
110+
111+
plt.tight_layout()
112+
plt.savefig("examples/uncertainty_demo.png", dpi=150)
100113
plt.close()
101114

102-
print("\nGenerating out-of-distribution examples:")
103-
for scale in [0.1, 1.0, 10.0]:
104-
scaled_features = atom_features * scale
105-
mean, var = model(scaled_features, adjacency_matrix)
106-
std = jnp.sqrt(var[0])
107-
print(
108-
f"Scale {scale:.1f}: {mean[0]:.4f} ± {std:.4f} "
109-
f"(95% CI: {mean[0] - 1.96 * std:.4f} to {mean[0] + 1.96 * std:.4f})"
110-
)
111-
112-
print("\nDemo completed. Saved visualization to 'uncertainty_demo.png'")
113-
print(
114-
"This demonstrates how uncertainty increases as inputs become more "
115-
"out-of-distribution."
116-
)
115+
print("Saved visualization to 'examples/uncertainty_demo.png'")
116+
117+
print("\n" + "=" * 60)
118+
print("Demo completed successfully!")
119+
print("=" * 60)
120+
print("\nKey takeaways:")
121+
print("- UncertaintyGCN outputs both mean prediction and variance")
122+
print("- Variance head predicts aleatoric (data) uncertainty")
123+
print("- MC Dropout provides epistemic (model) uncertainty")
124+
print("- 95% CI = mean ± 1.96 * std")

0 commit comments

Comments
 (0)