Skip to content

Commit 68c0e41

Browse files
authored
Updated the get-save upet functions (#77)
1 parent 4ea73b6 commit 68c0e41

File tree

2 files changed

+13
-4
lines changed

2 files changed

+13
-4
lines changed

src/upet/_models.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,11 @@ def upet_get_version_to_load(
109109

110110

111111
def get_upet(
112-
*, model: str, size: str, version: str, checkpoint_path: Optional[str] = None
112+
*,
113+
model: str,
114+
size: str,
115+
version: str = "latest",
116+
checkpoint_path: Optional[str] = None,
113117
) -> AtomisticModel:
114118
"""Get a metatomic ``AtomisticModel`` for a UPET MLIP.
115119
@@ -119,6 +123,10 @@ def get_upet(
119123
:param checkpoint_path: path to a checkpoint file to load the model from. If
120124
provided, the `version` parameter is ignored.
121125
"""
126+
if version == "latest":
127+
version = upet_get_version_to_load(model, size, requested_version=version)
128+
if not isinstance(version, Version):
129+
version = Version(version)
122130
if checkpoint_path is not None:
123131
logging.info(f"Loading model from checkpoint: {checkpoint_path}")
124132
path = checkpoint_path
@@ -146,7 +154,7 @@ def save_upet(
146154
*,
147155
model: str,
148156
size: str,
149-
version: str,
157+
version: str = "latest",
150158
checkpoint_path: Optional[str] = None,
151159
output=None,
152160
):
@@ -172,7 +180,7 @@ def save_upet(
172180
if checkpoint_path is None:
173181
output = "-".join([model, size, f"v{version}"]) + ".pt"
174182
else:
175-
raise
183+
raise ValueError("Output path must be specified when using a checkpoint.")
176184

177185
loaded_model.save(output)
178186
logging.info(f"Saved UPET model to {output}")

src/upet/calculator.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,8 @@ def __init__(
113113
)
114114
model, size = model.rsplit("-", 1)
115115
size = upet_get_size_to_load(model, requested_size=size)
116-
version = upet_get_version_to_load(model, size, requested_version=version)
116+
if version == "latest":
117+
version = upet_get_version_to_load(model, size, requested_version=version)
117118

118119
if not isinstance(version, Version):
119120
version = Version(version)

0 commit comments

Comments
 (0)