Skip to content

Commit d5d3b96

Browse files
Add instructions for fine-tuning (#26)
1 parent 744202d commit d5d3b96

File tree

4 files changed

+53
-0
lines changed

4 files changed

+53
-0
lines changed

README.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,13 @@ Models
105105
See [here](docs/models.md) for the complete list of the models we provide. If you are
106106
new to FlashMD, we recommend starting with the ``pet-omatpes-v2`` models.
107107

108+
Training/fine-tuning your own FlashMD models
109+
--------------------------------------------
110+
111+
FlashMD models can be trained from the metatrain library. This
112+
[tutorial](https://docs.metatensor.org/metatrain/latest/generated_examples/1-advanced/04-flashmd.html)
113+
shows how to train your own FlashMD model, either from scratch or via fine-tuning of one of our universal models.
114+
108115
Disclaimer
109116
----------
110117

src/flashmd/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import warnings
22

33
from .models import get_pretrained as get_pretrained
4+
from .models import save_checkpoint as save_checkpoint
45

56

67
warnings.filterwarnings("ignore", category=UserWarning, message="custom data")

src/flashmd/models.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
import shutil
23
import subprocess
34
import time
45

@@ -81,3 +82,21 @@ def get_pretrained(mlip: str = "pet-omatpes-v2", time_step: int = 16) -> Atomist
8182
flashmd_model = load_atomistic_model(exported_flashmd_path)
8283

8384
return mlip_model, flashmd_model
85+
86+
87+
def save_checkpoint(mlip: str = "pet-omatpes-v2", time_step: int = 16):
88+
if time_step not in AVAILABLE_TIME_STEPS[mlip]:
89+
raise ValueError(
90+
f"Pre-trained FlashMD models based on the {mlip} MLIP are only available "
91+
f"for time steps of {', '.join(map(str, AVAILABLE_TIME_STEPS[mlip]))} fs."
92+
)
93+
94+
checkpoint_path = hf_hub_download(
95+
repo_id="lab-cosmo/flashmd",
96+
filename=f"flashmd_{mlip}_{time_step}fs.ckpt",
97+
cache_dir=None,
98+
revision="main",
99+
)
100+
101+
# Copy it to the current directory
102+
shutil.copyfile(checkpoint_path, f"flashmd_{mlip}_{time_step}fs.ckpt")

tests/test_models.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,3 +37,29 @@ def test_get_pretrained_invalid_time_step():
3737

3838
with pytest.raises(ValueError, match="Pre-trained FlashMD models"):
3939
get_pretrained(mlip="pet-omatpes", time_step=999)
40+
41+
42+
def test_save_checkpoint_invalid_time_step():
43+
"""Test that save_checkpoint raises ValueError for invalid time step."""
44+
from flashmd.models import save_checkpoint
45+
46+
with pytest.raises(ValueError, match="Pre-trained FlashMD models"):
47+
save_checkpoint(mlip="pet-omatpes-v2", time_step=999)
48+
49+
50+
def test_save_checkpoint(monkeypatch):
51+
"""Test that save_checkpoint saves the checkpoint file."""
52+
from flashmd.models import save_checkpoint
53+
54+
# Mock hf_hub_download and shutil.copyfile
55+
def mock_hf_hub_download(repo_id, filename, cache_dir, revision):
56+
return f"/path/to/{filename}"
57+
58+
def mock_copyfile(src, dst):
59+
assert src == "/path/to/flashmd_pet-omatpes-v2_16fs.ckpt"
60+
assert dst == "flashmd_pet-omatpes-v2_16fs.ckpt"
61+
62+
monkeypatch.setattr("flashmd.models.hf_hub_download", mock_hf_hub_download)
63+
monkeypatch.setattr("flashmd.models.shutil.copyfile", mock_copyfile)
64+
65+
save_checkpoint(mlip="pet-omatpes-v2", time_step=16)

0 commit comments

Comments
 (0)