Skip to content

Commit d5afa35

Browse files
authored
Alchemiops KNN+D3, TorchSim, Refactor (#143)
* Switch to uv * Update code * Update tests * Readmes, examples, fixes * TorchSim and D3 examples, bug fixes * Use uv in CI * Clean up * Dockerfile * Made torch-sim optional * Removed default loss_weights from conservative/direct_regressor, added them in pretrained.py instead * Added changelog to README.md
1 parent 6ad92b4 commit d5afa35

File tree

125 files changed

+10094
-7134
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

125 files changed

+10094
-7134
lines changed

.flake8

Lines changed: 0 additions & 34 deletions
This file was deleted.

.github/CODEOWNERS

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@
22
# the repo. Unless a later match takes precedence,
33
# @global-owner1 and @global-owner2 will be requested for
44
# review when someone opens a pull request.
5-
* @DeNeutoy @benrhodes26 @jg8610 @reactiv @zhiyil1230
5+
* @benrhodes26 @vsimkus @jg8610 @reactiv

.github/workflows/publish.yml

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ jobs:
1212
runs-on: ubuntu-latest
1313
strategy:
1414
matrix:
15-
python-version: ["3.10", "3.11", "3.12"]
15+
python-version: ["3.12", "3.13", "3.14"]
1616
if: startsWith(github.ref, 'refs/tags/')
1717
steps:
1818
- uses: actions/checkout@v4
@@ -25,13 +25,20 @@ jobs:
2525
python-version: ${{ matrix.python-version }}
2626
cache: pip
2727
cache-dependency-path: pyproject.toml
28-
- name: Install dependencies
28+
- name: Set up uv
2929
run: |
30-
pip install '.[test]'
30+
pip install uv
31+
- name: Install dependencies
32+
run: uv sync --group dev
3133

32-
- name: Run tests
34+
- name: Run linters
3335
run: |
34-
pytest
36+
uv run ruff check ./orb_models ./tests ./scripts
37+
uv run ruff format --check ./orb_models ./tests ./scripts
38+
uv run mypy ./orb_models ./tests ./scripts
39+
- name: Run tests
40+
run: uv run pytest ./tests -n auto
41+
3542
deploy:
3643
runs-on: ubuntu-latest
3744
needs: [test]

.github/workflows/test.yml

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,13 @@ on: [push, pull_request]
55
permissions:
66
contents: read
77

8+
89
jobs:
910
test:
1011
runs-on: ubuntu-latest
1112
strategy:
1213
matrix:
13-
python-version: ["3.10", "3.11", "3.12"]
14+
python-version: ["3.12", "3.13", "3.14"]
1415
steps:
1516
- uses: actions/checkout@v4
1617
- name: Update version
@@ -22,10 +23,16 @@ jobs:
2223
python-version: ${{ matrix.python-version }}
2324
cache: pip
2425
cache-dependency-path: pyproject.toml
25-
- name: Install dependencies
26+
- name: Set up uv
2627
run: |
27-
pip install '.[test]'
28+
pip install uv
29+
- name: Install dependencies
30+
run: uv sync --group dev
2831

29-
- name: Run tests
32+
- name: Run linters
3033
run: |
31-
pytest
34+
uv run ruff check ./orb_models ./tests ./scripts
35+
uv run ruff format --check ./orb_models ./tests ./scripts
36+
uv run mypy ./orb_models ./tests ./scripts
37+
- name: Run tests
38+
run: uv run pytest ./tests -n auto

.gitignore

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
# MacOS
2+
.DS_Store
3+
14
# Byte-compiled / optimized / DLL files
25
__pycache__/
36
*.py[cod]
@@ -101,6 +104,10 @@ ipython_config.py
101104
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102105
#poetry.lock
103106

107+
# uv
108+
# https://docs.astral.sh/uv/concepts/projects/sync/#upgrading-locked-package-versions
109+
uv.lock
110+
104111
# pdm
105112
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106113
#pdm.lock
@@ -166,5 +173,7 @@ datasets/
166173

167174
# VS Code
168175
.devcontainer/
176+
.vscode/
169177

170-
wandb/
178+
wandb/
179+
ckpts/

.python-version

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
# Python version, ensures consistent development environment
2+
# If the version changes in this file, running `uv sync` or `uv sync --group dev`
3+
# will download the relevant Python version
4+
3.12

CONTRIBUTING.md

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,30 @@
11
### Setting up a development environment
22

3-
The `orb_models` package uses [Poetry](https://python-poetry.org/) for dependency management. To install the package and its dependencies, run the following command:
3+
The `orb_models` repository uses [uv](https://docs.astral.sh/uv/) for dependency management. To install the package and its dependencies, run the following command:
44

55
```bash
6-
pip install poetry # Install Poetry if you don't have it
7-
poetry install
6+
pipx install uv # If you don't have uv, we recommend installing it into an isolated environment with pipx: https://docs.astral.sh/uv/getting-started/installation/#pypi
7+
uv sync --group dev # Install orb-models and development packages
88
```
99

10-
Optionally, also install [cuML](https://docs.rapids.ai/install/) (requires CUDA):
10+
### Running linters
11+
12+
The `orb_models` repository uses `ruff` for formatting and linting, and `mypy` for type checking. To run the linters, use the following commands:
13+
1114
```bash
12-
pip install "cuml-cu11==25.2.*" # For cuda versions >=11.4, <11.8
13-
pip install "cuml-cu12==25.2.*" # For cuda versions >=12.0, <13.0
15+
ruff format . # Format code
16+
ruff check . # Check for linting errors
17+
mypy . # Run type checking
1418
```
1519

1620
### Running tests
1721

18-
The `orb_models` package uses `pytest` for testing. To run the tests, navigate to the root directory of the package and run the following command:
22+
The `orb_models` repository uses `pytest` for testing. To run the tests, navigate to the root directory of the package and run the following command:
1923

2024
```bash
21-
pytest
25+
pytest -n auto ./tests/
2226
```
2327

2428
### Publishing
2529

26-
The `orb_models` package is published using [trusted publishers](https://docs.pypi.org/trusted-publishers/). Whenever a new release is created on GitHub, the package is automatically published to PyPI using GitHub Actions.
30+
The `orb_models` package is published using [trusted publishers](https://docs.pypi.org/trusted-publishers/). Whenever a new release is created on GitHub, the package is automatically published to PyPI using GitHub Actions.

Dockerfile

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
FROM pytorch/pytorch:2.6.0-cuda12.4-cudnn9-runtime
1+
FROM pytorch/pytorch:2.10.0-cuda12.6-cudnn9-runtime
22

33
ENV DEBIAN_FRONTEND=noninteractive
44

@@ -10,14 +10,12 @@ RUN apt-get update && \
1010
git \
1111
sudo \
1212
gcc \
13-
g++
14-
15-
# Help Numba find lubcudart.so
16-
ENV LD_LIBRARY_PATH=/opt/conda/lib/python3.11/site-packages/nvidia/cuda_runtime/lib:$LD_LIBRARY_PATH
17-
RUN ln -s \
18-
/opt/conda/lib/python3.11/site-packages/nvidia/cuda_runtime/lib/libcudart.so.12 \
19-
/opt/conda/lib/python3.11/site-packages/nvidia/cuda_runtime/lib/libcudart.so
13+
g++ \
14+
# Cleanup
15+
&& apt-get clean \
16+
&& rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/*
2017

2118
## Install Python requirements
22-
RUN pip install orb-models && \
23-
pip install "cuml-cu12==25.2.*"
19+
# For details on --break-system-packages, see: https://veronneau.org/python-311-pip-and-breaking-system-packages.html
20+
RUN pip install --break-system-packages orb-models && \
21+
pip install --break-system-packages "cuml-cu12==25.2.*"

FINETUNING_GUIDE.md

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -243,9 +243,11 @@ Lines starting with `#` are treated as comments and ignored.
243243

244244
### Loss Weights
245245

246-
- `--energy_loss_weight`: Weight for energy loss (default: uses model default, usually 1.0)
247-
- `--forces_loss_weight`: Weight for forces loss (automatically uses correct key for model type)
246+
- `--energy_loss_weight`: Weight for energy loss (default: 1.0)
247+
- `--forces_loss_weight`: Weight for forces loss (default: 1.0)
248248
- `--stress_loss_weight`: Weight for stress loss (set to 0 to disable)
249+
- `--equigrad_loss_weight`: Weight for the Equigrad loss (turned off by default). Only available for the conservative models.
250+
- NOTE: We've found that Equigrad loss should be ≳1000x smaller than the other losses
249251

250252
### Reference Energies
251253

@@ -274,7 +276,7 @@ If you prefer to write your own finetuning script, you can use the clean API dir
274276
from orb_models.forcefield import pretrained
275277

276278
# Load model with custom configuration
277-
model = pretrained.orb_v3_conservative_omol(
279+
model, atoms_adapter = pretrained.orb_v3_conservative_omol(
278280
device='cuda',
279281
precision='float32-high',
280282
train=True,
@@ -287,7 +289,7 @@ model = pretrained.orb_v3_conservative_omol(
287289
)
288290

289291
# For direct models, use 'forces' and 'stress' keys:
290-
model = pretrained.orb_v3_direct_omol(
292+
model, atoms_adapter = pretrained.orb_v3_direct_omol(
291293
device='cuda',
292294
train=True,
293295
loss_weights={
@@ -300,8 +302,6 @@ model = pretrained.orb_v3_direct_omol(
300302
# The model is now ready for training with your custom configuration!
301303
```
302304

303-
This approach is more "pythonic" and clearly documents what configuration options are available. It also encapsulates the implementation details, making your code less fragile to internal changes.
304-
305305
## How It Works
306306

307307
### Reference Energies
@@ -319,7 +319,7 @@ import torch
319319
from orb_models.forcefield import pretrained
320320

321321
# Load model architecture (set train=False for inference)
322-
model = pretrained.orb_v3_conservative_omol(train=False)
322+
model, atoms_adapter = pretrained.orb_v3_conservative_omol(train=False)
323323

324324
# Load your finetuned checkpoint
325325
model.load_state_dict(torch.load('path/to/finetuned_checkpoint.pt'))
@@ -331,7 +331,7 @@ You can also specify loss weights when loading for further finetuning:
331331

332332
```python
333333
# Load for continued finetuning with different loss weights
334-
model = pretrained.orb_v3_conservative_omol(
334+
model, atoms_adapter = pretrained.orb_v3_conservative_omol(
335335
train=True,
336336
loss_weights={'energy': 0.5, 'grad_forces': 20.0}
337337
)
@@ -370,7 +370,7 @@ python finetune.py \
370370
from orb_models.forcefield import pretrained
371371
import torch
372372

373-
model = pretrained.orb_v3_conservative_omol(train=False)
373+
model, atoms_adapter = pretrained.orb_v3_conservative_omol(train=False)
374374
model.load_state_dict(torch.load('checkpoints/my_finetuned_model.pt'))
375375
# Reference energies from my_refs.json are now loaded!
376376
```
@@ -384,7 +384,7 @@ from orb_models.forcefield import pretrained
384384
from orb_models.dataset.ase_sqlite_dataset import AseSqliteDataset
385385

386386
# Load model with configuration
387-
model = pretrained.orb_v3_conservative_omol(
387+
model, atoms_adapter = pretrained.orb_v3_conservative_omol(
388388
device='cuda',
389389
train=True,
390390
train_reference_energies=False, # Fixed reference energies

MODELS.md

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ We provide several pretrained models that can be used to calculate energies, for
44

55
### OrbMol Models
66

7-
These models are a continuation of the [`orb-v3`](#v3-models) series trained on the [Open Molecules 2025 (OMol25)](https://arxiv.org/pdf/2505.08762) dataset—over 100M high-accuracy DFT calculations (ωB97M-V/def2-TZVPD) on diverse molecular systems including metal complexes, biomolecules, and electrolytes. Note: The training data does not contain periodic systmems and these models have not been carefully tested on periodic systems.
7+
These models are a continuation of the [`orb-v3`](#v3-models) series trained on the [Open Molecules 2025 (OMol25)](https://arxiv.org/pdf/2505.08762) dataset—over 100M high-accuracy DFT calculations (ωB97M-V/def2-TZVPD) on diverse molecular systems including metal complexes, biomolecules, and electrolytes. Note: The training data does not contain periodic systems and these models have not been carefully tested on periodic systems.
88

99
There are two options:
1010
* `orb-v3-conservative-omol`
@@ -15,11 +15,11 @@ See below for more explanation of this naming convention. Both models have `inf`
1515
### [V3 Models](https://arxiv.org/abs/2504.06231)
1616

1717
V3 models use the following naming convention: ```orb-v3-X-Y-Z``` where:
18-
- `X`: Model type - `direct` or `conservative`. Conservative models compute forces and stress via backpropagation, which is a physically motivated choice that appears necessary for certain types of simulation such as NVE Molecular dynamics. Conservative models are signficantly slower and use more memory than their direct counterparts.
18+
- `X`: Model type - `direct` or `conservative`. Conservative models compute forces and stress via backpropagation, which is a physically motivated choice that appears necessary for certain types of simulation such as NVE Molecular dynamics. Conservative models are significantly slower and use more memory than their direct counterparts.
1919

20-
- `Y`: Maximum neighbors per atom: `20` or `inf`. A finite cutoff of `20` induces discontinuties in the PES, which can lead to significant inaccuracies for certain types of highly sensitive calculations (e.g. calculations involving Hessians). However, finite cutoffs reduce the amount of edge processing in the network, reducing latency and memory use.
20+
- `Y`: Maximum neighbors per atom: `20` or `inf`. A finite cutoff of `20` induces discontinuities in the PES, which can lead to significant inaccuracies for certain types of highly sensitive calculations (e.g. calculations involving Hessians). However, finite cutoffs reduce the amount of edge processing in the network, reducing latency and memory use.
2121

22-
- `Z`: Training dataset - `omat` or `mpa`. Both of these dataset consist of small bulk crystal structures. We find that models trained on such data can generalise reasonably well to non-periodic systems (organic molecules) or partially periodic systems (slabs), but caution is advised in these scenarios.
22+
- `Z`: Training dataset - `omat` or `mpa`. Both of these datasets consist of small bulk crystal structures. We find that models trained on such data can generalise reasonably well to non-periodic systems (organic molecules) or partially periodic systems (slabs), but caution is advised in these scenarios.
2323

2424
#### Features:
2525

@@ -35,7 +35,7 @@ V3 models use the following naming convention: ```orb-v3-X-Y-Z``` where:
3535
#### Advice / Caveats
3636

3737
- Consider using `orb-v3-conservative-120-omat` for initial testing, specifying `precision='float32-highest'` when loading the model. This is the most computational expensive but accurate configuration. If this level of accuracy meets your needs, then other models and precisions can be investigated to improve speed and system-size scalability.
38-
- We do not advise using the `-mpa` models unless they are required for compatability with benchmarks (for example, Matbench Discovery). They are generally less performant.
38+
- We do not advise using the `-mpa` models unless they are required for compatibility with benchmarks (for example, Matbench Discovery). They are generally less performant.
3939
- Orb-v3 models are **compiled** by default and use Pytorch's dynamic batching, which means that they do not need to recompile as graph sizes change. However, the first call to the model will be slower, as the graph is compiled by torch.
4040

4141
### [V2 Models](https://arxiv.org/abs/2410.22570)
@@ -53,4 +53,4 @@ V3 models use the following naming convention: ```orb-v3-X-Y-Z``` where:
5353

5454
### [V1 Models](https://arxiv.org/abs/2410.22570)
5555

56-
Our initial release. These models were state of the art performance on the matbench discovery dataset at time of release, but have since been superceeded and removed.
56+
Our initial release. These models were state of the art performance on the matbench discovery dataset at time of release, but have since been superseded and removed.

0 commit comments

Comments
 (0)