Skip to content

Commit 8046385

Browse files
committed
first take care of checkpointing for pairformer stack
1 parent 051ee24 commit 8046385

File tree

4 files changed

+138
-25
lines changed

4 files changed

+138
-25
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 123 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from torch import nn
1111
from torch import Tensor
1212
import torch.nn.functional as F
13-
from loguru import logger
13+
from torch.utils.checkpoint import checkpoint, checkpoint_sequential
1414

1515
from torch.nn import (
1616
Module,
@@ -42,22 +42,20 @@
4242
from alphafold3_pytorch.inputs import (
4343
IS_MOLECULE_TYPES,
4444
IS_PROTEIN_INDEX,
45+
IS_DNA_INDEX,
46+
IS_RNA_INDEX,
4547
IS_LIGAND_INDEX,
4648
IS_METAL_ION_INDEX,
4749
IS_BIOMOLECULE_INDICES,
50+
IS_PROTEIN,
51+
IS_DNA,
52+
IS_RNA,
53+
IS_LIGAND,
54+
IS_METAL_ION,
4855
NUM_MOLECULE_IDS,
4956
ADDITIONAL_MOLECULE_FEATS
5057
)
5158

52-
53-
IS_DNA_INDEX = 1
54-
IS_RNA_INDEX = 2
55-
56-
IS_PROTEIN, IS_DNA, IS_RNA, IS_LIGAND, IS_METAL_ION = map(
57-
lambda x: IS_MOLECULE_TYPES - x if x < 0 else x, [
58-
IS_PROTEIN_INDEX, IS_DNA_INDEX, IS_RNA_INDEX, IS_LIGAND_INDEX, IS_METAL_ION_INDEX])
59-
60-
6159
from frame_averaging_pytorch import FrameAverage
6260

6361
from taylor_series_linear_attention import TaylorSeriesLinearAttn
@@ -70,6 +68,8 @@
7068

7169
from tqdm import tqdm
7270

71+
from loguru import logger
72+
7373
from importlib.metadata import version
7474

7575
from huggingface_hub import PyTorchModelHubMixin, hf_hub_download
@@ -169,6 +169,21 @@ def unpack_one(to_unpack, unpack_pattern = None):
169169
def exclusive_cumsum(t, dim = -1):
170170
return t.cumsum(dim = dim) - t
171171

172+
# checkpointing utils
173+
174+
@typecheck
175+
def should_checkpoint(
176+
self: Module,
177+
inputs: Tuple[Tensor, ...],
178+
check_instance_variable: str | None = 'checkpoint'
179+
) -> bool:
180+
181+
return (
182+
self.training and
183+
any([i.requires_grad for i in inputs]) and
184+
(not exists(check_instance_variable) or getattr(self, check_instance_variable, False))
185+
)
186+
172187
# decorators
173188

174189
def maybe(fn):
@@ -350,8 +365,7 @@ def repeat_consecutive_with_lens(
350365

351366
# final mask
352367

353-
if mask_value is None:
354-
mask_value = False if dtype == torch.bool else 0
368+
mask_value = default(mask_value, False if dtype == torch.bool else 0)
355369

356370
output = einx.where(
357371
'b n, b n ..., -> b n ...',
@@ -1101,6 +1115,8 @@ def __init__(
11011115
pair_bias_attn_heads = 16,
11021116
dropout_row_prob = 0.25,
11031117
num_register_tokens = 0,
1118+
checkpoint = False,
1119+
checkpoint_segments = 1,
11041120
pairwise_block_kwargs: dict = dict(),
11051121
pair_bias_attn_kwargs: dict = dict()
11061122
):
@@ -1136,6 +1152,11 @@ def __init__(
11361152

11371153
self.layers = layers
11381154

1155+
# checkpointing
1156+
1157+
self.checkpoint = checkpoint
1158+
self.checkpoint_segments = checkpoint_segments
1159+
11391160
# https://arxiv.org/abs/2405.16039 and https://arxiv.org/abs/2405.15071
11401161
# although possibly recycling already takes care of this
11411162

@@ -1150,6 +1171,80 @@ def __init__(
11501171
self.pairwise_row_registers = nn.Parameter(torch.zeros(num_register_tokens, dim_pairwise))
11511172
self.pairwise_col_registers = nn.Parameter(torch.zeros(num_register_tokens, dim_pairwise))
11521173

1174+
@typecheck
1175+
def to_layers(
1176+
self,
1177+
*,
1178+
single_repr: Float['b n ds'],
1179+
pairwise_repr: Float['b n n dp'],
1180+
mask: Bool['b n'] | None = None
1181+
1182+
) -> Tuple[Float['b n ds'], Float['b n n dp']]:
1183+
1184+
for _ in range(self.recurrent_depth):
1185+
for (
1186+
pairwise_block,
1187+
pair_bias_attn,
1188+
single_transition
1189+
) in self.layers:
1190+
1191+
pairwise_repr = pairwise_block(pairwise_repr = pairwise_repr, mask = mask)
1192+
1193+
single_repr = pair_bias_attn(single_repr, pairwise_repr = pairwise_repr, mask = mask) + single_repr
1194+
single_repr = single_transition(single_repr) + single_repr
1195+
1196+
return single_repr, pairwise_repr
1197+
1198+
@typecheck
1199+
def to_checkpointed_layers(
1200+
self,
1201+
*,
1202+
single_repr: Float['b n ds'],
1203+
pairwise_repr: Float['b n n dp'],
1204+
mask: Bool['b n'] | None = None
1205+
1206+
) -> Tuple[Float['b n ds'], Float['b n n dp']]:
1207+
1208+
inputs = (single_repr, pairwise_repr, mask)
1209+
1210+
def pairwise_block_wrapper(layer):
1211+
def inner(inputs, *args, **kwargs):
1212+
single_repr, pairwise_repr, mask = inputs
1213+
pairwise_repr = layer(pairwise_repr = pairwise_repr, mask = mask)
1214+
return single_repr, pairwise_repr, mask
1215+
return inner
1216+
1217+
def pair_bias_attn_wrapper(layer):
1218+
def inner(inputs, *args, **kwargs):
1219+
single_repr, pairwise_repr, mask = inputs
1220+
single_repr = layer(single_repr, pairwise_repr = pairwise_repr, mask = mask) + single_repr
1221+
return single_repr, pairwise_repr, mask
1222+
return inner
1223+
1224+
def single_transition_wrapper(layer):
1225+
def inner(inputs, *args, **kwargs):
1226+
single_repr, pairwise_repr, mask = inputs
1227+
single_repr = layer(single_repr) + single_repr
1228+
return single_repr, pairwise_repr, mask
1229+
return inner
1230+
1231+
wrapped_layers = []
1232+
1233+
for _ in range(self.recurrent_depth):
1234+
for (
1235+
pairwise_block,
1236+
pair_bias_attn,
1237+
single_transition
1238+
) in self.layers:
1239+
1240+
wrapped_layers.append(pairwise_block_wrapper(pairwise_block))
1241+
wrapped_layers.append(pair_bias_attn_wrapper(pair_bias_attn))
1242+
wrapped_layers.append(single_transition_wrapper(single_transition))
1243+
1244+
single_repr, pairwise_repr, _ = checkpoint_sequential(wrapped_layers, self.checkpoint_segments, inputs)
1245+
1246+
return single_repr, pairwise_repr
1247+
11531248
@typecheck
11541249
def forward(
11551250
self,
@@ -1175,19 +1270,20 @@ def forward(
11751270
if exists(mask):
11761271
mask = F.pad(mask, (num_registers, 0), value = True)
11771272

1178-
# main transformer block layers
1273+
# maybe checkpoint
11791274

1180-
for _ in range(self.recurrent_depth):
1181-
for (
1182-
pairwise_block,
1183-
pair_bias_attn,
1184-
single_transition
1185-
) in self.layers:
1275+
if should_checkpoint(self, (single_repr, pairwise_repr)):
1276+
to_layers_fn = self.to_checkpointed_layers
1277+
else:
1278+
to_layers_fn = self.to_layers
11861279

1187-
pairwise_repr = pairwise_block(pairwise_repr = pairwise_repr, mask = mask)
1280+
# main transformer block layers
11881281

1189-
single_repr = pair_bias_attn(single_repr, pairwise_repr = pairwise_repr, mask = mask) + single_repr
1190-
single_repr = single_transition(single_repr) + single_repr
1282+
single_repr, pairwise_repr = to_layers_fn(
1283+
single_repr = single_repr,
1284+
pairwise_repr = pairwise_repr,
1285+
mask = mask
1286+
)
11911287

11921288
# splice out registers
11931289

@@ -3644,6 +3740,9 @@ def compute_lddt(
36443740
is_rna: boolean tensor indicating RNA atoms
36453741
pairwise_mask: boolean tensor indicating atompair for which LDDT is computed
36463742
"""
3743+
3744+
atom_seq_len, device = pred_coords.shape[1], pred_coords.device
3745+
36473746
# Compute distances between all pairs of atoms
36483747
pred_dists = torch.cdist(pred_coords, pred_coords)
36493748
true_dists = torch.cdist(true_coords, true_coords)
@@ -3669,7 +3768,7 @@ def compute_lddt(
36693768
)
36703769

36713770
# Compute mean, avoiding self term
3672-
mask = inclusion_radius & ~torch.eye(pred_coords.shape[1], dtype=torch.bool, device=pred_coords.device)
3771+
mask = inclusion_radius & ~torch.eye(atom_seq_len, dtype=torch.bool, device=device)
36733772

36743773
# Take into account variable lengthed atoms in batch
36753774
if exists(coords_mask):
@@ -3700,7 +3799,7 @@ def compute_chain_pair_lddt(
37003799
plddt between atoms maked by asym_mask_a and asym_mask_b
37013800
"""
37023801

3703-
if coords_mask is None:
3802+
if not exists(coords_mask):
37043803
coords_mask = torch.ones_like(asym_mask_a)
37053804

37063805
if asym_mask_a.ndim == 1:

alphafold3_pytorch/inputs.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,10 +64,17 @@
6464

6565
IS_MOLECULE_TYPES = 5
6666
IS_PROTEIN_INDEX = 0
67+
IS_DNA_INDEX = 1
68+
IS_RNA_INDEX = 2
6769
IS_LIGAND_INDEX = -2
6870
IS_METAL_ION_INDEX = -1
6971
IS_BIOMOLECULE_INDICES = slice(0, 3)
7072

73+
IS_PROTEIN, IS_DNA, IS_RNA, IS_LIGAND, IS_METAL_ION = tuple(
74+
(IS_MOLECULE_TYPES - i if i < 0 else i)
75+
for i in [IS_PROTEIN_INDEX, IS_DNA_INDEX, IS_RNA_INDEX, IS_LIGAND_INDEX, IS_METAL_ION_INDEX]
76+
)
77+
7178
MOLECULE_GAP_ID = len(HUMAN_AMINO_ACIDS) + len(RNA_NUCLEOTIDES) + len(DNA_NUCLEOTIDES)
7279
MOLECULE_METAL_ION_ID = MOLECULE_GAP_ID + 1
7380
NUM_MOLECULE_IDS = len(HUMAN_AMINO_ACIDS) + len(RNA_NUCLEOTIDES) + len(DNA_NUCLEOTIDES) + 2

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.2.62"
3+
version = "0.2.63"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

tests/test_af3.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,9 +173,11 @@ def test_centre_random_augmentation():
173173
assert augmented_coords.shape == coords.shape
174174

175175

176+
@pytest.mark.parametrize('checkpoint', (True, False))
176177
@pytest.mark.parametrize('recurrent_depth', (1, 2))
177178
@pytest.mark.parametrize('enable_attn_softclamp', (True, False))
178179
def test_pairformer(
180+
checkpoint,
179181
recurrent_depth,
180182
enable_attn_softclamp
181183
):
@@ -187,6 +189,7 @@ def test_pairformer(
187189
depth = 4,
188190
num_register_tokens = 4,
189191
recurrent_depth = recurrent_depth,
192+
checkpoint = checkpoint,
190193
pair_bias_attn_kwargs = dict(
191194
enable_attn_softclamp = enable_attn_softclamp
192195
)
@@ -201,6 +204,10 @@ def test_pairformer(
201204
assert single.shape == single_out.shape
202205
assert pairwise.shape == pairwise_out.shape
203206

207+
if checkpoint:
208+
loss = single_out.sum() + pairwise_out.sum()
209+
loss.backward()
210+
204211
def test_msa_module():
205212

206213
single = torch.randn(2, 16, 384)

0 commit comments

Comments
 (0)