Skip to content

Commit ccc72a6

Browse files
authored
feat: update python (#84)
* feat: make jaxmd a project dependency * feat: upgrade python version to 3.11 * feat: add optional gpu dependencies * feat: test tox without dependencies * feat: attempt fixing ci tests * docs: update README.md * ci: try uv venv 3.13 * ci: switch back to tox * docs: fix badge * docs: update README.md * chore: remove tox pr trigger * docs: update installation * docs: remove editable install
1 parent 8e575d5 commit ccc72a6

File tree

6 files changed

+19
-64
lines changed

6 files changed

+19
-64
lines changed

.github/workflows/tests_and_linters.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ jobs:
5353

5454
- name: Run tests 🧪
5555
run: |
56-
uv sync --group dev --group jax_md
56+
uv sync --only-dev
5757
uv run pytest --verbose --cov-report xml:coverage.xml \
5858
--cov-report term-missing \
5959
--junitxml=pytest.xml \

.github/workflows/tox_tests.yaml

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -5,26 +5,6 @@ on:
55
branches: [ main, develop ]
66

77
jobs:
8-
tox_py310:
9-
runs-on: instadeep-ci-4
10-
container:
11-
image: python:3.10-slim-bullseye
12-
13-
steps:
14-
- name: Git setup
15-
run: |
16-
apt-get update && apt-get install -y coreutils git
17-
18-
- name: Checkout code 📦
19-
uses: actions/checkout@v4
20-
with:
21-
fetch-depth: '0'
22-
23-
- name: Run tests with tox 🧪
24-
run: |
25-
pip install tox==4.30.3
26-
tox -e py310
27-
288
tox_py311:
299
runs-on: instadeep-ci-4
3010
container:

README.md

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# 🔬 MLIPAudit: A library to validate and benchmark MLIP models
22

33
[![uv](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/uv/main/assets/badge/v0.json)](https://github.com/astral-sh/uv)
4-
[![Python 3.11](https://img.shields.io/badge/python-3.10%20%7C%203.11-blue)](https://www.python.org/downloads/release/python-3110/)
4+
[![Python 3.11](https://img.shields.io/badge/python-3.11%20%7C%203.12%20%7C%203.13-blue))](https://www.python.org/downloads/release/python-3110/)
55
[![pre-commit](https://img.shields.io/badge/pre--commit-enabled-brightgreen?logo=pre-commit)](https://github.com/pre-commit/pre-commit)
66
![badge](https://img.shields.io/endpoint?url=https://gist.githubusercontent.com/mlipbot/e7c79b17c0a9d47bc826100ef880a16f/raw/pytest-coverage-comment.json)
77
[![Tests and Linters 🧪](https://github.com/instadeepai/mlipaudit-open/actions/workflows/tests_and_linters.yaml/badge.svg?branch=main)](https://github.com/instadeepai/mlipaudit-open/actions/workflows/tests_and_linters.yaml)
@@ -29,21 +29,13 @@ MLIPAudit can be installed via pip:
2929
pip install mlipaudit
3030
```
3131

32-
However, this command **only installs the regular CPU version** of JAX.
33-
We recommend that MLIPAudit is run on GPU. Also, some benchmarks will require
34-
[JAX-MD](https://github.com/jax-md/jax-md) as a dependency. As the newest
35-
version of JAX-MD is not available on PyPI yet, this dependency will not
36-
be shipped with MLIPAudit automatically and instead must be installed
37-
directly from the GitHub repository.
38-
39-
Therefore, we recommend running
40-
32+
However, this command **only installs the regular CPU version** of JAX. If benchmarking
33+
native JAX models, we recommend installing the core library along with the GPU
34+
dependencies (`jax[cuda12]` and `jaxlib`) with the following command:
4135
```bash
42-
pip install -U "jax[cuda12]" git+https://github.com/jax-md/jax-md.git
36+
pip install mlipaudit[gpu]
4337
```
4438

45-
to install both of these additional packages.
46-
4739
## 📖 Documentation
4840

4941
The detailed code documentation that also contains descriptions for each benchmark and

docs/source/installation/index.rst

Lines changed: 3 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -13,23 +13,9 @@ After installation and activating the respective Python environment, the command
1313
tools `mlipaudit` and `mlipauditapp` should be available.
1414

1515
However, the command above **only installs the regular CPU version** of JAX.
16-
We recommend that the library is run on GPU.
17-
This requires also installing the necessary versions
18-
of `jaxlib <https://pypi.org/project/jaxlib/>`_ which can also be installed via pip. See
19-
the `installation guide of JAX <https://docs.jax.dev/en/latest/installation.html>`_ for
20-
more information.
21-
At time of release, the following install command is supported:
16+
If benchmarking native JAX models, we recommend installing the core library
17+
along with the GPU dependencies (`jax[cuda12]` and `jaxlib`) with the following command:
2218

2319
.. code-block:: bash
2420
25-
pip install -U "jax[cuda12]"
26-
27-
Also, some benchmarks require `JAX-MD <https://github.com/jax-md/jax-md>`_ as a
28-
dependency. As the newest
29-
version of JAX-MD is not available on PyPI yet, this dependency will not
30-
be shipped with *mlipaudit* automatically and instead must be installed
31-
directly from the GitHub repository, like this:
32-
33-
.. code-block:: bash
34-
35-
pip install git+https://github.com/jax-md/jax-md.git
21+
pip install mlipaudit[gpu]

pyproject.toml

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ readme = "README.md"
66
authors = [
77
{ name = "InstaDeep", email = "[email protected]" }
88
]
9-
requires-python = ">=3.10"
9+
requires-python = ">=3.11"
1010
dependencies = [
1111
"huggingface-hub>=0.33.4",
1212
"mdtraj>=1.10.3",
@@ -16,6 +16,7 @@ dependencies = [
1616
"vl-convert-python>=1.8.0",
1717
"mdtraj>=1.10.3",
1818
"tmtools>=0.2.0",
19+
"jax-md>=0.2.26"
1920
]
2021

2122
[project.scripts]
@@ -25,6 +26,12 @@ mlipaudit = "mlipaudit.main:main"
2526
requires = ["hatchling"]
2627
build-backend = "hatchling.build"
2728

29+
[project.optional-dependencies]
30+
gpu = [
31+
"jax[cuda12]>=0.4.33",
32+
"jaxlib>=0.4.33"
33+
]
34+
2835
[dependency-groups]
2936
dev = [
3037
"notebook>=7.4.4",
@@ -39,19 +46,9 @@ dev = [
3946
scripts = [
4047
"requests>=2.32.4",
4148
]
42-
gpu = [
43-
"jax[cuda12]>=0.4.33",
44-
"jaxlib>=0.4.33"
45-
]
46-
jax_md = [
47-
"jax-md",
48-
]
4949

5050
[tool.coverage.run]
5151
omit = [
5252
"main.py",
5353
"app.py"
5454
]
55-
56-
[tool.uv.sources]
57-
jax-md = { git = "https://github.com/jax-md/jax-md.git", tag = "jax-md-v0.2.25" }

tox.ini

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
[tox]
22
isolated_build = True
3-
env_list = py{310,311,312,313}
3+
env_list = py{311,312,313}
44

55
[testenv]
66
deps =
77
pytest>=8.4.1
88
pytest-mock>=3.14.1
9-
git+https://github.com/jax-md/jax-md.git
9+
description = Test various python versions
1010
commands = pytest {posargs}

0 commit comments

Comments
 (0)