Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/deploy_docs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ jobs:

- name: Build documentation
run: |
uv run sphinx-build -b html docs/source/ _build/
uv run --only-dev sphinx-build -b html docs/source/ _build/

- name: Deploy to GitHub Pages
uses: peaceiris/actions-gh-pages@v3
Expand Down
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ However, this command **only installs the regular CPU version** of JAX. If bench
native JAX models, we recommend installing the core library along with the GPU
dependencies (`jax[cuda12]` and `jaxlib`) with the following command:
```bash
pip install mlipaudit[gpu]
pip install mlipaudit[cuda]
```

## 📖 Documentation
Expand Down Expand Up @@ -158,10 +158,10 @@ package and dependency management.

This command installs all dependency groups. We recommend to check out
the `pyproject.toml` file for information on the available groups. Most notably,
the group `gpu` installs the GPU-ready version of JAX which are strongly recommended.
If you do not want to install the `gpu` dependency group (for example, because you are
the group `cuda` installs the GPU-ready version of JAX which are strongly recommended.
If you do not want to install the `cuda` dependency group (for example, because you are
on MacOS that does not support this standard installation), you can use the
`--no-group gpu` option in the [uv](https://docs.astral.sh/uv/) command.
`--no-group cuda` option in the [uv](https://docs.astral.sh/uv/) command.

When adding new benchmarks, make sure that the following key pieces are added
for each one:
Expand Down
2 changes: 1 addition & 1 deletion docs/source/installation/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,4 @@ along with the GPU dependencies (`jax[cuda12]` and `jaxlib`) with the following

.. code-block:: bash

pip install mlipaudit[gpu]
pip install mlipaudit[cuda]
9 changes: 4 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@ requires-python = ">=3.11"
dependencies = [
"huggingface-hub>=0.33.4",
"mdtraj>=1.10.3",
"mlip>=0.1.4",
"mlip>=0.1.6",
"scikit-learn>=1.7.0",
"streamlit>=1.46.1",
"vl-convert-python>=1.8.0",
"mdtraj>=1.10.3",
"tmtools>=0.2.0",
"jax-md>=0.2.26"
"jax-md>=0.2.26",
]

[project.scripts]
Expand All @@ -27,9 +27,8 @@ requires = ["hatchling"]
build-backend = "hatchling.build"

[project.optional-dependencies]
gpu = [
"jax[cuda12]>=0.4.33",
"jaxlib>=0.4.33"
cuda = [
"mlip[cuda]>=0.1.6",
]

[dependency-groups]
Expand Down