Skip to content

Commit 6325dab

Browse files
release: create release-0.1.2 branch
1 parent f5f2559 commit 6325dab

File tree

16 files changed

+275
-153
lines changed

16 files changed

+275
-153
lines changed

.github/workflows/deploy_docs.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ jobs:
2121
pip install -U pip setuptools
2222
pip install poetry==${POETRY_VERSION}
2323
poetry install
24+
poetry run pip install git+https://github.com/jax-md/jax-md.git
2425
2526
- name: Sphinx build
2627
run: |

CHANGELOG

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,18 @@
11
# Changelog
22

3+
## Release 0.1.2
4+
5+
- Fixing the computation of metrics during training, by reweighting the metrics of
6+
each batch to account for a varying number of real graphs per batch; this results
7+
in the metrics being independent of the batching strategy and number of GPUs employed
8+
- In addition to the point above, fixing the computation of RMSE metrics by now
9+
only computing MSE metrics in the loss and taking the square root at the very end
10+
when logging
11+
- Deleting relative and 95-percentile metrics, as they are not straightforward to
12+
compute on-the-fly with our dynamic batching strategy; we recommend to compute them
13+
separately for a model checkpoint if necessary
14+
- Small amount of modifications to README and documentation
15+
316
## Release 0.1.1
417

518
- Small amount of modifications to README and documentation

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ more information.
5656
At time of release, the following install command is supported:
5757

5858
```bash
59-
pip install -U "jax[cuda12]"
59+
pip install -U "jax[cuda12]==0.4.33"
6060
```
6161

6262
Note that using the TPU version of *jaxlib* is, in principle, also supported by
@@ -169,7 +169,7 @@ Scott Cameron, Louis Robinson, Tom Barrett, and Alex Laterre.
169169

170170
## 📚 Citing our work
171171

172-
We kindly request to cite [our white paper](https://arxiv.org/abs/2505.22397)
172+
We kindly request that you to cite [our white paper](https://arxiv.org/abs/2505.22397)
173173
when using this library:
174174

175175
C. Brunken, O. Peltre, H. Chomet, L. Walewski, M. McAuliffe, V. Heyraud,

docs/source/api_reference/training/training_io_handling.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,5 @@ IO handling during training
3030
.. autofunction:: log_metrics_to_table
3131

3232
.. autofunction:: log_metrics_to_line
33+
34+
.. autofunction:: convert_mse_to_rmse_in_logs

docs/source/installation/index.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ At time of release, the following install command is supported:
1919

2020
.. code-block:: bash
2121
22-
pip install -U "jax[cuda12]"
22+
pip install -U "jax[cuda12]==0.4.33"
2323
2424
Note that using the TPU version of *jaxlib* is, in principle, also supported by
2525
this library. However, it has not been thoroughly tested and should therefore be

docs/source/user_guide/training.rst

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -82,11 +82,11 @@ Loss
8282

8383
All losses must be implemented as derived classes of
8484
:py:class:`Loss <mlip.models.loss.Loss>`. We currently implement two losses, the
85-
Mean-Squared-Error loss (:py:class:`Loss <mlip.models.loss.MSELoss>`), and the
86-
Huber loss (:py:class:`Loss <mlip.models.loss.HuberLoss>`), which are both losses
85+
Mean-Squared-Error loss (:py:class:`MSELoss <mlip.models.loss.MSELoss>`), and the
86+
Huber loss (:py:class:`HuberLoss <mlip.models.loss.HuberLoss>`), which are both losses
8787
that are derived from a loss that computes errors for energies, forces, and stress,
8888
and weights them according to some weighting schedule that can depend on the epoch
89-
number (base class: :py:class:`Loss <mlip.models.loss.WeightedEFSLoss>`).
89+
number (base class: :py:class:`WeightedEFSLoss <mlip.models.loss.WeightedEFSLoss>`).
9090

9191
If one wants to use the MSE loss for training, simply run this code to initialize it:
9292

@@ -105,7 +105,16 @@ If one wants to use the MSE loss for training, simply run this code to initializ
105105
106106
For our two implemented losses, we also allow for computation of more extended metrics
107107
by setting the `extended_metrics` argument to `True` in the loss constructor.
108-
By default, it is `False`.
108+
By default, it is `False`. See the documentation of
109+
the :py:class:`call method <mlip.models.loss.WeightedEFSLoss.__call__>` of the class
110+
:py:class:`WeightedEFSLoss <mlip.models.loss.WeightedEFSLoss>` for more information on
111+
the returned metrics.
112+
113+
Furthermore, note that even though the loss class is supposed to provide these metrics
114+
averaged just over a given input batch, we reweight these metrics based on the number
115+
of real (not dummy) graphs per batch in the training loop, such that the
116+
resulting metrics that are logged during training are accurately averaged
117+
over the whole dataset.
109118

110119
.. _training_optimizer:
111120

@@ -119,9 +128,13 @@ however, this library also has a specialized pipeline that has been inspired by
119128
`this <https://github.com/ACEsuit/mace>`_ PyTorch MACE implementation.
120129
It is configurable via a
121130
:py:class:`OptimizerConfig <mlip.training.optimizer_config.OptimizerConfig>` object that
122-
has sensible defaults set for training MLIP models.
131+
has sensible defaults set for training MLIP models. However, we suggest to also check
132+
out `our white paper <https://arxiv.org/abs/2505.22397>`_ for recommendations for
133+
sensible ways to adapt the defaults for specific models, for instance, ViSNet and
134+
NequIP seem to be more prone to NaNs with the default learning rate and benefit from
135+
using a smaller one such as ``1e-4``.
123136

124-
This default MLIP optimizer can be set up like this:
137+
The default MLIP optimizer can be set up like this:
125138

126139
.. code-block:: python
127140
@@ -206,6 +219,14 @@ which prints the training metrics to the console in a nice table format (using
206219
:py:func:`log_metrics_to_line() <mlip.training.training_loggers.log_metrics_to_line>`,
207220
which logs the metrics in a single line.
208221

222+
These logging functions automatically convert any MSE metrics to RMSE for easier
223+
interpretation. Internally, we only keep track of MSE instead of RMSE because we must
224+
ensure that the square root is taken at the very end and not before any averaging
225+
across batches or devices happens. If one desires to do the same conversion in their
226+
custom logging function, see
227+
:py:func:`convert_mse_to_rmse_in_logs() <mlip.training.training_loggers.convert_mse_to_rmse_in_logs>`,
228+
which is a helper function we provide for this task.
229+
209230
Note that it is possible to omit the `io_handler` argument in the
210231
:py:class:`TrainingLoop <mlip.training.training_loop.TrainingLoop>` class. In that case,
211232
a default IO handler is set up internally and used. This IO handler does not include

mlip/models/loss.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,40 @@ def __call__(
9898
epoch: int,
9999
eval_metrics: bool = False,
100100
) -> tuple[float, dict[str, float]]:
101+
"""The call function that outputs the loss and metrics (auxiliary data).
102+
103+
The metrics returned by this class if `eval_metrics=False`:
104+
- average loss per structure
105+
- energy, forces, and stress weighting factors
106+
107+
The metrics returned by this class if `eval_metrics=True`:
108+
- average loss per structure
109+
- MAE and MAE per atom (for energies, forces, and stress)
110+
- MSE and MSE per atom (for energies, forces, and stress)
111+
112+
**Important note 1:** we provide MSE instead of RMSE, because MSE and MAE
113+
metrics allow for downstream reweighting by number of real graphs per batch
114+
to obtain the correct metrics over the whole dataset. This reweighting
115+
is necessary as not every batch has the same number of real
116+
(not dummy) graphs and is therefore done as part of the training loop.
117+
Feel free to take the square root of the final MSE metric before logging it.
118+
The default loggers provided with this library also report RMSE instead of MSE
119+
during training.
120+
121+
**Important note 2:** we use per-component errors for forces instead of
122+
computing force error vectors per atom and then computing their norm.
123+
124+
Args:
125+
prediction: The force field predictor's outputs.
126+
ref_graph: The reference graph holding the ground truth data.
127+
epoch: The epoch number.
128+
eval_metrics: Switch deciding whether to include additional
129+
evaluation metrics to the returned dictionary.
130+
Default is `False`.
131+
132+
Returns:
133+
The loss and the auxiliary metrics dictionary.
134+
"""
101135
# Get weights
102136
energy_weight = self.energy_weight_schedule(epoch)
103137
forces_weight = self.forces_weight_schedule(epoch)

mlip/models/loss_helpers.py

Lines changed: 25 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -52,52 +52,16 @@ def compute_mae_stress(delta: jnp.ndarray, mask) -> float:
5252
return _masked_mean_stress(jnp.abs(delta), mask)
5353

5454

55-
def compute_rel_mae(delta: jnp.ndarray, target_val: jnp.ndarray, mask) -> float:
56-
target_norm = _masked_mean(jnp.abs(target_val), mask)
57-
return _masked_mean(jnp.abs(delta), mask) / (target_norm + 1e-30)
55+
def compute_mse(delta: jnp.ndarray, mask) -> float:
56+
return _masked_mean(jnp.square(delta), mask)
5857

5958

60-
def compute_rel_mae_f(delta: jnp.ndarray, target_val: jnp.ndarray, mask) -> float:
61-
target_norm = _masked_mean_f(jnp.abs(target_val), mask)
62-
return _masked_mean_f(jnp.abs(delta), mask) / (target_norm + 1e-30)
59+
def compute_mse_f(delta: jnp.ndarray, mask) -> float:
60+
return _masked_mean_f(jnp.square(delta), mask)
6361

6462

65-
def compute_rel_mae_stress(delta: jnp.ndarray, target_val: jnp.ndarray, mask) -> float:
66-
target_norm = _masked_mean_stress(jnp.abs(target_val), mask)
67-
return _masked_mean_stress(jnp.abs(delta), mask) / (target_norm + 1e-30)
68-
69-
70-
def compute_rmse(delta: jnp.ndarray, mask) -> float:
71-
return jnp.sqrt(_masked_mean(jnp.square(delta), mask))
72-
73-
74-
def compute_rmse_f(delta: jnp.ndarray, mask) -> float:
75-
return jnp.sqrt(_masked_mean_f(jnp.square(delta), mask))
76-
77-
78-
def compute_rmse_stress(delta: jnp.ndarray, mask) -> float:
79-
return jnp.sqrt(_masked_mean_stress(jnp.square(delta), mask))
80-
81-
82-
def compute_rel_rmse(delta: jnp.ndarray, target_val: jnp.ndarray, mask) -> float:
83-
target_norm = jnp.sqrt(_masked_mean(jnp.square(target_val), mask))
84-
return jnp.sqrt(_masked_mean(jnp.square(delta), mask)) / (target_norm + 1e-30)
85-
86-
87-
def compute_rel_rmse_f(delta: jnp.ndarray, target_val: jnp.ndarray, mask) -> float:
88-
target_norm = jnp.sqrt(_masked_mean_f(jnp.square(target_val), mask))
89-
return jnp.sqrt(_masked_mean_f(jnp.square(delta), mask)) / (target_norm + 1e-30)
90-
91-
92-
def compute_rel_rmse_stress(delta: jnp.ndarray, target_val: jnp.ndarray, mask) -> float:
93-
target_norm = jnp.sqrt(_masked_mean_stress(jnp.square(target_val), mask))
94-
return jnp.sqrt(_masked_mean_stress(jnp.square(delta), mask)) / (
95-
target_norm + 1e-30
96-
)
97-
98-
99-
def compute_q95(delta: jnp.ndarray) -> float:
100-
return jnp.percentile(jnp.abs(delta), q=95)
63+
def compute_mse_stress(delta: jnp.ndarray, mask) -> float:
64+
return _masked_mean_stress(jnp.square(delta), mask)
10165

10266

10367
def _sum_nodes_of_the_same_graph(
@@ -295,120 +259,66 @@ def compute_eval_metrics(
295259
stress_per_atom_list.append(ref_graph.globals.stress / jnp.sum(node_mask))
296260

297261
metrics = {
298-
"mae_e": None,
299-
"rel_mae_e": None,
300-
"mae_e_per_atom": None,
301-
"rel_mae_e_per_atom": None,
302-
"rmse_e": None,
303-
"rel_rmse_e": None,
304-
"rmse_e_per_atom": None,
305-
"rel_rmse_e_per_atom": None,
306-
"q95_e": None,
307-
"mae_f": None,
308-
"rel_mae_f": None,
309-
"rmse_f": None,
310-
"rel_rmse_f": None,
311-
"q95_f": None,
312-
"mae_stress": None,
313-
"rel_mae_stress": None,
314-
"mae_stress_per_atom": None,
315-
"rel_mae_stress_per_atom": None,
316-
"rmse_stress": None,
317-
"rel_rmse_stress": None,
318-
"rmse_stress_per_atom": None,
319-
"rel_rmse_stress_per_atom": None,
320-
"q95_stress": None,
262+
"mae_e": jnp.nan,
263+
"mae_e_per_atom": jnp.nan,
264+
"mse_e": jnp.nan,
265+
"mse_e_per_atom": jnp.nan,
266+
"mae_f": jnp.nan,
267+
"mse_f": jnp.nan,
268+
"mae_stress": jnp.nan,
269+
"mae_stress_per_atom": jnp.nan,
270+
"mse_stress": jnp.nan,
271+
"mse_stress_per_atom": jnp.nan,
321272
}
322273

323274
if len(delta_es_list) > 0:
324275
delta_es = jnp.concatenate(delta_es_list, axis=0)
325276
delta_es_per_atom = jnp.concatenate(delta_es_per_atom_list, axis=0)
326-
es = jnp.concatenate(es_list, axis=0)
327-
es_per_atom = jnp.concatenate(es_per_atom_list, axis=0)
328277

329278
metrics.update(
330279
{
331280
# Mean absolute error
332281
"mae_e": compute_mae(delta_es, graph_mask),
333-
# Root-mean-square error
334-
"rmse_e": compute_rmse(delta_es, graph_mask),
282+
# Mean-square error
283+
"mse_e": compute_mse(delta_es, graph_mask),
335284
}
336285
)
337286
if extended_metrics:
338287
metrics.update(
339288
{
340289
# Mean absolute error
341-
"rel_mae_e": compute_rel_mae(delta_es, es, graph_mask),
342290
"mae_e_per_atom": compute_mae(delta_es_per_atom, graph_mask),
343-
"rel_mae_e_per_atom": compute_rel_mae(
344-
delta_es_per_atom, es_per_atom, graph_mask
345-
),
346-
# Root-mean-square error
347-
"rel_rmse_e": compute_rel_rmse(delta_es, es, graph_mask),
348-
"rmse_e_per_atom": compute_rmse(delta_es_per_atom, graph_mask),
349-
"rel_rmse_e_per_atom": compute_rel_rmse(
350-
delta_es_per_atom, es_per_atom, graph_mask
351-
),
352-
# Q_95
353-
"q95_e": compute_q95(delta_es),
291+
# Mean-square error
292+
"mse_e_per_atom": compute_mse(delta_es_per_atom, graph_mask),
354293
}
355294
)
356295

357296
if len(delta_fs_list) > 0:
358297
delta_fs = jnp.concatenate(delta_fs_list, axis=0)
359-
fs = jnp.concatenate(fs_list, axis=0)
360-
361298
metrics.update(
362299
{
363300
# Mean absolute error
364301
"mae_f": compute_mae_f(delta_fs, node_mask),
365-
# Root-mean-square error
366-
"rmse_f": compute_rmse_f(delta_fs, node_mask),
302+
# Mean-square error
303+
"mse_f": compute_mse_f(delta_fs, node_mask),
367304
}
368305
)
369-
if extended_metrics:
370-
metrics.update(
371-
{
372-
# Mean absolute error
373-
"rel_mae_f": compute_rel_mae_f(delta_fs, fs, node_mask),
374-
# Root-mean-square error
375-
"rel_rmse_f": compute_rel_rmse_f(delta_fs, fs, node_mask),
376-
# Q_95
377-
"q95_f": compute_q95(delta_fs),
378-
}
379-
)
380306

381307
if len(delta_stress_list) > 0 and extended_metrics:
382308
delta_stress = jnp.concatenate(delta_stress_list, axis=0)
383309
delta_stress_per_atom = jnp.concatenate(delta_stress_per_atom_list, axis=0)
384-
stress = jnp.concatenate(stress_list, axis=0)
385-
stress_per_atom = jnp.concatenate(stress_per_atom_list, axis=0)
386310
metrics.update(
387311
{
388312
# Mean absolute error
389313
"mae_stress": compute_mae_stress(delta_stress, graph_mask),
390-
"rel_mae_stress": compute_rel_mae_stress(
391-
delta_stress, stress, graph_mask
392-
),
393314
"mae_stress_per_atom": compute_mae_stress(
394315
delta_stress_per_atom, graph_mask
395316
),
396-
"rel_mae_stress_per_atom": compute_rel_mae_stress(
397-
delta_stress_per_atom, stress_per_atom, graph_mask
398-
),
399-
# Root-mean-square error
400-
"rmse_stress": compute_rmse_stress(delta_stress, graph_mask),
401-
"rel_rmse_stress": compute_rel_rmse_stress(
402-
delta_stress, stress, graph_mask
403-
),
404-
"rmse_stress_per_atom": compute_rmse_stress(
317+
# Mean-square error
318+
"mse_stress": compute_mse_stress(delta_stress, graph_mask),
319+
"mse_stress_per_atom": compute_mse_stress(
405320
delta_stress_per_atom, graph_mask
406321
),
407-
"rel_rmse_stress_per_atom": compute_rel_rmse_stress(
408-
delta_stress_per_atom, stress_per_atom, graph_mask
409-
),
410-
# Q_95
411-
"q95_stress": compute_q95(delta_stress),
412322
}
413323
)
414324

mlip/training/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@
1818
)
1919
from mlip.training.optimizer_config import OptimizerConfig
2020
from mlip.training.training_io_handler import TrainingIOHandler, TrainingIOHandlerConfig
21-
from mlip.training.training_loggers import log_metrics_to_line, log_metrics_to_table
21+
from mlip.training.training_loggers import (
22+
convert_mse_to_rmse_in_logs,
23+
log_metrics_to_line,
24+
log_metrics_to_table,
25+
)
2226
from mlip.training.training_loop import TrainingLoop
2327
from mlip.training.training_loop_config import TrainingLoopConfig

0 commit comments

Comments
 (0)