11import importlib .util
22import logging
33import warnings
4+ from typing import Optional
45
5- from metatomic .torch import AtomisticModel , ModelMetadata
6+ from metatomic .torch import AtomisticModel
67from metatrain .utils .io import load_model as load_metatrain_model
8+ from .utils import get_metadata
79
10+ from packaging .version import Version
811
9- METADATA = ModelMetadata (
10- name = "PET-MAD" ,
11- description = "A universal interatomic potential for advanced materials modeling" ,
12- authors = [
13- "Arslan Mazitov (arslan.mazitov@epfl.ch)" ,
14- "Filippo Bigi" ,
15- "Matthias Kellner" ,
16- "Paolo Pegolo" ,
17- "Davide Tisi" ,
18- "Guillaume Fraux" ,
19- "Sergey Pozdnyakov" ,
20- "Philip Loche" ,
21- "Michele Ceriotti (michele.ceriotti@epfl.ch)" ,
22- ],
23- references = {
24- "architecture" : ["https://arxiv.org/abs/2305.19302v3" ],
25- "model" : ["http://arxiv.org/abs/2503.14118" ],
26- },
27- )
28- VERSIONS = ("latest" , "1.1.0" , "1.0.1" , "1.0.0" )
29- BASE_URL = (
30- "https://huggingface.co/lab-cosmo/pet-mad/resolve/{}/models/pet-mad-latest.ckpt"
31- )
32-
33-
34- def get_pet_mad (* , version = "latest" , checkpoint_path = None ) -> AtomisticModel :
12+ from ._version import LATEST_VERSION , AVAILABLE_VERSIONS
13+
14+ BASE_URL = "https://huggingface.co/lab-cosmo/pet-mad/resolve/{tag}/models/pet-mad-{version}.ckpt"
15+
16+
17+ def get_pet_mad (
18+ * , version : str = "latest" , checkpoint_path : Optional [str ] = None
19+ ) -> AtomisticModel :
3520 """Get a metatomic ``AtomisticModel`` for PET-MAD.
3621
37- :param version: PET-MAD version to use. Supported versions are "latest", "1.1.0",
38- "1.0. 1", "1.0.0". Defaults to " latest" .
22+ :param version: PET-MAD version to use. Supported versions are
23+ "1.1.0", "1.0. 1", "1.0.0". Defaults to latest available version .
3924 :param checkpoint_path: path to a checkpoint file to load the model from. If
4025 provided, the `version` parameter is ignored.
4126 """
42- if version not in VERSIONS :
27+ if version == "latest" :
28+ version = Version (LATEST_VERSION )
29+ if not isinstance (version , Version ):
30+ version = Version (version )
31+
32+ if version not in [Version (v ) for v in AVAILABLE_VERSIONS ]:
4333 raise ValueError (
44- f"Version { version } is not supported. Supported versions are { VERSIONS } "
34+ f"Version { version } is not supported. Supported versions are { AVAILABLE_VERSIONS } "
4535 )
4636
47- if version == "1.0.0" :
37+ if version == Version ( "1.0.0" ) :
4838 if not importlib .util .find_spec ("pet_neighbors_convert" ):
4939 raise ImportError (
5040 f"PET-MAD v{ version } is now deprecated. Please consider using the "
@@ -60,9 +50,7 @@ def get_pet_mad(*, version="latest", checkpoint_path=None) -> AtomisticModel:
6050 path = checkpoint_path
6151 else :
6252 logging .info (f"Downloading PET-MAD model version: { version } " )
63- path = BASE_URL .format (
64- f"v{ version } " if version not in ("latest" , "1.1.0" ) else "main"
65- )
53+ path = BASE_URL .format (tag = f"v{ version } " , version = f"v{ version } " )
6654
6755 with warnings .catch_warnings ():
6856 warnings .filterwarnings (
@@ -71,25 +59,31 @@ def get_pet_mad(*, version="latest", checkpoint_path=None) -> AtomisticModel:
7159 )
7260 model = load_metatrain_model (path )
7361
74- return model .export (METADATA )
62+ metadata = get_metadata (version )
63+ return model .export (metadata )
7564
7665
77- def save_pet_mad (* , version = "latest" , checkpoint_path = None , output = None ):
66+ def save_pet_mad (* , version : str = "latest" , checkpoint_path = None , output = None ):
7867 """
7968 Save the PET-MAD model to a TorchScript file (``pet-mad-xxx.pt``). These files can
8069 be used with LAMMPS and other tools to run simulations without Python.
8170
82- :param version: PET-MAD version to use. Supported versions are "latest", " 1.1.0",
83- "1.0.1", "1.0.0". Defaults to " latest" .
71+ :param version: PET-MAD version to use. Supported versions are "1.1.0",
72+ "1.0.1", "1.0.0". Defaults to the latest version .
8473 :param checkpoint_path: path to a checkpoint file to load the model from. If
8574 provided, the `version` parameter is ignored.
8675 :param output: path to use for the output model, defaults to
8776 ``pet-mad-{version}.pt`` when using a version, or the checkpoint path when using
8877 a checkpoint.
8978 """
79+ if version == "latest" :
80+ version = Version (LATEST_VERSION )
81+ if not isinstance (version , Version ):
82+ version = Version (version )
83+
9084 extensions_directory = None
91- if version == "1.0.0" :
92- logging .info ("putting TorchScript extensions in `extensions/`" )
85+ if version == Version ( "1.0.0" ) :
86+ logging .info ("Putting TorchScript extensions in `extensions/`" )
9387 extensions_directory = "extensions"
9488
9589 model = get_pet_mad (version = version , checkpoint_path = checkpoint_path )
@@ -101,4 +95,4 @@ def save_pet_mad(*, version="latest", checkpoint_path=None, output=None):
10195 raise
10296
10397 model .save (output , collect_extensions = extensions_directory )
104- logging .info (f"saved pet-mad model to { output } " )
98+ logging .info (f"Saved PET-MAD model to { output } " )
0 commit comments