@@ -29,8 +29,11 @@ python examples/active_learning_benchmark.py
2929python examples/mpnn_demo.py
3030python examples/gat_demo.py
3131python examples/graph_transformer_demo.py
32- python examples/ensemble_demo.py
33- python examples/evidential_demo.py
32+ python examples/ensemble_active_learning.py
33+ python examples/evidential_active_learning.py
34+ python examples/acquisition_strategies_demo.py
35+ python examples/calibration_comparison.py
36+ python examples/uncertainty_gcn_demo.py
3437```
3538
3639## Architecture
@@ -83,141 +86,66 @@ jraph.GraphsTuple (batched - all molecules as one graph)
8386| ` molax/models/ensemble.py ` | ` EnsembleConfig ` , ` DeepEnsemble ` for ensemble uncertainty |
8487| ` molax/models/evidential.py ` | ` EvidentialConfig ` , ` EvidentialGCN ` for evidential uncertainty |
8588| ` molax/utils/data.py ` | ` MolecularDataset ` , ` smiles_to_jraph ` , ` batch_graphs ` |
86- | ` molax/acquisition/uncertainty.py ` | ` uncertainty_sampling ` , ` ensemble_uncertainty_sampling ` , ` evidential_uncertainty_sampling ` |
87- | ` molax/acquisition/diversity.py ` | ` diversity_sampling ` |
88- | ` molax/metrics/calibration.py ` | ` expected_calibration_error ` , ` calibration_report ` |
89- | ` molax/metrics/visualization.py ` | ` plot_calibration_curve ` , ` plot_reliability_diagram ` |
89+ | ` molax/acquisition/uncertainty.py ` | ` uncertainty_sampling ` , ` ensemble_uncertainty_sampling ` , ` evidential_uncertainty_sampling ` , ` diversity_sampling ` , ` combined_* ` strategies |
90+ | ` molax/acquisition/bald.py ` | ` bald_sampling ` , ` ensemble_bald_sampling ` , ` evidential_bald_sampling ` (mutual information) |
91+ | ` molax/acquisition/batch_aware.py ` | ` batch_bald_sampling ` , ` dpp_sampling ` , ` combined_batch_acquisition ` |
92+ | ` molax/acquisition/coreset.py ` | ` coreset_sampling ` , ` coreset_sampling_with_scores ` (K-center greedy) |
93+ | ` molax/acquisition/expected_change.py ` | ` egl_sampling ` , ` egl_acquisition ` (expected gradient length) |
94+ | ` molax/metrics/calibration.py ` | ` expected_calibration_error ` , ` evaluate_calibration ` , ` TemperatureScaling ` , ` negative_log_likelihood ` , ` sharpness ` |
95+ | ` molax/metrics/visualization.py ` | ` plot_reliability_diagram ` , ` plot_calibration_comparison ` , ` plot_uncertainty_vs_error ` , ` create_calibration_report ` |
9096
9197### Model API
9298
93- ``` python
94- from molax.models.gcn import GCNConfig, UncertaintyGCN
95-
96- config = GCNConfig(
97- node_features = 6 , # Atom features
98- hidden_features = [64 , 64 ], # GCN layers
99- out_features = 1 , # Output dim
100- dropout_rate = 0.1 ,
101- )
102- model = UncertaintyGCN(config, rngs = nnx.Rngs(0 ))
99+ All models share a common pattern: ` Config ` → ` Model(config, rngs) ` → ` model(graphs, training) ` → ` (mean, variance) ` .
103100
104- # Forward pass
105- mean, variance = model(batched_graphs, training = True )
106- ```
101+ | Model | Config | Extra Config Params | Output |
102+ | -------| --------| -------------------| --------|
103+ | ` UncertaintyGCN ` | ` GCNConfig ` | — | ` (mean, var) ` |
104+ | ` UncertaintyMPNN ` | ` MPNNConfig ` | ` edge_features ` , ` aggregation ` | ` (mean, var) ` |
105+ | ` UncertaintyGAT ` | ` GATConfig ` | ` n_heads ` , ` attention_dropout_rate ` , ` negative_slope ` | ` (mean, var) ` |
106+ | ` UncertaintyGraphTransformer ` | ` GraphTransformerConfig ` | ` n_heads ` , ` ffn_ratio ` , ` pe_type ` , ` pe_dim ` | ` (mean, var) ` |
107+ | ` DeepEnsemble ` | ` EnsembleConfig ` | ` n_members ` | ` (mean, epistemic_var, aleatoric_var) ` |
108+ | ` EvidentialGCN ` | ` EvidentialConfig ` | — | ` (mean, aleatoric_var, epistemic_var) ` |
107109
108- ### Ensemble API
110+ All configs share: ` node_features ` , ` hidden_features ` , ` out_features ` , ` dropout_rate ` . Edge-aware models (MPNN, GAT, GraphTransformer) also take ` edge_features ` .
109111
110112``` python
111- from molax.models.ensemble import EnsembleConfig, DeepEnsemble
113+ # Example: any model follows this pattern
114+ from molax.models import GCNConfig, UncertaintyGCN
112115
113- config = EnsembleConfig(
114- node_features = 6 ,
115- hidden_features = [64 , 64 ],
116- out_features = 1 ,
117- n_members = 5 ,
118- )
119- ensemble = DeepEnsemble(config, rngs = nnx.Rngs(0 ))
116+ config = GCNConfig(node_features = 6 , hidden_features = [64 , 64 ], out_features = 1 , dropout_rate = 0.1 )
117+ model = UncertaintyGCN(config, rngs = nnx.Rngs(0 ))
118+ mean, variance = model(batched_graphs, training = True )
120119
121- # Returns separate epistemic and aleatoric uncertainty
120+ # Ensemble returns 3 values
121+ from molax.models import EnsembleConfig, DeepEnsemble
122+ ensemble = DeepEnsemble(EnsembleConfig(node_features = 6 , hidden_features = [64 , 64 ], out_features = 1 , n_members = 5 ), rngs = nnx.Rngs(0 ))
122123mean, epistemic_var, aleatoric_var = ensemble(batched_graphs, training = False )
123- ```
124-
125- ### Evidential API
126-
127- ``` python
128- from molax.models.evidential import EvidentialConfig, EvidentialGCN
129-
130- config = EvidentialConfig(
131- node_features = 6 ,
132- hidden_features = [64 , 64 ],
133- out_features = 1 ,
134- )
135- model = EvidentialGCN(config, rngs = nnx.Rngs(0 ))
136-
137- # Single forward pass for both uncertainties
138- mean, aleatoric_var, epistemic_var = model(batched_graphs, training = False )
139- ```
140-
141- ### MPNN API
142-
143- ``` python
144- from molax.models.mpnn import MPNNConfig, UncertaintyMPNN
145-
146- config = MPNNConfig(
147- node_features = 6 ,
148- edge_features = 1 , # Bond type feature
149- hidden_features = [64 , 64 ],
150- out_features = 1 ,
151- aggregation = " sum" , # or "mean", "max"
152- dropout_rate = 0.1 ,
153- )
154- model = UncertaintyMPNN(config, rngs = nnx.Rngs(0 ))
155-
156- # Same API as UncertaintyGCN - uses edge features in message passing
157- mean, variance = model(batched_graphs, training = False )
158- ```
159-
160- ### GAT API
161124
162- ``` python
163- from molax.models.gat import GATConfig, UncertaintyGAT
164-
165- config = GATConfig(
166- node_features = 6 ,
167- edge_features = 1 , # Optional: include edge features in attention
168- hidden_features = [64 , 64 ],
169- out_features = 1 ,
170- n_heads = 4 , # Multi-head attention
171- dropout_rate = 0.1 ,
172- attention_dropout_rate = 0.1 ,
173- negative_slope = 0.2 , # LeakyReLU slope
174- )
175- model = UncertaintyGAT(config, rngs = nnx.Rngs(0 ))
176-
177- # Same API as UncertaintyGCN/UncertaintyMPNN - uses attention for aggregation
178- mean, variance = model(batched_graphs, training = False )
179- ```
180-
181- ### Graph Transformer API
182-
183- ``` python
184- from molax.models.graph_transformer import GraphTransformerConfig, UncertaintyGraphTransformer
185-
186- config = GraphTransformerConfig(
187- node_features = 6 ,
188- edge_features = 1 , # Optional: edge features as attention bias
189- hidden_features = [64 , 64 ],
190- out_features = 1 ,
191- n_heads = 4 , # Multi-head self-attention
192- ffn_ratio = 4.0 , # FFN hidden dim = 4 * model dim
193- dropout_rate = 0.1 ,
194- attention_dropout_rate = 0.1 ,
195- pe_type = " rwpe" , # Positional encoding: "rwpe", "laplacian", or "none"
196- pe_dim = 16 , # Positional encoding dimension
197- )
198- model = UncertaintyGraphTransformer(config, rngs = nnx.Rngs(0 ))
199-
200- # Same API as UncertaintyGCN/UncertaintyMPNN/UncertaintyGAT
201- mean, variance = model(batched_graphs, training = False )
202-
203- # Extract embeddings for Core-Set selection
125+ # GraphTransformer supports embedding extraction for CoreSet
204126embeddings = model.extract_embeddings(batched_graphs)
205127```
206128
207129### Calibration Metrics
208130
209131``` python
210- from molax.metrics import expected_calibration_error, calibration_report
211- from molax.metrics.visualization import plot_calibration_curve
132+ from molax.metrics import expected_calibration_error, evaluate_calibration, TemperatureScaling
133+ from molax.metrics import plot_reliability_diagram, create_calibration_report
212134
213135# Compute ECE
214136ece = expected_calibration_error(predictions, variances, targets)
215137
216- # Generate full report
217- report = calibration_report(predictions, variances, targets)
138+ # Full evaluation (returns dict with ece, nll, sharpness, mse, rmse)
139+ metrics = evaluate_calibration(predictions, variances, targets)
140+
141+ # Post-hoc calibration
142+ scaler = TemperatureScaling()
143+ scaler.fit(val_mean, val_var, val_targets)
144+ calibrated_var = scaler.transform(test_var)
218145
219146# Visualize
220- fig = plot_calibration_curve(predictions, variances, targets)
147+ plot_reliability_diagram(predictions, variances, targets)
148+ create_calibration_report(predictions, variances, targets) # Multi-plot report
221149```
222150
223151### Optimizer Pattern (Flax 0.11+)
@@ -241,7 +169,11 @@ pytest tests/test_gat.py -v # GAT model tests
241169pytest tests/test_graph_transformer.py -v # Graph Transformer tests
242170pytest tests/test_ensemble.py -v # Ensemble tests
243171pytest tests/test_evidential.py -v # Evidential tests
244- pytest tests/test_acquisition.py -v # Acquisition tests
172+ pytest tests/test_acquisition.py -v # Acquisition tests (uncertainty/diversity)
173+ pytest tests/test_bald.py -v # BALD acquisition tests
174+ pytest tests/test_batch_aware.py -v # BatchBALD/DPP tests
175+ pytest tests/test_coreset.py -v # CoreSet acquisition tests
176+ pytest tests/test_expected_change.py -v # Expected gradient length tests
245177pytest tests/test_calibration.py -v # Calibration tests
246178```
247179
@@ -257,21 +189,21 @@ Download: `python scripts/download_esol.py`
257189
258190## GitHub CLI
259191
260- Use ` gh_cli ` command to interact with GitHub:
192+ Use ` gh ` command to interact with GitHub:
261193
262194``` bash
263195# Check workflow runs
264- gh_cli run list
196+ gh run list
265197
266198# View specific run
267- gh_cli run view < run-id>
199+ gh run view < run-id>
268200
269201# Watch a run in progress
270- gh_cli run watch < run-id>
202+ gh run watch < run-id>
271203
272204# View workflow logs
273- gh_cli run view < run-id> --log
205+ gh run view < run-id> --log
274206
275207# Trigger workflow manually
276- gh_cli workflow run ci.yml
208+ gh workflow run ci.yml
277209```
0 commit comments