Skip to content

Commit 5af20c8

Browse files
committed
set weight_only=False for esm reader
1 parent 196d662 commit 5af20c8

File tree

1 file changed

+26
-2
lines changed

1 file changed

+26
-2
lines changed

chebai_proteins/preprocessing/reader.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -221,10 +221,29 @@ def load_model_and_alphabet(self) -> Tuple[ESM2, Alphabet]:
221221
"""
222222
model_location = os.path.join(self.save_model_dir, f"{self.model_name}.pt")
223223
if os.path.exists(model_location):
224-
return load_model_and_alphabet_local(model_location)
224+
return self.load_model_and_alphabet_local(model_location)
225225
else:
226226
return self.load_model_and_alphabet_hub()
227227

228+
@staticmethod
229+
def load_model_and_alphabet_local(model_location):
230+
"""Load from local path. The regression weights need to be co-located"""
231+
model_location = Path(model_location)
232+
model_data = torch.load(
233+
str(model_location), map_location="cpu", weights_only=False
234+
)
235+
model_name = model_location.stem
236+
if _has_regression_weights(model_name):
237+
regression_location = (
238+
str(model_location.with_suffix("")) + "-contact-regression.pt"
239+
)
240+
regression_data = torch.load(
241+
regression_location, map_location="cpu", weights_only=False
242+
)
243+
else:
244+
regression_data = None
245+
return load_model_and_alphabet_core(model_name, model_data, regression_data)
246+
228247
def load_model_and_alphabet_hub(self) -> Tuple[ESM2, Alphabet]:
229248
"""
230249
Load the model and alphabet from the hub URL.
@@ -257,7 +276,11 @@ def load_hub_workaround(self, url) -> torch.Tensor:
257276
"""
258277
try:
259278
data = torch.hub.load_state_dict_from_url(
260-
url, self.save_model_dir, progress=True, map_location=self.device
279+
url,
280+
self.save_model_dir,
281+
progress=True,
282+
map_location=self.device,
283+
weights_only=False,
261284
)
262285

263286
except RuntimeError:
@@ -266,6 +289,7 @@ def load_hub_workaround(self, url) -> torch.Tensor:
266289
data = torch.load(
267290
f"{torch.hub.get_dir()}/checkpoints/{fn}",
268291
map_location="cpu",
292+
weights_only=False,
269293
)
270294
except HTTPError as e:
271295
raise Exception(

0 commit comments

Comments
 (0)