Skip to content

Commit 1907779

Browse files
ceriottmabmazitov
andauthored
Add list_upet() function (#118)
--------- Co-authored-by: Arslan Mazitov <arslan.mazitov@phystech.edu>
1 parent 41fac54 commit 1907779

File tree

6 files changed

+141
-6
lines changed

6 files changed

+141
-6
lines changed

README.md

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,10 +130,10 @@ UPET integrates with the following atomistic simulation engines:
130130

131131
- **Atomic Simulation Environment (ASE)**
132132
- **LAMMPS** (including KOKKOS support)
133+
- **GROMACS**
133134
- **i-PI**
134135
- **TorchSim**
135136
- **OpenMM** (coming soon)
136-
- **GROMACS** (coming soon)
137137

138138
## Usage
139139

@@ -175,6 +175,26 @@ These ASE methods are ideal for single-structure evaluations, but they are
175175
inefficient for the evaluation on a large number of pre-defined structures. To
176176
perform efficient batched evaluation in that case, read [here](docs/README_BATCHED.md).
177177

178+
If the `version` argument is not specified, the latest available version of the model
179+
will be downloaded and used by default.
180+
181+
```python
182+
from upet.calculator import UPETCalculator
183+
184+
calculator = UPETCalculator(model="pet-mad-s", device="cpu") # uses the latest version of the PET-MAD-S model by default
185+
```
186+
187+
You can get the list of available versions for a given model using the
188+
`list_upet` function:
189+
190+
```python
191+
from upet import list_upet
192+
193+
list_upet(model="pet-mad", size="s") # for PET-MAD model of size S
194+
list_upet(model="pet-mad") # for all PET-MAD models of all sizes
195+
list_upet() # for all available UPET models
196+
```
197+
178198
#### Non-conservative (direct) forces and stresses prediction
179199

180200
UPET models also support the direct prediction of forces and stresses. In that case,

pyproject.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,8 +106,9 @@ filterwarnings = [
106106
# metatrain checkpoint upgrade warning
107107
"ignore: trying to upgrade an old model checkpoint with unknown version, this might fail and require manual modifications",
108108
# TorchScript deprecation warnings
109-
"ignore: `torch.jit.script` is deprecated. Please switch to `torch.compile` or `torch.export`.",
110-
"ignore: `torch.jit.load` is deprecated. Please switch to `torch.export`.",
109+
"ignore: `torch.jit.script` is deprecated.*:DeprecationWarning",
110+
"ignore: `torch.jit.save` is deprecated.*:DeprecationWarning",
111+
"ignore: `torch.jit.load` is deprecated.*:DeprecationWarning",
111112
"ignore: `torch.jit.script_method` is deprecated. Please switch to `torch.compile` or `torch.export`.",
112113
# PET-MAD v1.0 deprecation warning
113114
"ignore:.*is deprecated in favor of the newer PET-MAD-1.5.*:DeprecationWarning",

src/upet/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import torch
66
from warp._src import utils as wp_utils
77

8-
from ._models import get_upet, save_upet
8+
from ._models import get_upet, list_upet, save_upet
99

1010

1111
# hides a harmless warning from nvalchemi's neighbor list implmentation
@@ -33,4 +33,4 @@ def _warn_filtered(message, category=None, stacklevel=1):
3333
# causing "Global alloc not supported yet" errors (cuda 13+) at the time of writing
3434
torch.jit.set_fusion_strategy([("DYNAMIC", 10)])
3535

36-
__all__ = ["get_upet", "save_upet"]
36+
__all__ = ["get_upet", "list_upet", "save_upet"]

src/upet/_models.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,20 @@ def _get_upet_repo_files() -> List[str]:
3636
return [f[7:] for f in repo_files if f.startswith("models/")]
3737

3838

39+
def get_available_models() -> List[str]:
40+
"""Get all available base model names from the HuggingFace repository.
41+
42+
:return: Sorted list of base model names (e.g., ["pet-mad", "pet-omat", ...])
43+
"""
44+
files = _get_upet_repo_files()
45+
models = set()
46+
for f in files:
47+
match = CHECKPOINT_NAME_PATTERN.match(f)
48+
if match:
49+
models.add(match.group("model"))
50+
return sorted(models)
51+
52+
3953
def get_sizes_for_model(model: str) -> List[str]:
4054
"""Get all available sizes for a given model from the cached repo files.
4155
@@ -278,6 +292,53 @@ def save_upet(
278292
logging.info(f"Saved UPET model to {output}")
279293

280294

295+
def list_upet(
296+
*,
297+
model: Optional[str] = None,
298+
size: Optional[str] = None,
299+
print_summary: bool = True,
300+
) -> List[dict]:
301+
"""List available UPET models, sizes, and versions.
302+
303+
When called without arguments, returns all available model/size/version
304+
combinations. When ``model`` is given, filters to that model. When both
305+
``model`` and ``size`` are given, filters to that specific combination.
306+
307+
:param model: Base model name (e.g., "pet-mad", "pet-omat"). If ``None``,
308+
lists all available models.
309+
:param size: Model size (e.g., "s", "m", "l"). If ``None`` and ``model`` is
310+
given, lists all sizes for that model.
311+
:param print_summary: Whether to print a human-readable summary to stdout.
312+
Defaults to ``True``.
313+
:return: A list of dictionaries, each with keys ``"model"``, ``"size"``,
314+
and ``"version"``.
315+
"""
316+
if model is None:
317+
models = get_available_models()
318+
else:
319+
models = [model]
320+
321+
result = []
322+
for m in models:
323+
if size is None:
324+
sizes = get_sizes_for_model(m)
325+
else:
326+
sizes = [size]
327+
for s in sizes:
328+
for v in get_versions_for_model(m, s):
329+
result.append({"model": m, "size": s, "version": str(v)})
330+
331+
if print_summary:
332+
if not result:
333+
print("No UPET models found.")
334+
else:
335+
print("Available UPET models:")
336+
for entry in result:
337+
print(f" - {entry['model']}-{entry['size']} v{entry['version']}")
338+
339+
return result
340+
341+
281342
BASE_URL_PET_MAD_DOS = "https://huggingface.co/lab-cosmo/pet-mad-dos/resolve/{tag}/models/pet-mad-dos-{version}.pt"
282343
BASE_URL_BANDGAP_MODEL = (
283344
"https://huggingface.co/lab-cosmo/pet-mad-dos/resolve/{tag}/models/bandgap-model.pt"

tests/upet/test_basic_usage.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,15 @@
1+
import os
2+
13
import pytest
24
from ase.build import bulk, molecule
35

4-
from upet._models import get_upet, get_versions_for_model, upet_resolve_model
6+
from upet._models import (
7+
get_upet,
8+
get_versions_for_model,
9+
list_upet,
10+
save_upet,
11+
upet_resolve_model,
12+
)
513
from upet._version import UPET_AVAILABLE_MODELS
614
from upet.calculator import UPETCalculator
715

@@ -50,6 +58,27 @@ def test_get_upet(model_name):
5058
get_upet(model=model, size=size, version=version)
5159

5260

61+
def test_list_models():
62+
result = list_upet(print_summary=False)
63+
assert len(result) > 0
64+
assert all(
65+
"model" in entry and "size" in entry and "version" in entry for entry in result
66+
)
67+
assert any(entry["model"] == "pet-mad" for entry in result)
68+
69+
70+
def test_list_sizes_for_model():
71+
result = list_upet(model="pet-mad", print_summary=False)
72+
assert len(result) > 0
73+
assert all(entry["model"] == "pet-mad" for entry in result)
74+
75+
76+
def test_list_versions_for_model_and_size():
77+
result = list_upet(model="pet-mad", size="s", print_summary=False)
78+
assert len(result) > 0
79+
assert all(entry["model"] == "pet-mad" and entry["size"] == "s" for entry in result)
80+
81+
5382
@pytest.mark.parametrize("model_name", UPET_AVAILABLE_MODELS)
5483
def test_basic_usage(model_name):
5584
if "-xl" in model_name or "-l" in model_name:
@@ -72,3 +101,10 @@ def test_basic_usage(model_name):
72101
assert isinstance(energy, float)
73102
assert forces.shape == (len(atoms), 3)
74103
assert virial.shape == (6,)
104+
105+
106+
def test_save_upet(tmp_path):
107+
output_path = str(tmp_path / "pet-mad-xs.pt")
108+
save_upet(model="pet-mad", size="xs", output=output_path)
109+
assert os.path.isfile(output_path)
110+
assert os.path.getsize(output_path) > 0
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from packaging.version import Version
2+
3+
from upet._models import parse_checkpoint_filename
4+
5+
6+
def test_parse_standard_checkpoint_name():
7+
model, size, version = parse_checkpoint_filename("pet-mad-s-v1.2.3.ckpt")
8+
assert model == "pet-mad"
9+
assert size == "s"
10+
assert version == Version("1.2.3")
11+
12+
13+
def test_parse_non_standard_checkpoint_name():
14+
model, size, version = parse_checkpoint_filename("custom_name.ckpt")
15+
assert model is None
16+
assert size is None
17+
assert version is None

0 commit comments

Comments
 (0)