@@ -109,7 +109,11 @@ def upet_get_version_to_load(
109109
110110
111111def get_upet (
112- * , model : str , size : str , version : str , checkpoint_path : Optional [str ] = None
112+ * ,
113+ model : str ,
114+ size : str ,
115+ version : str = "latest" ,
116+ checkpoint_path : Optional [str ] = None ,
113117) -> AtomisticModel :
114118 """Get a metatomic ``AtomisticModel`` for a UPET MLIP.
115119
@@ -119,6 +123,10 @@ def get_upet(
119123 :param checkpoint_path: path to a checkpoint file to load the model from. If
120124 provided, the `version` parameter is ignored.
121125 """
126+ if version == "latest" :
127+ version = upet_get_version_to_load (model , size , requested_version = version )
128+ if not isinstance (version , Version ):
129+ version = Version (version )
122130 if checkpoint_path is not None :
123131 logging .info (f"Loading model from checkpoint: { checkpoint_path } " )
124132 path = checkpoint_path
@@ -146,7 +154,7 @@ def save_upet(
146154 * ,
147155 model : str ,
148156 size : str ,
149- version : str ,
157+ version : str = "latest" ,
150158 checkpoint_path : Optional [str ] = None ,
151159 output = None ,
152160):
@@ -172,7 +180,7 @@ def save_upet(
172180 if checkpoint_path is None :
173181 output = "-" .join ([model , size , f"v{ version } " ]) + ".pt"
174182 else :
175- raise
183+ raise ValueError ( "Output path must be specified when using a checkpoint." )
176184
177185 loaded_model .save (output )
178186 logging .info (f"Saved UPET model to { output } " )
0 commit comments