Skip to content

Commit 73e4471

Browse files
committed
slightly more efficient linear attention
1 parent 7fe88a5 commit 73e4471

File tree

3 files changed

+24
-13
lines changed

3 files changed

+24
-13
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,11 @@
149149

150150
LinearNoBias = partial(Linear, bias = False)
151151

152+
# always use non reentrant checkpointing
153+
154+
checkpoint = partial(checkpoint, use_reentrant = False)
155+
checkpoint_sequential = partial(checkpoint_sequential, use_reentrant = False)
156+
152157
# helper functions
153158

154159
def exists(v):
@@ -1179,7 +1184,7 @@ def inner(inputs):
11791184
wrapped_layers.append(msa_transition_wrapper(msa_transition))
11801185
wrapped_layers.append(pairwise_block_wrapper(pairwise_block))
11811186

1182-
pairwise_repr, *_ = checkpoint_sequential(wrapped_layers, self.checkpoint_segments, inputs, use_reentrant = False)
1187+
pairwise_repr, *_ = checkpoint_sequential(wrapped_layers, self.checkpoint_segments, inputs)
11831188

11841189
return pairwise_repr
11851190

@@ -1399,7 +1404,7 @@ def inner(inputs, *args, **kwargs):
13991404
wrapped_layers.append(pair_bias_attn_wrapper(pair_bias_attn))
14001405
wrapped_layers.append(single_transition_wrapper(single_transition))
14011406

1402-
single_repr, pairwise_repr, _ = checkpoint_sequential(wrapped_layers, self.checkpoint_segments, inputs, use_reentrant = False)
1407+
single_repr, pairwise_repr, _ = checkpoint_sequential(wrapped_layers, self.checkpoint_segments, inputs)
14031408

14041409
return single_repr, pairwise_repr
14051410

@@ -1618,7 +1623,7 @@ def inner(inputs):
16181623
for block in self.pairformer_stack:
16191624
wrapped_layers.append(block_wrapper(block))
16201625

1621-
templates, _ = checkpoint_sequential(wrapped_layers, self.checkpoint_segments, inputs, use_reentrant = False)
1626+
templates, _ = checkpoint_sequential(wrapped_layers, self.checkpoint_segments, inputs)
16221627

16231628
return templates
16241629

@@ -1858,6 +1863,7 @@ def __init__(
18581863
dim = dim,
18591864
prenorm = True,
18601865
gate_value_heads = True,
1866+
remove_even_power_dups = True,
18611867
**linear_attn_kwargs
18621868
)
18631869

@@ -1971,7 +1977,7 @@ def inner(inputs):
19711977
wrapped_layers.append(attn_wrapper(attn))
19721978
wrapped_layers.append(transition_wrapper(transition))
19731979

1974-
out = checkpoint_sequential(wrapped_layers, self.checkpoint_segments, inputs, use_reentrant = False)
1980+
out = checkpoint_sequential(wrapped_layers, self.checkpoint_segments, inputs)
19751981

19761982
noised_repr, *_ = out
19771983
return noised_repr
@@ -2439,7 +2445,7 @@ def forward(
24392445
token_transformer = self.token_transformer
24402446

24412447
if should_checkpoint(self, tokens, 'checkpoint_token_transformer'):
2442-
token_transformer = partial(checkpoint, token_transformer, use_reentrant = False)
2448+
token_transformer = partial(checkpoint, token_transformer)
24432449

24442450
# token transformer
24452451

@@ -2917,9 +2923,7 @@ def forward(
29172923
mask = mask & paired_coords_mask
29182924

29192925
# Calculate masked averaging
2920-
lddt_sum = (eps * mask).sum(dim=(-1, -2))
2921-
lddt_count = mask.sum(dim=(-1, -2))
2922-
lddt = lddt_sum / lddt_count.clamp(min=1)
2926+
lddt = masked_average(eps, mask = mask, dim = (-1, -2), eps = 1)
29232927

29242928
return 1. - lddt.mean()
29252929

@@ -4994,6 +4998,7 @@ def __init__(
49944998
distogram_atom_resolution = False,
49954999
checkpoint_input_embedding = False,
49965000
checkpoint_trunk_pairformer = False,
5001+
checkpoint_distogram_head = True,
49975002
checkpoint_diffusion_token_transformer = False,
49985003
detach_when_recycling = True,
49995004
pdb_training_set=True,
@@ -5213,6 +5218,7 @@ def __init__(
52135218

52145219
self.checkpoint_trunk_pairformer = checkpoint_trunk_pairformer
52155220
self.checkpoint_diffusion_token_transformer = checkpoint_diffusion_token_transformer
5221+
self.checkpoint_distogram_head = checkpoint_distogram_head
52165222

52175223
# loss related
52185224

@@ -5566,7 +5572,7 @@ def forward(
55665572
pairformer = self.pairformer
55675573

55685574
if should_checkpoint(self, (single, pairwise), 'checkpoint_trunk_pairformer'):
5569-
pairformer = partial(checkpoint, pairformer, use_reentrant = False)
5575+
pairformer = partial(checkpoint, pairformer)
55705576

55715577
# main attention trunk (pairformer)
55725578

@@ -5693,7 +5699,12 @@ def forward(
56935699

56945700
distance_labels = torch.where(distogram_mask, distance_labels, ignore)
56955701

5696-
distogram_logits = self.distogram_head(
5702+
distogram_head_fn = self.distogram_head
5703+
5704+
if should_checkpoint(self, pairwise, 'checkpoint_distogram_head'):
5705+
distogram_head_fn = partial(checkpoint, distogram_head_fn)
5706+
5707+
distogram_logits = distogram_head_fn(
56975708
pairwise,
56985709
molecule_atom_lens = molecule_atom_lens,
56995710
atom_feats = atom_feats

alphafold3_pytorch/inputs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,7 @@ def file_to_atom_input(path: str | Path) -> AtomInput:
276276

277277
assert path.is_file()
278278

279-
atom_input_dict = torch.load(str(path))
279+
atom_input_dict = torch.load(str(path), weights_only = True)
280280
return AtomInput(**atom_input_dict)
281281

282282
@typecheck

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.3.10"
3+
version = "0.3.11"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" },
@@ -49,7 +49,7 @@ dependencies = [
4949
"sh>=2.0.7",
5050
"shortuuid",
5151
"tensorboard",
52-
"taylor-series-linear-attention>=0.1.11",
52+
"taylor-series-linear-attention>=0.1.12",
5353
"torchtyping>=0.1.5",
5454
"timeout_decorator>=0.5.0",
5555
'torch_geometric',

0 commit comments

Comments
 (0)