Skip to content

Commit 6f81b7a

Browse files
committed
checkpointable msa module
1 parent 70a39d8 commit 6f81b7a

File tree

3 files changed

+127
-19
lines changed

3 files changed

+127
-19
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 115 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -981,6 +981,8 @@ def __init__(
981981
msa_pwa_dropout_row_prob = 0.15,
982982
msa_pwa_heads = 8,
983983
msa_pwa_dim_head = 32,
984+
checkpoint = False,
985+
checkpoint_segments = 1,
984986
pairwise_block_kwargs: dict = dict(),
985987
max_num_msa: int | None = None,
986988
layerscale_output: bool = True
@@ -1028,10 +1030,105 @@ def __init__(
10281030
pairwise_block
10291031
]))
10301032

1033+
self.checkpoint = checkpoint
1034+
self.checkpoint_segments = checkpoint_segments
1035+
10311036
self.layers = layers
10321037

10331038
self.layerscale_output = nn.Parameter(torch.zeros(dim_pairwise)) if layerscale_output else 1.
10341039

1040+
@typecheck
1041+
def to_layers(
1042+
self,
1043+
*,
1044+
pairwise_repr: Float['b n n dp'],
1045+
msa: Float['b s n dm'],
1046+
mask: Bool['b n'] | None = None,
1047+
msa_mask: Bool['b s'] | None = None,
1048+
) -> Float['b n n dp']:
1049+
1050+
for (
1051+
outer_product_mean,
1052+
msa_pair_weighted_avg,
1053+
msa_transition,
1054+
pairwise_block
1055+
) in self.layers:
1056+
1057+
# communication between msa and pairwise rep
1058+
1059+
pairwise_repr = outer_product_mean(msa, mask = mask, msa_mask = msa_mask) + pairwise_repr
1060+
1061+
msa = msa_pair_weighted_avg(msa = msa, pairwise_repr = pairwise_repr, mask = mask) + msa
1062+
msa = msa_transition(msa) + msa
1063+
1064+
# pairwise block
1065+
1066+
pairwise_repr = pairwise_block(pairwise_repr = pairwise_repr, mask = mask)
1067+
1068+
return pairwise_repr
1069+
1070+
@typecheck
1071+
def to_checkpointed_layers(
1072+
self,
1073+
*,
1074+
pairwise_repr: Float['b n n dp'],
1075+
msa: Float['b s n dm'],
1076+
mask: Bool['b n'] | None = None,
1077+
msa_mask: Bool['b s'] | None = None,
1078+
) -> Float['b n n dp']:
1079+
1080+
inputs = (pairwise_repr, mask, msa, msa_mask)
1081+
1082+
wrapped_layers = []
1083+
1084+
def outer_product_mean_wrapper(fn):
1085+
@wraps(fn)
1086+
def inner(inputs):
1087+
pairwise_repr, mask, msa, msa_mask = inputs
1088+
pairwise_repr = fn(msa = msa, mask = mask, msa_mask = msa_mask) + pairwise_repr
1089+
return pairwise_repr, mask, msa, msa_mask
1090+
return inner
1091+
1092+
def msa_pair_weighted_avg_wrapper(fn):
1093+
@wraps(fn)
1094+
def inner(inputs):
1095+
pairwise_repr, mask, msa, msa_mask = inputs
1096+
msa = fn(msa = msa, pairwise_repr = pairwise_repr, mask = mask) + msa
1097+
return pairwise_repr, mask, msa, msa_mask
1098+
return inner
1099+
1100+
def pairwise_block_wrapper(fn):
1101+
@wraps(fn)
1102+
def inner(inputs):
1103+
pairwise_repr, mask, msa, msa_mask = inputs
1104+
pairwise_repr = fn(pairwise_repr = pairwise_repr, mask = mask)
1105+
return pairwise_repr, mask, msa, msa_mask
1106+
return inner
1107+
1108+
def msa_transition_wrapper(fn):
1109+
@wraps(fn)
1110+
def inner(inputs):
1111+
pairwise_repr, mask, msa, msa_mask = inputs
1112+
msa = fn(msa) + msa
1113+
return pairwise_repr, mask, msa, msa_mask
1114+
return inner
1115+
1116+
for (
1117+
outer_product_mean,
1118+
msa_pair_weighted_avg,
1119+
msa_transition,
1120+
pairwise_block
1121+
) in self.layers:
1122+
1123+
wrapped_layers.append(outer_product_mean_wrapper(outer_product_mean))
1124+
wrapped_layers.append(msa_pair_weighted_avg_wrapper(msa_pair_weighted_avg))
1125+
wrapped_layers.append(msa_transition_wrapper(msa_transition))
1126+
wrapped_layers.append(pairwise_block_wrapper(pairwise_block))
1127+
1128+
pairwise_repr, *_ = checkpoint_sequential(wrapped_layers, self.checkpoint_segments, inputs, use_reentrant = False)
1129+
1130+
return pairwise_repr
1131+
10351132
@typecheck
10361133
def forward(
10371134
self,
@@ -1073,23 +1170,21 @@ def forward(
10731170

10741171
msa = rearrange(single_msa_feats, 'b n d -> b 1 n d') + msa
10751172

1076-
for (
1077-
outer_product_mean,
1078-
msa_pair_weighted_avg,
1079-
msa_transition,
1080-
pairwise_block
1081-
) in self.layers:
1082-
1083-
# communication between msa and pairwise rep
1084-
1085-
pairwise_repr = outer_product_mean(msa, mask = mask, msa_mask = msa_mask) + pairwise_repr
1173+
# going through the layers
10861174

1087-
msa = msa_pair_weighted_avg(msa = msa, pairwise_repr = pairwise_repr, mask = mask) + msa
1088-
msa = msa_transition(msa) + msa
1175+
if should_checkpoint(self, (pairwise_repr, msa)):
1176+
to_layers_fn = self.to_checkpointed_layers
1177+
else:
1178+
to_layers_fn = self.to_layers
10891179

1090-
# pairwise block
1180+
pairwise_repr = to_layers_fn(
1181+
msa = msa,
1182+
mask = mask,
1183+
pairwise_repr = pairwise_repr,
1184+
msa_mask = msa_mask
1185+
)
10911186

1092-
pairwise_repr = pairwise_block(pairwise_repr = pairwise_repr, mask = mask)
1187+
# final masking and then layer scale
10931188

10941189
if exists(msa_mask):
10951190
pairwise_repr = einx.where(
@@ -1208,20 +1303,23 @@ def to_checkpointed_layers(
12081303
inputs = (single_repr, pairwise_repr, mask)
12091304

12101305
def pairwise_block_wrapper(layer):
1306+
@wraps(layer)
12111307
def inner(inputs, *args, **kwargs):
12121308
single_repr, pairwise_repr, mask = inputs
12131309
pairwise_repr = layer(pairwise_repr = pairwise_repr, mask = mask)
12141310
return single_repr, pairwise_repr, mask
12151311
return inner
12161312

12171313
def pair_bias_attn_wrapper(layer):
1314+
@wraps(layer)
12181315
def inner(inputs, *args, **kwargs):
12191316
single_repr, pairwise_repr, mask = inputs
12201317
single_repr = layer(single_repr, pairwise_repr = pairwise_repr, mask = mask) + single_repr
12211318
return single_repr, pairwise_repr, mask
12221319
return inner
12231320

12241321
def single_transition_wrapper(layer):
1322+
@wraps(layer)
12251323
def inner(inputs, *args, **kwargs):
12261324
single_repr, pairwise_repr, mask = inputs
12271325
single_repr = layer(single_repr) + single_repr
@@ -1725,20 +1823,23 @@ def to_checkpointed_serial_layers(
17251823
wrapped_layers = []
17261824

17271825
def efficient_attn_wrapper(fn):
1826+
@wraps(fn)
17281827
def inner(inputs):
17291828
noised_repr, single_repr, pairwise_repr, mask, windowed_mask = inputs
17301829
noised_repr = fn(noised_repr, mask = mask) + noised_repr
17311830
return noised_repr, single_repr, pairwise_repr, mask, windowed_mask
17321831
return inner
17331832

17341833
def attn_wrapper(fn):
1834+
@wraps(fn)
17351835
def inner(inputs):
17361836
noised_repr, single_repr, pairwise_repr, mask, windowed_mask = inputs
17371837
noised_repr = fn(noised_repr, cond = single_repr, pairwise_repr = pairwise_repr, mask = mask, windowed_mask = windowed_mask) + noised_repr
17381838
return noised_repr, single_repr, pairwise_repr, mask, windowed_mask
17391839
return inner
17401840

17411841
def transition_wrapper(fn):
1842+
@wraps(fn)
17421843
def inner(inputs):
17431844
noised_repr, single_repr, pairwise_repr, mask, windowed_mask = inputs
17441845
noised_repr = fn(noised_repr, cond = single_repr) + noised_repr

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.2.64"
3+
version = "0.2.65"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

tests/test_af3.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -208,14 +208,17 @@ def test_pairformer(
208208
loss = single_out.sum() + pairwise_out.sum()
209209
loss.backward()
210210

211-
def test_msa_module():
212-
213-
single = torch.randn(2, 16, 384)
214-
pairwise = torch.randn(2, 16, 16, 128)
211+
@pytest.mark.parametrize('checkpoint', (False, True))
212+
def test_msa_module(
213+
checkpoint
214+
):
215+
single = torch.randn(2, 16, 384).requires_grad_()
216+
pairwise = torch.randn(2, 16, 16, 128).requires_grad_()
215217
msa = torch.randn(2, 7, 16, 64)
216218
mask = torch.randint(0, 2, (2, 16)).bool()
217219

218220
msa_module = MSAModule(
221+
checkpoint = checkpoint,
219222
max_num_msa = 3 # will randomly select 3 out of the MSAs, accounting for mask, using sample without replacement
220223
)
221224

@@ -228,6 +231,10 @@ def test_msa_module():
228231

229232
assert pairwise.shape == pairwise_out.shape
230233

234+
if checkpoint:
235+
loss = pairwise_out.sum()
236+
loss.backward()
237+
231238
@pytest.mark.parametrize('serial,checkpoint', ((False, False), (True, False), (True, True)))
232239
@pytest.mark.parametrize('use_linear_attn', (False, True))
233240
@pytest.mark.parametrize('use_colt5_attn', (False, True))

0 commit comments

Comments
 (0)