Skip to content

Commit 7d1025a

Browse files
committed
add two more checkpoints on trunk pairformer and diffusion token transformer
1 parent af6d560 commit 7d1025a

File tree

3 files changed

+36
-6
lines changed

3 files changed

+36
-6
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2130,6 +2130,7 @@ def __init__(
21302130
atom_decoder_kwargs: dict = dict(),
21312131
token_transformer_kwargs: dict = dict(),
21322132
use_linear_attn = False,
2133+
checkpoint_token_transformer = False,
21332134
linear_attn_kwargs: dict = dict(
21342135
heads = 8,
21352136
dim_head = 16
@@ -2222,6 +2223,10 @@ def __init__(
22222223

22232224
self.attended_token_norm = nn.LayerNorm(dim_token)
22242225

2226+
# checkpointing
2227+
2228+
self.checkpoint_token_transformer = checkpoint_token_transformer
2229+
22252230
# atom attention decoding related modules
22262231

22272232
self.tokens_to_atom_decoder_input_cond = LinearNoBias(dim_token, dim_atom)
@@ -2378,11 +2383,18 @@ def forward(
23782383
molecule_atom_lens = molecule_atom_lens
23792384
)
23802385

2386+
# maybe checkpoint token transformer
2387+
2388+
token_transformer = self.token_transformer
2389+
2390+
if should_checkpoint(self, tokens, 'checkpoint_token_transformer'):
2391+
token_transformer = partial(checkpoint, token_transformer, use_reentrant = False)
2392+
23812393
# token transformer
23822394

23832395
tokens = self.cond_tokens_with_cond_single(conditioned_single_repr) + tokens
23842396

2385-
tokens = self.token_transformer(
2397+
tokens = token_transformer(
23862398
tokens,
23872399
mask = mask,
23882400
single_repr = conditioned_single_repr,
@@ -4300,7 +4312,10 @@ def __init__(
43004312
),
43014313
augment_kwargs: dict = dict(),
43024314
stochastic_frame_average = False,
4303-
confidence_head_atom_resolution = False
4315+
confidence_head_atom_resolution = False,
4316+
checkpoint_input_embedding = False,
4317+
checkpoint_trunk_pairformer = False,
4318+
checkpoint_diffusion_token_transformer = False,
43044319
):
43054320
super().__init__()
43064321

@@ -4447,6 +4462,7 @@ def __init__(
44474462
dim_atompair = dim_atompair,
44484463
dim_token = dim_token,
44494464
dim_single = dim_single + dim_single_inputs,
4465+
checkpoint_token_transformer = checkpoint_diffusion_token_transformer,
44504466
**diffusion_module_kwargs
44514467
)
44524468

@@ -4484,6 +4500,11 @@ def __init__(
44844500
**confidence_head_kwargs
44854501
)
44864502

4503+
# checkpointing related
4504+
4505+
self.checkpoint_trunk_pairformer = checkpoint_trunk_pairformer
4506+
self.checkpoint_diffusion_token_transformer = checkpoint_diffusion_token_transformer
4507+
44874508
# loss related
44884509

44894510
self.ignore_index = ignore_index
@@ -4817,9 +4838,16 @@ def forward(
48174838

48184839
pairwise = embedded_msa + pairwise
48194840

4841+
# maybe checkpoint trunk pairformer
4842+
4843+
pairformer = self.pairformer
4844+
4845+
if should_checkpoint(self, (single, pairwise), 'checkpoint_trunk_pairformer'):
4846+
pairformer = partial(checkpoint, pairformer, use_reentrant = False)
4847+
48204848
# main attention trunk (pairformer)
48214849

4822-
single, pairwise = self.pairformer(
4850+
single, pairwise = pairformer(
48234851
single_repr = single,
48244852
pairwise_repr = pairwise,
48254853
mask = mask
@@ -4877,7 +4905,7 @@ def forward(
48774905
pred_atom_pos = confidence_head_atom_pos_input.detach(),
48784906
molecule_atom_indices = molecule_atom_indices,
48794907
molecule_atom_lens = molecule_atom_lens,
4880-
atom_feats = atom_feats,
4908+
atom_feats = atom_feats.detach(),
48814909
mask = mask,
48824910
return_pae_logits = True
48834911
)
@@ -5056,7 +5084,7 @@ def forward(
50565084
molecule_atom_indices = molecule_atom_indices,
50575085
molecule_atom_lens = molecule_atom_lens,
50585086
mask = mask,
5059-
atom_feats = atom_feats,
5087+
atom_feats = atom_feats.detach(),
50605088
return_pae_logits = return_pae_logits
50615089
)
50625090

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.2.70"
3+
version = "0.2.71"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

tests/test_af3.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -673,6 +673,8 @@ def test_alphafold3_without_msa_and_templates():
673673
dim_atom_inputs = 77,
674674
dim_template_feats = 44,
675675
num_dist_bins = 38,
676+
checkpoint_trunk_pairformer = True,
677+
checkpoint_diffusion_token_transformer = True,
676678
confidence_head_kwargs = dict(
677679
pairformer_depth = 1
678680
),

0 commit comments

Comments
 (0)