Skip to content

Commit 62ce470

Browse files
authored
Merge pull request #30 from lab-cosmo/pre-release-uq
Pre-release the UQ update
2 parents c720d63 + 89d9aee commit 62ce470

File tree

9 files changed

+151
-59
lines changed

9 files changed

+151
-59
lines changed

.github/workflows/release.yml

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
name: Release
2+
3+
on:
4+
push:
5+
tags: ["*"]
6+
7+
jobs:
8+
build:
9+
name: Build distribution
10+
runs-on: ubuntu-latest
11+
environment:
12+
name: pypi
13+
url: https://pypi.org/project/pet-mad
14+
permissions:
15+
id-token: write
16+
contents: write
17+
18+
steps:
19+
- uses: actions/checkout@v4
20+
with:
21+
fetch-depth: 0
22+
- name: setup Python
23+
uses: actions/setup-python@v5
24+
with:
25+
python-version: "3.13"
26+
- run: python -m pip install tox
27+
- name: Build package
28+
run: tox -e build
29+
- name: Publish distribution to PyPI
30+
if: startsWith(github.ref, 'refs/tags/v')
31+
uses: pypa/gh-action-pypi-publish@release/v1
32+
- name: Publish to GitHub release
33+
if: startsWith(github.ref, 'refs/tags/v')
34+
uses: softprops/action-gh-release@v2
35+
with:
36+
files: |
37+
dist/*.tar.gz
38+
dist/*.whl
39+
prerelease: ${{ contains(github.ref, '-rc') }}
40+
env:
41+
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}

.gitignore

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ __pycache__/
77
*.so
88

99
# Distribution / packaging
10-
_version.py
1110
.Python
1211
build/
1312
develop-eggs/

src/pet_mad/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1 @@
1-
from ._models import get_pet_mad, save_pet_mad # noqa: F401
2-
31
__version__ = "1.3.0"

src/pet_mad/_models.py

Lines changed: 36 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,50 +1,40 @@
11
import importlib.util
22
import logging
33
import warnings
4+
from typing import Optional
45

5-
from metatomic.torch import AtomisticModel, ModelMetadata
6+
from metatomic.torch import AtomisticModel
67
from metatrain.utils.io import load_model as load_metatrain_model
8+
from .utils import get_metadata
79

10+
from packaging.version import Version
811

9-
METADATA = ModelMetadata(
10-
name="PET-MAD",
11-
description="A universal interatomic potential for advanced materials modeling",
12-
authors=[
13-
"Arslan Mazitov (arslan.mazitov@epfl.ch)",
14-
"Filippo Bigi",
15-
"Matthias Kellner",
16-
"Paolo Pegolo",
17-
"Davide Tisi",
18-
"Guillaume Fraux",
19-
"Sergey Pozdnyakov",
20-
"Philip Loche",
21-
"Michele Ceriotti (michele.ceriotti@epfl.ch)",
22-
],
23-
references={
24-
"architecture": ["https://arxiv.org/abs/2305.19302v3"],
25-
"model": ["http://arxiv.org/abs/2503.14118"],
26-
},
27-
)
28-
VERSIONS = ("latest", "1.1.0", "1.0.1", "1.0.0")
29-
BASE_URL = (
30-
"https://huggingface.co/lab-cosmo/pet-mad/resolve/{}/models/pet-mad-latest.ckpt"
31-
)
32-
33-
34-
def get_pet_mad(*, version="latest", checkpoint_path=None) -> AtomisticModel:
12+
from ._version import LATEST_VERSION, AVAILABLE_VERSIONS
13+
14+
BASE_URL = "https://huggingface.co/lab-cosmo/pet-mad/resolve/{tag}/models/pet-mad-{version}.ckpt"
15+
16+
17+
def get_pet_mad(
18+
*, version: str = "latest", checkpoint_path: Optional[str] = None
19+
) -> AtomisticModel:
3520
"""Get a metatomic ``AtomisticModel`` for PET-MAD.
3621
37-
:param version: PET-MAD version to use. Supported versions are "latest", "1.1.0",
38-
"1.0.1", "1.0.0". Defaults to "latest".
22+
:param version: PET-MAD version to use. Supported versions are
23+
"1.1.0", "1.0.1", "1.0.0". Defaults to latest available version.
3924
:param checkpoint_path: path to a checkpoint file to load the model from. If
4025
provided, the `version` parameter is ignored.
4126
"""
42-
if version not in VERSIONS:
27+
if version == "latest":
28+
version = Version(LATEST_VERSION)
29+
if not isinstance(version, Version):
30+
version = Version(version)
31+
32+
if version not in [Version(v) for v in AVAILABLE_VERSIONS]:
4333
raise ValueError(
44-
f"Version {version} is not supported. Supported versions are {VERSIONS}"
34+
f"Version {version} is not supported. Supported versions are {AVAILABLE_VERSIONS}"
4535
)
4636

47-
if version == "1.0.0":
37+
if version == Version("1.0.0"):
4838
if not importlib.util.find_spec("pet_neighbors_convert"):
4939
raise ImportError(
5040
f"PET-MAD v{version} is now deprecated. Please consider using the "
@@ -60,9 +50,7 @@ def get_pet_mad(*, version="latest", checkpoint_path=None) -> AtomisticModel:
6050
path = checkpoint_path
6151
else:
6252
logging.info(f"Downloading PET-MAD model version: {version}")
63-
path = BASE_URL.format(
64-
f"v{version}" if version not in ("latest", "1.1.0") else "main"
65-
)
53+
path = BASE_URL.format(tag=f"v{version}", version=f"v{version}")
6654

6755
with warnings.catch_warnings():
6856
warnings.filterwarnings(
@@ -71,25 +59,31 @@ def get_pet_mad(*, version="latest", checkpoint_path=None) -> AtomisticModel:
7159
)
7260
model = load_metatrain_model(path)
7361

74-
return model.export(METADATA)
62+
metadata = get_metadata(version)
63+
return model.export(metadata)
7564

7665

77-
def save_pet_mad(*, version="latest", checkpoint_path=None, output=None):
66+
def save_pet_mad(*, version: str = "latest", checkpoint_path=None, output=None):
7867
"""
7968
Save the PET-MAD model to a TorchScript file (``pet-mad-xxx.pt``). These files can
8069
be used with LAMMPS and other tools to run simulations without Python.
8170
82-
:param version: PET-MAD version to use. Supported versions are "latest", "1.1.0",
83-
"1.0.1", "1.0.0". Defaults to "latest".
71+
:param version: PET-MAD version to use. Supported versions are "1.1.0",
72+
"1.0.1", "1.0.0". Defaults to the latest version.
8473
:param checkpoint_path: path to a checkpoint file to load the model from. If
8574
provided, the `version` parameter is ignored.
8675
:param output: path to use for the output model, defaults to
8776
``pet-mad-{version}.pt`` when using a version, or the checkpoint path when using
8877
a checkpoint.
8978
"""
79+
if version == "latest":
80+
version = Version(LATEST_VERSION)
81+
if not isinstance(version, Version):
82+
version = Version(version)
83+
9084
extensions_directory = None
91-
if version == "1.0.0":
92-
logging.info("putting TorchScript extensions in `extensions/`")
85+
if version == Version("1.0.0"):
86+
logging.info("Putting TorchScript extensions in `extensions/`")
9387
extensions_directory = "extensions"
9488

9589
model = get_pet_mad(version=version, checkpoint_path=checkpoint_path)
@@ -101,4 +95,4 @@ def save_pet_mad(*, version="latest", checkpoint_path=None, output=None):
10195
raise
10296

10397
model.save(output, collect_extensions=extensions_directory)
104-
logging.info(f"saved pet-mad model to {output}")
98+
logging.info(f"Saved PET-MAD model to {output}")

src/pet_mad/_version.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
LATEST_VERSION = "1.1.0"
2+
AVAILABLE_VERSIONS = ["1.1.0", "1.0.1", "1.0.0"]

src/pet_mad/calculator.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,10 @@
55
from metatomic.torch.ase_calculator import MetatomicCalculator
66
from platformdirs import user_cache_dir
77

8+
from packaging.version import Version
9+
810
from ._models import get_pet_mad
11+
from ._version import LATEST_VERSION
912

1013

1114
class PETMADCalculator(MetatomicCalculator):
@@ -18,30 +21,37 @@ def __init__(
1821
version: str = "latest",
1922
checkpoint_path: Optional[str] = None,
2023
*,
21-
non_conservative=False,
22-
check_consistency=False,
23-
device=None,
24+
check_consistency: bool = False,
25+
device: Optional[str] = None,
26+
non_conservative: bool = False,
2427
):
2528
"""
26-
:param version: PET-MAD version to use. Supported versions are "latest", "1.1.0",
27-
"1.0.1", "1.0.0". Defaults to "latest".
29+
:param version: PET-MAD version to use. Supported versions are
30+
"1.1.0", "1.0.1", "1.0.0". Defaults to latest available version.
2831
:param checkpoint_path: path to a checkpoint file to load the model from. If
2932
provided, the `version` parameter is ignored.
30-
:param non_conservative: whether to use the non-conservative regime of forces and
31-
stresses prediction. Defaults to False.
3233
:param check_consistency: should we check the model for consistency when
3334
running, defaults to False.
3435
:param device: torch device to use for the calculation. If `None`, we will try
3536
the options in the model's `supported_device` in order.
37+
:param non_conservative: whether to use the non-conservative regime of forces
38+
and stresses prediction. Defaults to False. Only available for PET-MAD
39+
version 1.1.0 or higher.
40+
3641
"""
3742

43+
if version == "latest":
44+
version = Version(LATEST_VERSION)
45+
if not isinstance(version, Version):
46+
version = Version(version)
47+
3848
model = get_pet_mad(version=version, checkpoint_path=checkpoint_path)
3949

4050
cache_dir = user_cache_dir("pet-mad", "metatensor")
4151
os.makedirs(cache_dir, exist_ok=True)
4252

4353
extensions_directory = None
44-
if version == "1.0.0":
54+
if version == Version("1.0.0"):
4555
extensions_directory = "extensions"
4656

4757
pt_path = cache_dir + f"/pet-mad-{version}.pt"
@@ -56,9 +66,9 @@ def __init__(
5666

5767
super().__init__(
5868
pt_path,
69+
additional_outputs={},
5970
extensions_directory=extensions_directory,
6071
check_consistency=check_consistency,
6172
device=device,
6273
non_conservative=non_conservative,
63-
additional_outputs={},
6474
)

src/pet_mad/utils.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
from metatomic.torch import ModelMetadata
2+
3+
4+
def get_metadata(version: str):
5+
return ModelMetadata(
6+
name=f"PET-MAD v{version}",
7+
description="A universal interatomic potential for advanced materials modeling",
8+
authors=[
9+
"Arslan Mazitov (arslan.mazitov@epfl.ch)",
10+
"Filippo Bigi",
11+
"Matthias Kellner",
12+
"Paolo Pegolo",
13+
"Davide Tisi",
14+
"Guillaume Fraux",
15+
"Sergey Pozdnyakov",
16+
"Philip Loche",
17+
"Michele Ceriotti (michele.ceriotti@epfl.ch)",
18+
],
19+
references={
20+
"architecture": ["https://arxiv.org/abs/2305.19302v3"],
21+
"model": ["http://arxiv.org/abs/2503.14118"],
22+
},
23+
)

tests/test_basic_usage.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,28 @@
11
from pet_mad.calculator import PETMADCalculator
2+
from pet_mad._models import get_pet_mad, save_pet_mad
23
from ase.build import bulk
4+
import os
35
import pytest
46

7+
VERSIONS = ("1.1.0", "1.0.1")
8+
9+
@pytest.mark.parametrize(
10+
"version",
11+
VERSIONS,
12+
)
13+
def test_get_pet_mad(version):
14+
model = get_pet_mad(version=version)
15+
assert model.metadata().name == f"PET-MAD v{version}"
16+
17+
18+
@pytest.mark.parametrize(
19+
"version",
20+
VERSIONS,
21+
)
22+
def test_save_pet_mad(version, monkeypatch, tmp_path):
23+
monkeypatch.chdir(tmp_path)
24+
save_pet_mad(version=version, output=f"pet-mad-{version}.pt")
25+
assert os.path.exists(f"pet-mad-{version}.pt")
526

627
def test_basic_usage():
728
atoms = bulk("Si", cubic=True, a=5.43, crystalstructure="diamond")
@@ -12,11 +33,7 @@ def test_basic_usage():
1233

1334
@pytest.mark.parametrize(
1435
"version",
15-
[
16-
"latest",
17-
"1.1.0",
18-
"1.0.1",
19-
],
36+
VERSIONS,
2037
)
2138
def test_version(version):
2239
atoms = bulk("Si", cubic=True, a=5.43, crystalstructure="diamond")

tests/test_metadata.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
from pet_mad.utils import get_metadata
2+
3+
def test_get_metadata():
4+
metadata = get_metadata("1.1.0")
5+
assert metadata.name == "PET-MAD v1.1.0"
6+
assert metadata.description == "A universal interatomic potential for advanced materials modeling"
7+
assert metadata.authors == ["Arslan Mazitov (arslan.mazitov@epfl.ch)", "Filippo Bigi", "Matthias Kellner", "Paolo Pegolo", "Davide Tisi", "Guillaume Fraux", "Sergey Pozdnyakov", "Philip Loche", "Michele Ceriotti (michele.ceriotti@epfl.ch)"]
8+
assert metadata.references == {"architecture": ["https://arxiv.org/abs/2305.19302v3"], "model": ["http://arxiv.org/abs/2503.14118"]}

0 commit comments

Comments
 (0)