1010from torch import nn
1111from torch import Tensor
1212import torch .nn .functional as F
13- from loguru import logger
13+ from torch . utils . checkpoint import checkpoint , checkpoint_sequential
1414
1515from torch .nn import (
1616 Module ,
4242from alphafold3_pytorch .inputs import (
4343 IS_MOLECULE_TYPES ,
4444 IS_PROTEIN_INDEX ,
45+ IS_DNA_INDEX ,
46+ IS_RNA_INDEX ,
4547 IS_LIGAND_INDEX ,
4648 IS_METAL_ION_INDEX ,
4749 IS_BIOMOLECULE_INDICES ,
50+ IS_PROTEIN ,
51+ IS_DNA ,
52+ IS_RNA ,
53+ IS_LIGAND ,
54+ IS_METAL_ION ,
4855 NUM_MOLECULE_IDS ,
4956 ADDITIONAL_MOLECULE_FEATS
5057)
5158
52-
53- IS_DNA_INDEX = 1
54- IS_RNA_INDEX = 2
55-
56- IS_PROTEIN , IS_DNA , IS_RNA , IS_LIGAND , IS_METAL_ION = map (
57- lambda x : IS_MOLECULE_TYPES - x if x < 0 else x , [
58- IS_PROTEIN_INDEX , IS_DNA_INDEX , IS_RNA_INDEX , IS_LIGAND_INDEX , IS_METAL_ION_INDEX ])
59-
60-
6159from frame_averaging_pytorch import FrameAverage
6260
6361from taylor_series_linear_attention import TaylorSeriesLinearAttn
7068
7169from tqdm import tqdm
7270
71+ from loguru import logger
72+
7373from importlib .metadata import version
7474
7575from huggingface_hub import PyTorchModelHubMixin , hf_hub_download
@@ -169,6 +169,21 @@ def unpack_one(to_unpack, unpack_pattern = None):
169169def exclusive_cumsum (t , dim = - 1 ):
170170 return t .cumsum (dim = dim ) - t
171171
172+ # checkpointing utils
173+
174+ @typecheck
175+ def should_checkpoint (
176+ self : Module ,
177+ inputs : Tuple [Tensor , ...],
178+ check_instance_variable : str | None = 'checkpoint'
179+ ) -> bool :
180+
181+ return (
182+ self .training and
183+ any ([i .requires_grad for i in inputs ]) and
184+ (not exists (check_instance_variable ) or getattr (self , check_instance_variable , False ))
185+ )
186+
172187# decorators
173188
174189def maybe (fn ):
@@ -350,8 +365,7 @@ def repeat_consecutive_with_lens(
350365
351366 # final mask
352367
353- if mask_value is None :
354- mask_value = False if dtype == torch .bool else 0
368+ mask_value = default (mask_value , False if dtype == torch .bool else 0 )
355369
356370 output = einx .where (
357371 'b n, b n ..., -> b n ...' ,
@@ -1101,6 +1115,8 @@ def __init__(
11011115 pair_bias_attn_heads = 16 ,
11021116 dropout_row_prob = 0.25 ,
11031117 num_register_tokens = 0 ,
1118+ checkpoint = False ,
1119+ checkpoint_segments = 1 ,
11041120 pairwise_block_kwargs : dict = dict (),
11051121 pair_bias_attn_kwargs : dict = dict ()
11061122 ):
@@ -1136,6 +1152,11 @@ def __init__(
11361152
11371153 self .layers = layers
11381154
1155+ # checkpointing
1156+
1157+ self .checkpoint = checkpoint
1158+ self .checkpoint_segments = checkpoint_segments
1159+
11391160 # https://arxiv.org/abs/2405.16039 and https://arxiv.org/abs/2405.15071
11401161 # although possibly recycling already takes care of this
11411162
@@ -1150,6 +1171,80 @@ def __init__(
11501171 self .pairwise_row_registers = nn .Parameter (torch .zeros (num_register_tokens , dim_pairwise ))
11511172 self .pairwise_col_registers = nn .Parameter (torch .zeros (num_register_tokens , dim_pairwise ))
11521173
1174+ @typecheck
1175+ def to_layers (
1176+ self ,
1177+ * ,
1178+ single_repr : Float ['b n ds' ],
1179+ pairwise_repr : Float ['b n n dp' ],
1180+ mask : Bool ['b n' ] | None = None
1181+
1182+ ) -> Tuple [Float ['b n ds' ], Float ['b n n dp' ]]:
1183+
1184+ for _ in range (self .recurrent_depth ):
1185+ for (
1186+ pairwise_block ,
1187+ pair_bias_attn ,
1188+ single_transition
1189+ ) in self .layers :
1190+
1191+ pairwise_repr = pairwise_block (pairwise_repr = pairwise_repr , mask = mask )
1192+
1193+ single_repr = pair_bias_attn (single_repr , pairwise_repr = pairwise_repr , mask = mask ) + single_repr
1194+ single_repr = single_transition (single_repr ) + single_repr
1195+
1196+ return single_repr , pairwise_repr
1197+
1198+ @typecheck
1199+ def to_checkpointed_layers (
1200+ self ,
1201+ * ,
1202+ single_repr : Float ['b n ds' ],
1203+ pairwise_repr : Float ['b n n dp' ],
1204+ mask : Bool ['b n' ] | None = None
1205+
1206+ ) -> Tuple [Float ['b n ds' ], Float ['b n n dp' ]]:
1207+
1208+ inputs = (single_repr , pairwise_repr , mask )
1209+
1210+ def pairwise_block_wrapper (layer ):
1211+ def inner (inputs , * args , ** kwargs ):
1212+ single_repr , pairwise_repr , mask = inputs
1213+ pairwise_repr = layer (pairwise_repr = pairwise_repr , mask = mask )
1214+ return single_repr , pairwise_repr , mask
1215+ return inner
1216+
1217+ def pair_bias_attn_wrapper (layer ):
1218+ def inner (inputs , * args , ** kwargs ):
1219+ single_repr , pairwise_repr , mask = inputs
1220+ single_repr = layer (single_repr , pairwise_repr = pairwise_repr , mask = mask ) + single_repr
1221+ return single_repr , pairwise_repr , mask
1222+ return inner
1223+
1224+ def single_transition_wrapper (layer ):
1225+ def inner (inputs , * args , ** kwargs ):
1226+ single_repr , pairwise_repr , mask = inputs
1227+ single_repr = layer (single_repr ) + single_repr
1228+ return single_repr , pairwise_repr , mask
1229+ return inner
1230+
1231+ wrapped_layers = []
1232+
1233+ for _ in range (self .recurrent_depth ):
1234+ for (
1235+ pairwise_block ,
1236+ pair_bias_attn ,
1237+ single_transition
1238+ ) in self .layers :
1239+
1240+ wrapped_layers .append (pairwise_block_wrapper (pairwise_block ))
1241+ wrapped_layers .append (pair_bias_attn_wrapper (pair_bias_attn ))
1242+ wrapped_layers .append (single_transition_wrapper (single_transition ))
1243+
1244+ single_repr , pairwise_repr , _ = checkpoint_sequential (wrapped_layers , self .checkpoint_segments , inputs )
1245+
1246+ return single_repr , pairwise_repr
1247+
11531248 @typecheck
11541249 def forward (
11551250 self ,
@@ -1175,19 +1270,20 @@ def forward(
11751270 if exists (mask ):
11761271 mask = F .pad (mask , (num_registers , 0 ), value = True )
11771272
1178- # main transformer block layers
1273+ # maybe checkpoint
11791274
1180- for _ in range (self .recurrent_depth ):
1181- for (
1182- pairwise_block ,
1183- pair_bias_attn ,
1184- single_transition
1185- ) in self .layers :
1275+ if should_checkpoint (self , (single_repr , pairwise_repr )):
1276+ to_layers_fn = self .to_checkpointed_layers
1277+ else :
1278+ to_layers_fn = self .to_layers
11861279
1187- pairwise_repr = pairwise_block ( pairwise_repr = pairwise_repr , mask = mask )
1280+ # main transformer block layers
11881281
1189- single_repr = pair_bias_attn (single_repr , pairwise_repr = pairwise_repr , mask = mask ) + single_repr
1190- single_repr = single_transition (single_repr ) + single_repr
1282+ single_repr , pairwise_repr = to_layers_fn (
1283+ single_repr = single_repr ,
1284+ pairwise_repr = pairwise_repr ,
1285+ mask = mask
1286+ )
11911287
11921288 # splice out registers
11931289
@@ -3644,6 +3740,9 @@ def compute_lddt(
36443740 is_rna: boolean tensor indicating RNA atoms
36453741 pairwise_mask: boolean tensor indicating atompair for which LDDT is computed
36463742 """
3743+
3744+ atom_seq_len , device = pred_coords .shape [1 ], pred_coords .device
3745+
36473746 # Compute distances between all pairs of atoms
36483747 pred_dists = torch .cdist (pred_coords , pred_coords )
36493748 true_dists = torch .cdist (true_coords , true_coords )
@@ -3669,7 +3768,7 @@ def compute_lddt(
36693768 )
36703769
36713770 # Compute mean, avoiding self term
3672- mask = inclusion_radius & ~ torch .eye (pred_coords . shape [ 1 ] , dtype = torch .bool , device = pred_coords . device )
3771+ mask = inclusion_radius & ~ torch .eye (atom_seq_len , dtype = torch .bool , device = device )
36733772
36743773 # Take into account variable lengthed atoms in batch
36753774 if exists (coords_mask ):
@@ -3700,7 +3799,7 @@ def compute_chain_pair_lddt(
37003799 plddt between atoms maked by asym_mask_a and asym_mask_b
37013800 """
37023801
3703- if coords_mask is None :
3802+ if not exists ( coords_mask ) :
37043803 coords_mask = torch .ones_like (asym_mask_a )
37053804
37063805 if asym_mask_a .ndim == 1 :
0 commit comments