Skip to content

Commit 4180f3f

Browse files
committed
make sure all plms are excluded from state_dict
1 parent 12cd5c1 commit 4180f3f

File tree

3 files changed

+31
-2
lines changed

3 files changed

+31
-2
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,8 @@
8282

8383
from alphafold3_pytorch.plm import (
8484
PLMEmbedding,
85-
PLMRegistry
85+
PLMRegistry,
86+
remove_plms
8687
)
8788

8889
from alphafold3_pytorch.utils.model_utils import (
@@ -6242,6 +6243,14 @@ def __init__(
62426243
def device(self):
62436244
return self.zero.device
62446245

6246+
@remove_plms
6247+
def state_dict(self, *args, **kwargs):
6248+
return super().state_dict(*args, **kwargs)
6249+
6250+
@remove_plms
6251+
def load_state_dict(self, *args, **kwargs):
6252+
return super().load_state_dict(*args, **kwargs)
6253+
62456254
@property
62466255
def state_dict_with_init_args(self):
62476256
return dict(

alphafold3_pytorch/plm.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import re
2-
from functools import partial
2+
from functools import partial, wraps
33

44
import torch
55
from beartype.typing import Literal
@@ -15,6 +15,22 @@
1515
def join(arr, delimiter = ''): # just redo an ugly part of python
1616
return delimiter.join(arr)
1717

18+
def remove_plms(fn):
19+
@wraps(fn)
20+
def inner(self, *args, **kwargs):
21+
has_plms = hasattr(self, 'plms')
22+
if has_plms:
23+
plms = self.plms
24+
delattr(self, 'plms')
25+
26+
out = fn(self, *args, **kwargs)
27+
28+
if has_plms:
29+
self.plms = plms
30+
31+
return out
32+
return inner
33+
1834
# constants
1935

2036
aa_constants = get_residue_constants(res_chem_index=IS_PROTEIN)

tests/test_af3.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1098,6 +1098,10 @@ def test_alphafold3_with_plm_embeddings():
10981098
plm_embeddings="esm2_t33_650M_UR50D",
10991099
)
11001100

1101+
state_dict = alphafold3.state_dict()
1102+
1103+
assert not any([key.startswith('plms') for key in state_dict.keys()])
1104+
11011105
# mock inputs
11021106

11031107
seq_len = 16

0 commit comments

Comments
 (0)