Skip to content

Commit 17f9618

Browse files
release: create release-0.1.0 branch
1 parent 2837288 commit 17f9618

File tree

7 files changed

+42
-25
lines changed

7 files changed

+42
-25
lines changed

CHANGELOG

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,10 @@
22

33
## Release 0.1.0
44

5-
To be added upon release...
5+
- Implemented model architectures: MACE, NequIP and ViSNet
6+
- Dataset preprocessing
7+
- Training of MLIP models
8+
- Batched inference with trained MLIP models
9+
- MD simulations with MLIP models using JAX-MD and ASE simulation backends
10+
- Energy minimizations with MLIP models using the same simulation backends
11+
- Fine-tuning of pre-trained MLIP models (only for MACE)

README.md

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,6 @@
22

33
![badge](https://img.shields.io/endpoint?url=https://gist.githubusercontent.com/mlipbot/b6e4bf384215e60775699a83c3c00aef/raw/pytest-coverage-comment.json)
44

5-
## ⚠️ Important note
6-
7-
The *mlip* library is currently available as a pre-release version only.
8-
The release of the first stable version will follow later this month.
9-
105
## 👀 Overview
116

127
*mlip* is a Python library for training and deploying
@@ -27,6 +22,9 @@ material science applications, (2) **extensibility and flexibility** for users m
2722
experienced with MLIP and JAX, and (3) a focus on **high inference speeds** that enable
2823
running long MD simulations on large systems which we believe is necessary in order to
2924
bring MLIP to large-scale industrial application.
25+
See our [inference speed benchmark](#-inference-time-benchmarks) below.
26+
With our library, we observe a 10x speedup on 138 atoms and up to 4x speed up
27+
on 1205 atoms over equivalent implementations relying on Torch and ASE.
3028

3129
See the [Installation](#-installation) section for details on how to install
3230
MLIP-JAX and the example Google Colab notebooks linked below for a quick way
@@ -75,6 +73,11 @@ directly from the GitHub repository, like this:
7573
pip install git+https://github.com/jax-md/jax-md.git
7674
```
7775

76+
Furthermore, note that among our library dependencies we have pinned the versions
77+
for *jaxlib*, *matscipy*, and *orbax-checkpoint* to one specific version only to
78+
prioritize reliability, however, we plan to allow for a more flexible definition of
79+
our dependencies in upcoming releases.
80+
7881
## ⚡ Examples
7982

8083
In addition to the in-depth tutorials provided as part of our documentation
@@ -130,35 +133,39 @@ please refer to the model cards of the relevant HuggingFace repos.
130133

131134
## 🚀 Inference time benchmarks
132135

133-
In order to showcase the runtime efficiency, we conducted benchmarks across all three models
134-
on two different systems: 1UAO (138 atoms) and 1ABT (1205 atoms), both run for 1ns on a H100
135-
NVidia GPU. All model implementations are our own, including the Torch + ASE benchmarks, and
136+
In order to showcase the runtime efficiency, we conducted benchmarks across all three
137+
models on two different systems: Chignolin
138+
([1UAO](https://www.rcsb.org/structure/1UAO), 138 atoms) and Alpha-bungarotoxin
139+
([1ABT](https://www.rcsb.org/structure/1ABT), 1205 atoms), both run for 1 ns of
140+
MD simulation on a H100 NVIDIA GPU.
141+
All model implementations are our own, including the Torch + ASE benchmarks, and
136142
should not be considered representative of the performance of the code developed by the
137-
original authors of the methods. Further details can be found in our whitepaper (see below).
143+
original authors of the methods.
144+
Further details can be found in our white paper (see [below](#-citing-our-work)).
138145

139146
**MACE (2,139,152 parameters):**
140-
| Systems | JAX + JAX MD | JAX + ASE | Torch + ASE |
147+
| Systems | JAX + JAX-MD | JAX + ASE | Torch + ASE |
141148
| --------- |-------------:|-------------:|-------------:|
142-
| 1UAO | 6.3 ms/step | 11.6 ms/step | TBC ms/step |
143-
| 1ABT | TBC ms/step | TBC ms/step | TBC ms/step |
149+
| 1UAO | 6.3 ms/step | 11.6 ms/step | 44.2/step |
150+
| 1ABT | 66.8 ms/step | 99.5 ms/step | 157.2/step |
144151

145152
**ViSNet (1,137,922 parameters):**
146-
| Systems | JAX + JAX MD | JAX + ASE | Torch + ASE |
153+
| Systems | JAX + JAX-MD | JAX + ASE | Torch + ASE |
147154
| --------- |-------------:|-------------:|-------------:|
148155
| 1UAO | 2.9 ms/step | 6.2 ms/step | 33.8 ms/step |
149-
| 1ABT | 25.4 ms/step | TBC ms/step | TBC ms/step |
156+
| 1ABT | 25.4 ms/step | 46.4 ms/step | 101.6 ms/step|
150157

151158
**NequIP (1,327,792 parameters):**
152-
| Systems | JAX + JAX MD | JAX + ASE | Torch + ASE |
159+
| Systems | JAX + JAX-MD | JAX + ASE | Torch + ASE |
153160
| --------- |-------------:|-------------:|-------------:|
154161
| 1UAO | 3.8 ms/step | 8.5 ms/step | 38.7 ms/step |
155-
| 1ABT | TBC ms/step | TBC ms/step | TBC ms/step |
162+
| 1ABT | 67.0 ms/step | 105.7 ms/step| 117.0 ms/step|
156163

157164
## 🙏 Acknowledgments
158165

159166
We would like to acknowledge beta testers for this library: Isabel Wilkinson,
160167
Nick Venanzi, Hassan Sirelkhatim, Leon Wehrhan, Sebastien Boyer, Massimo Bortone,
161-
Tom Barrett, and Alex Laterre.
168+
Scott Cameron, Louis Robinson, Tom Barrett, and Alex Laterre.
162169

163170
## 📚 Citing our work
164171

mlip/inference/batched_inference.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,10 @@ def run_batched_inference(
136136
stress tensors. Result will be returned as a list of `Prediction` objects, one
137137
for each input structure.
138138
139+
Note: When using ``batch_size=1``, we recommend to set ``max_n_node`` and
140+
``max_n_edge`` explicitly to avoid edge cases in the automated computation of these
141+
parameters that may cause errors.
142+
139143
Args:
140144
structures: The structures to batch and then compute predictions for.
141145
force_field: The force field object to compute the predictions with.

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "mlip"
3-
version = "0.0.1a5"
3+
version = "0.1.0"
44
description = "Machine Learning Interatomic Potentials in JAX"
55
license = "LICENSE"
66
authors = [

tutorials/model_addition_tutorial.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,10 @@
3434
},
3535
"outputs": [],
3636
"source": [
37-
"!pip install mlip \"jax[cuda12]==0.4.33\"\n",
37+
"%pip install mlip \"jax[cuda12]==0.4.33\"\n",
3838
"\n",
3939
"# Use this instead for installation without GPU:\n",
40-
"# !pip install mlip"
40+
"# %pip install mlip"
4141
]
4242
},
4343
{

tutorials/model_training_tutorial.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,10 +57,10 @@
5757
},
5858
"outputs": [],
5959
"source": [
60-
"!pip install mlip \"jax[cuda12]==0.4.33\" huggingface_hub\n",
60+
"%pip install mlip \"jax[cuda12]==0.4.33\" huggingface_hub\n",
6161
"\n",
6262
"# Use this instead for installation without GPU:\n",
63-
"# !pip install mlip huggingface_hub"
63+
"# %pip install mlip huggingface_hub"
6464
]
6565
},
6666
{

tutorials/simulation_tutorial.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,10 +66,10 @@
6666
},
6767
"outputs": [],
6868
"source": [
69-
"!pip install mlip \"jax[cuda12]==0.4.33\" huggingface_hub git+https://github.com/jax-md/jax-md.git\n",
69+
"%pip install mlip \"jax[cuda12]==0.4.33\" huggingface_hub git+https://github.com/jax-md/jax-md.git\n",
7070
"\n",
7171
"# Use this instead for installation without GPU:\n",
72-
"# !pip install mlip huggingface_hub git+https://github.com/jax-md/jax-md.git"
72+
"# %pip install mlip huggingface_hub git+https://github.com/jax-md/jax-md.git"
7373
]
7474
},
7575
{

0 commit comments

Comments
 (0)