Skip to content

Commit ec4bd7c

Browse files
HFooladiclaude
andcommitted
fix: update CLAUDE.md with correct names, missing modules, and condensed API docs
Fix incorrect filenames (ensemble_demo.py → ensemble_active_learning.py, evidential_demo.py → evidential_active_learning.py) and wrong function names (calibration_report → evaluate_calibration, plot_calibration_curve → plot_reliability_diagram). Add missing acquisition modules (BALD, batch-aware, CoreSet, EGL), missing test files, and missing examples. Condense 6 repetitive model API sections into a compact table. Co-Authored-By: Claude <noreply@anthropic.com>
1 parent 9a26918 commit ec4bd7c

File tree

1 file changed

+53
-121
lines changed

1 file changed

+53
-121
lines changed

CLAUDE.md

Lines changed: 53 additions & 121 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,11 @@ python examples/active_learning_benchmark.py
2929
python examples/mpnn_demo.py
3030
python examples/gat_demo.py
3131
python examples/graph_transformer_demo.py
32-
python examples/ensemble_demo.py
33-
python examples/evidential_demo.py
32+
python examples/ensemble_active_learning.py
33+
python examples/evidential_active_learning.py
34+
python examples/acquisition_strategies_demo.py
35+
python examples/calibration_comparison.py
36+
python examples/uncertainty_gcn_demo.py
3437
```
3538

3639
## Architecture
@@ -83,141 +86,66 @@ jraph.GraphsTuple (batched - all molecules as one graph)
8386
| `molax/models/ensemble.py` | `EnsembleConfig`, `DeepEnsemble` for ensemble uncertainty |
8487
| `molax/models/evidential.py` | `EvidentialConfig`, `EvidentialGCN` for evidential uncertainty |
8588
| `molax/utils/data.py` | `MolecularDataset`, `smiles_to_jraph`, `batch_graphs` |
86-
| `molax/acquisition/uncertainty.py` | `uncertainty_sampling`, `ensemble_uncertainty_sampling`, `evidential_uncertainty_sampling` |
87-
| `molax/acquisition/diversity.py` | `diversity_sampling` |
88-
| `molax/metrics/calibration.py` | `expected_calibration_error`, `calibration_report` |
89-
| `molax/metrics/visualization.py` | `plot_calibration_curve`, `plot_reliability_diagram` |
89+
| `molax/acquisition/uncertainty.py` | `uncertainty_sampling`, `ensemble_uncertainty_sampling`, `evidential_uncertainty_sampling`, `diversity_sampling`, `combined_*` strategies |
90+
| `molax/acquisition/bald.py` | `bald_sampling`, `ensemble_bald_sampling`, `evidential_bald_sampling` (mutual information) |
91+
| `molax/acquisition/batch_aware.py` | `batch_bald_sampling`, `dpp_sampling`, `combined_batch_acquisition` |
92+
| `molax/acquisition/coreset.py` | `coreset_sampling`, `coreset_sampling_with_scores` (K-center greedy) |
93+
| `molax/acquisition/expected_change.py` | `egl_sampling`, `egl_acquisition` (expected gradient length) |
94+
| `molax/metrics/calibration.py` | `expected_calibration_error`, `evaluate_calibration`, `TemperatureScaling`, `negative_log_likelihood`, `sharpness` |
95+
| `molax/metrics/visualization.py` | `plot_reliability_diagram`, `plot_calibration_comparison`, `plot_uncertainty_vs_error`, `create_calibration_report` |
9096

9197
### Model API
9298

93-
```python
94-
from molax.models.gcn import GCNConfig, UncertaintyGCN
95-
96-
config = GCNConfig(
97-
node_features=6, # Atom features
98-
hidden_features=[64, 64], # GCN layers
99-
out_features=1, # Output dim
100-
dropout_rate=0.1,
101-
)
102-
model = UncertaintyGCN(config, rngs=nnx.Rngs(0))
99+
All models share a common pattern: `Config``Model(config, rngs)``model(graphs, training)``(mean, variance)`.
103100

104-
# Forward pass
105-
mean, variance = model(batched_graphs, training=True)
106-
```
101+
| Model | Config | Extra Config Params | Output |
102+
|-------|--------|-------------------|--------|
103+
| `UncertaintyGCN` | `GCNConfig` || `(mean, var)` |
104+
| `UncertaintyMPNN` | `MPNNConfig` | `edge_features`, `aggregation` | `(mean, var)` |
105+
| `UncertaintyGAT` | `GATConfig` | `n_heads`, `attention_dropout_rate`, `negative_slope` | `(mean, var)` |
106+
| `UncertaintyGraphTransformer` | `GraphTransformerConfig` | `n_heads`, `ffn_ratio`, `pe_type`, `pe_dim` | `(mean, var)` |
107+
| `DeepEnsemble` | `EnsembleConfig` | `n_members` | `(mean, epistemic_var, aleatoric_var)` |
108+
| `EvidentialGCN` | `EvidentialConfig` || `(mean, aleatoric_var, epistemic_var)` |
107109

108-
### Ensemble API
110+
All configs share: `node_features`, `hidden_features`, `out_features`, `dropout_rate`. Edge-aware models (MPNN, GAT, GraphTransformer) also take `edge_features`.
109111

110112
```python
111-
from molax.models.ensemble import EnsembleConfig, DeepEnsemble
113+
# Example: any model follows this pattern
114+
from molax.models import GCNConfig, UncertaintyGCN
112115

113-
config = EnsembleConfig(
114-
node_features=6,
115-
hidden_features=[64, 64],
116-
out_features=1,
117-
n_members=5,
118-
)
119-
ensemble = DeepEnsemble(config, rngs=nnx.Rngs(0))
116+
config = GCNConfig(node_features=6, hidden_features=[64, 64], out_features=1, dropout_rate=0.1)
117+
model = UncertaintyGCN(config, rngs=nnx.Rngs(0))
118+
mean, variance = model(batched_graphs, training=True)
120119

121-
# Returns separate epistemic and aleatoric uncertainty
120+
# Ensemble returns 3 values
121+
from molax.models import EnsembleConfig, DeepEnsemble
122+
ensemble = DeepEnsemble(EnsembleConfig(node_features=6, hidden_features=[64, 64], out_features=1, n_members=5), rngs=nnx.Rngs(0))
122123
mean, epistemic_var, aleatoric_var = ensemble(batched_graphs, training=False)
123-
```
124-
125-
### Evidential API
126-
127-
```python
128-
from molax.models.evidential import EvidentialConfig, EvidentialGCN
129-
130-
config = EvidentialConfig(
131-
node_features=6,
132-
hidden_features=[64, 64],
133-
out_features=1,
134-
)
135-
model = EvidentialGCN(config, rngs=nnx.Rngs(0))
136-
137-
# Single forward pass for both uncertainties
138-
mean, aleatoric_var, epistemic_var = model(batched_graphs, training=False)
139-
```
140-
141-
### MPNN API
142-
143-
```python
144-
from molax.models.mpnn import MPNNConfig, UncertaintyMPNN
145-
146-
config = MPNNConfig(
147-
node_features=6,
148-
edge_features=1, # Bond type feature
149-
hidden_features=[64, 64],
150-
out_features=1,
151-
aggregation="sum", # or "mean", "max"
152-
dropout_rate=0.1,
153-
)
154-
model = UncertaintyMPNN(config, rngs=nnx.Rngs(0))
155-
156-
# Same API as UncertaintyGCN - uses edge features in message passing
157-
mean, variance = model(batched_graphs, training=False)
158-
```
159-
160-
### GAT API
161124

162-
```python
163-
from molax.models.gat import GATConfig, UncertaintyGAT
164-
165-
config = GATConfig(
166-
node_features=6,
167-
edge_features=1, # Optional: include edge features in attention
168-
hidden_features=[64, 64],
169-
out_features=1,
170-
n_heads=4, # Multi-head attention
171-
dropout_rate=0.1,
172-
attention_dropout_rate=0.1,
173-
negative_slope=0.2, # LeakyReLU slope
174-
)
175-
model = UncertaintyGAT(config, rngs=nnx.Rngs(0))
176-
177-
# Same API as UncertaintyGCN/UncertaintyMPNN - uses attention for aggregation
178-
mean, variance = model(batched_graphs, training=False)
179-
```
180-
181-
### Graph Transformer API
182-
183-
```python
184-
from molax.models.graph_transformer import GraphTransformerConfig, UncertaintyGraphTransformer
185-
186-
config = GraphTransformerConfig(
187-
node_features=6,
188-
edge_features=1, # Optional: edge features as attention bias
189-
hidden_features=[64, 64],
190-
out_features=1,
191-
n_heads=4, # Multi-head self-attention
192-
ffn_ratio=4.0, # FFN hidden dim = 4 * model dim
193-
dropout_rate=0.1,
194-
attention_dropout_rate=0.1,
195-
pe_type="rwpe", # Positional encoding: "rwpe", "laplacian", or "none"
196-
pe_dim=16, # Positional encoding dimension
197-
)
198-
model = UncertaintyGraphTransformer(config, rngs=nnx.Rngs(0))
199-
200-
# Same API as UncertaintyGCN/UncertaintyMPNN/UncertaintyGAT
201-
mean, variance = model(batched_graphs, training=False)
202-
203-
# Extract embeddings for Core-Set selection
125+
# GraphTransformer supports embedding extraction for CoreSet
204126
embeddings = model.extract_embeddings(batched_graphs)
205127
```
206128

207129
### Calibration Metrics
208130

209131
```python
210-
from molax.metrics import expected_calibration_error, calibration_report
211-
from molax.metrics.visualization import plot_calibration_curve
132+
from molax.metrics import expected_calibration_error, evaluate_calibration, TemperatureScaling
133+
from molax.metrics import plot_reliability_diagram, create_calibration_report
212134

213135
# Compute ECE
214136
ece = expected_calibration_error(predictions, variances, targets)
215137

216-
# Generate full report
217-
report = calibration_report(predictions, variances, targets)
138+
# Full evaluation (returns dict with ece, nll, sharpness, mse, rmse)
139+
metrics = evaluate_calibration(predictions, variances, targets)
140+
141+
# Post-hoc calibration
142+
scaler = TemperatureScaling()
143+
scaler.fit(val_mean, val_var, val_targets)
144+
calibrated_var = scaler.transform(test_var)
218145

219146
# Visualize
220-
fig = plot_calibration_curve(predictions, variances, targets)
147+
plot_reliability_diagram(predictions, variances, targets)
148+
create_calibration_report(predictions, variances, targets) # Multi-plot report
221149
```
222150

223151
### Optimizer Pattern (Flax 0.11+)
@@ -241,7 +169,11 @@ pytest tests/test_gat.py -v # GAT model tests
241169
pytest tests/test_graph_transformer.py -v # Graph Transformer tests
242170
pytest tests/test_ensemble.py -v # Ensemble tests
243171
pytest tests/test_evidential.py -v # Evidential tests
244-
pytest tests/test_acquisition.py -v # Acquisition tests
172+
pytest tests/test_acquisition.py -v # Acquisition tests (uncertainty/diversity)
173+
pytest tests/test_bald.py -v # BALD acquisition tests
174+
pytest tests/test_batch_aware.py -v # BatchBALD/DPP tests
175+
pytest tests/test_coreset.py -v # CoreSet acquisition tests
176+
pytest tests/test_expected_change.py -v # Expected gradient length tests
245177
pytest tests/test_calibration.py -v # Calibration tests
246178
```
247179

@@ -257,21 +189,21 @@ Download: `python scripts/download_esol.py`
257189

258190
## GitHub CLI
259191

260-
Use `gh_cli` command to interact with GitHub:
192+
Use `gh` command to interact with GitHub:
261193

262194
```bash
263195
# Check workflow runs
264-
gh_cli run list
196+
gh run list
265197

266198
# View specific run
267-
gh_cli run view <run-id>
199+
gh run view <run-id>
268200

269201
# Watch a run in progress
270-
gh_cli run watch <run-id>
202+
gh run watch <run-id>
271203

272204
# View workflow logs
273-
gh_cli run view <run-id> --log
205+
gh run view <run-id> --log
274206

275207
# Trigger workflow manually
276-
gh_cli workflow run ci.yml
208+
gh workflow run ci.yml
277209
```

0 commit comments

Comments
 (0)