@@ -221,10 +221,29 @@ def load_model_and_alphabet(self) -> Tuple[ESM2, Alphabet]:
221221 """
222222 model_location = os .path .join (self .save_model_dir , f"{ self .model_name } .pt" )
223223 if os .path .exists (model_location ):
224- return load_model_and_alphabet_local (model_location )
224+ return self . load_model_and_alphabet_local (model_location )
225225 else :
226226 return self .load_model_and_alphabet_hub ()
227227
228+ @staticmethod
229+ def load_model_and_alphabet_local (model_location ):
230+ """Load from local path. The regression weights need to be co-located"""
231+ model_location = Path (model_location )
232+ model_data = torch .load (
233+ str (model_location ), map_location = "cpu" , weights_only = False
234+ )
235+ model_name = model_location .stem
236+ if _has_regression_weights (model_name ):
237+ regression_location = (
238+ str (model_location .with_suffix ("" )) + "-contact-regression.pt"
239+ )
240+ regression_data = torch .load (
241+ regression_location , map_location = "cpu" , weights_only = False
242+ )
243+ else :
244+ regression_data = None
245+ return load_model_and_alphabet_core (model_name , model_data , regression_data )
246+
228247 def load_model_and_alphabet_hub (self ) -> Tuple [ESM2 , Alphabet ]:
229248 """
230249 Load the model and alphabet from the hub URL.
@@ -257,7 +276,11 @@ def load_hub_workaround(self, url) -> torch.Tensor:
257276 """
258277 try :
259278 data = torch .hub .load_state_dict_from_url (
260- url , self .save_model_dir , progress = True , map_location = self .device
279+ url ,
280+ self .save_model_dir ,
281+ progress = True ,
282+ map_location = self .device ,
283+ weights_only = False ,
261284 )
262285
263286 except RuntimeError :
@@ -266,6 +289,7 @@ def load_hub_workaround(self, url) -> torch.Tensor:
266289 data = torch .load (
267290 f"{ torch .hub .get_dir ()} /checkpoints/{ fn } " ,
268291 map_location = "cpu" ,
292+ weights_only = False ,
269293 )
270294 except HTTPError as e :
271295 raise Exception (
0 commit comments