1313from torch import Tensor
1414from torch .amp import autocast
1515import torch .nn .functional as F
16- from torch .utils .checkpoint import checkpoint , checkpoint_sequential
1716
1817from torch .nn import (
1918 Module ,
7372 ExpressCoordinatesInFrame ,
7473 RigidFrom3Points ,
7574 calculate_weighted_rigid_align_weights ,
75+ package_available ,
7676)
7777
7878from frame_averaging_pytorch import FrameAverage
8484import einx
8585from einops import rearrange , repeat , reduce , einsum , pack , unpack
8686from einops .layers .torch import Rearrange
87+ from environs import Env
8788
8889from tqdm import tqdm
8990
169170
170171LinearNoBias = partial (Linear , bias = False )
171172
173+ # environment
174+
175+ env = Env ()
176+ env .read_env ()
177+
172178# always use non reentrant checkpointing
173179
174- checkpoint = partial (checkpoint , use_reentrant = False )
175- checkpoint_sequential = partial (checkpoint_sequential , use_reentrant = False )
180+ DEEPSPEED_CHECKPOINTING = env .bool ('DEEPSPEED_CHECKPOINTING' , False )
181+
182+ if DEEPSPEED_CHECKPOINTING :
183+ assert package_available ("deepspeed" ), "DeepSpeed must be installed for checkpointing."
184+
185+ import deepspeed
186+
187+ checkpoint = deepspeed .checkpointing .checkpoint
188+ else :
189+ checkpoint = partial (torch .utils .checkpoint .checkpoint , use_reentrant = False )
176190
177191# helper functions
178192
@@ -1061,7 +1075,6 @@ def __init__(
10611075 msa_pwa_heads = 8 ,
10621076 msa_pwa_dim_head = 32 ,
10631077 checkpoint = False ,
1064- checkpoint_segments = 1 ,
10651078 pairwise_block_kwargs : dict = dict (),
10661079 max_num_msa : int | None = None ,
10671080 layerscale_output : bool = True
@@ -1112,7 +1125,6 @@ def __init__(
11121125 ]))
11131126
11141127 self .checkpoint = checkpoint
1115- self .checkpoint_segments = checkpoint_segments
11161128
11171129 self .layers = layers
11181130
@@ -1182,19 +1194,19 @@ def inner(inputs):
11821194 return pairwise_repr , mask , msa , msa_mask
11831195 return inner
11841196
1185- def pairwise_block_wrapper (fn ):
1197+ def msa_transition_wrapper (fn ):
11861198 @wraps (fn )
11871199 def inner (inputs ):
11881200 pairwise_repr , mask , msa , msa_mask = inputs
1189- pairwise_repr = fn (pairwise_repr = pairwise_repr , mask = mask )
1201+ msa = fn (msa ) + msa
11901202 return pairwise_repr , mask , msa , msa_mask
11911203 return inner
11921204
1193- def msa_transition_wrapper (fn ):
1205+ def pairwise_block_wrapper (fn ):
11941206 @wraps (fn )
11951207 def inner (inputs ):
11961208 pairwise_repr , mask , msa , msa_mask = inputs
1197- msa = fn (msa ) + msa
1209+ pairwise_repr = fn (pairwise_repr = pairwise_repr , mask = mask )
11981210 return pairwise_repr , mask , msa , msa_mask
11991211 return inner
12001212
@@ -1210,8 +1222,10 @@ def inner(inputs):
12101222 wrapped_layers .append (msa_transition_wrapper (msa_transition ))
12111223 wrapped_layers .append (pairwise_block_wrapper (pairwise_block ))
12121224
1213- pairwise_repr , * _ = checkpoint_sequential (wrapped_layers , self .checkpoint_segments , inputs )
1225+ for layer in wrapped_layers :
1226+ inputs = checkpoint (layer , inputs )
12141227
1228+ pairwise_repr , * _ = inputs
12151229 return pairwise_repr
12161230
12171231 @typecheck
@@ -1318,7 +1332,6 @@ def __init__(
13181332 dropout_row_prob = 0.25 ,
13191333 num_register_tokens = 0 ,
13201334 checkpoint = False ,
1321- checkpoint_segments = 1 ,
13221335 pairwise_block_kwargs : dict = dict (),
13231336 pair_bias_attn_kwargs : dict = dict ()
13241337 ):
@@ -1357,7 +1370,6 @@ def __init__(
13571370 # checkpointing
13581371
13591372 self .checkpoint = checkpoint
1360- self .checkpoint_segments = checkpoint_segments
13611373
13621374 # https://arxiv.org/abs/2405.16039 and https://arxiv.org/abs/2405.15071
13631375 # although possibly recycling already takes care of this
@@ -1446,8 +1458,10 @@ def inner(inputs, *args, **kwargs):
14461458 wrapped_layers .append (pair_bias_attn_wrapper (pair_bias_attn ))
14471459 wrapped_layers .append (single_transition_wrapper (single_transition ))
14481460
1449- single_repr , pairwise_repr , _ = checkpoint_sequential (wrapped_layers , self .checkpoint_segments , inputs )
1461+ for layer in wrapped_layers :
1462+ inputs = checkpoint (layer , inputs )
14501463
1464+ single_repr , pairwise_repr , _ = inputs
14511465 return single_repr , pairwise_repr
14521466
14531467 @typecheck
@@ -1590,7 +1604,6 @@ def __init__(
15901604 pairwise_block_kwargs : dict = dict (),
15911605 eps = 1e-5 ,
15921606 checkpoint = False ,
1593- checkpoint_segments = 1 ,
15941607 layerscale_output = True
15951608 ):
15961609 super ().__init__ ()
@@ -1615,7 +1628,6 @@ def __init__(
16151628 self .pairformer_stack = layers
16161629
16171630 self .checkpoint = checkpoint
1618- self .checkpoint_segments = checkpoint_segments
16191631
16201632 self .final_norm = nn .LayerNorm (dim )
16211633
@@ -1666,8 +1678,10 @@ def inner(inputs):
16661678 for block in self .pairformer_stack :
16671679 wrapped_layers .append (block_wrapper (block ))
16681680
1669- templates , _ = checkpoint_sequential (wrapped_layers , self .checkpoint_segments , inputs )
1681+ for layer in wrapped_layers :
1682+ inputs = checkpoint (layer , inputs )
16701683
1684+ templates , _ = inputs
16711685 return templates
16721686
16731687 @typecheck
@@ -1877,7 +1891,6 @@ def __init__(
18771891 add_residual = True ,
18781892 use_linear_attn = False ,
18791893 checkpoint = False ,
1880- checkpoint_segments = 1 ,
18811894 linear_attn_kwargs = dict (
18821895 heads = 8 ,
18831896 dim_head = 16
@@ -1956,7 +1969,6 @@ def __init__(
19561969 assert not (not serial and checkpoint ), 'checkpointing can only be used for serial version of diffusion transformer'
19571970
19581971 self .checkpoint = checkpoint
1959- self .checkpoint_segments = checkpoint_segments
19601972
19611973 self .layers = layers
19621974
@@ -2021,9 +2033,10 @@ def inner(inputs):
20212033 wrapped_layers .append (attn_wrapper (attn ))
20222034 wrapped_layers .append (transition_wrapper (transition ))
20232035
2024- out = checkpoint_sequential (wrapped_layers , self .checkpoint_segments , inputs )
2036+ for layer in wrapped_layers :
2037+ inputs = checkpoint (layer , inputs )
20252038
2026- noised_repr , * _ = out
2039+ noised_repr , * _ = inputs
20272040 return noised_repr
20282041
20292042 @typecheck
@@ -2314,10 +2327,6 @@ def __init__(
23142327
23152328 self .attended_token_norm = nn .LayerNorm (dim_token )
23162329
2317- # checkpointing
2318-
2319- self .checkpoint_token_transformer = checkpoint_token_transformer
2320-
23212330 # atom attention decoding related modules
23222331
23232332 self .tokens_to_atom_decoder_input_cond = LinearNoBias (dim_token , dim_atom )
@@ -2332,6 +2341,7 @@ def __init__(
23322341 serial = serial ,
23332342 use_linear_attn = use_linear_attn ,
23342343 linear_attn_kwargs = linear_attn_kwargs ,
2344+ checkpoint = checkpoint_token_transformer ,
23352345 ** atom_decoder_kwargs
23362346 )
23372347
@@ -2484,18 +2494,11 @@ def forward(
24842494 molecule_atom_lens = molecule_atom_lens
24852495 )
24862496
2487- # maybe checkpoint token transformer
2488-
2489- token_transformer = self .token_transformer
2490-
2491- if should_checkpoint (self , tokens , 'checkpoint_token_transformer' ):
2492- token_transformer = partial (checkpoint , token_transformer )
2493-
24942497 # token transformer
24952498
24962499 tokens = self .cond_tokens_with_cond_single (conditioned_single_repr ) + tokens
24972500
2498- tokens = token_transformer (
2501+ tokens = self . token_transformer (
24992502 tokens ,
25002503 mask = mask ,
25012504 single_repr = conditioned_single_repr ,
@@ -5991,6 +5994,7 @@ def __init__(
59915994 dim_template_feats = dim_template_feats ,
59925995 dim = dim_template_model ,
59935996 dim_pairwise = dim_pairwise ,
5997+ checkpoint = checkpoint_input_embedding ,
59945998 ** template_embedder_kwargs
59955999 )
59966000
@@ -6003,6 +6007,7 @@ def __init__(
60036007 dim_pairwise = dim_pairwise ,
60046008 dim_msa_input = dim_msa_inputs ,
60056009 dim_additional_msa_feats = dim_additional_msa_feats ,
6010+ checkpoint = checkpoint_input_embedding ,
60066011 ** msa_module_kwargs ,
60076012 )
60086013
@@ -6011,6 +6016,7 @@ def __init__(
60116016 self .pairformer = PairformerStack (
60126017 dim_single = dim_single ,
60136018 dim_pairwise = dim_pairwise ,
6019+ checkpoint = checkpoint_trunk_pairformer ,
60146020 ** pairformer_stack
60156021 )
60166022
@@ -6115,13 +6121,6 @@ def __init__(
61156121
61166122 self .register_buffer ('lddt_thresholds' , torch .tensor ([0.5 , 1.0 , 2.0 , 4.0 ]))
61176123
6118- # checkpointing related
6119-
6120- self .checkpoint_trunk_pairformer = checkpoint_trunk_pairformer
6121- self .checkpoint_diffusion_token_transformer = checkpoint_diffusion_token_transformer
6122- self .checkpoint_distogram_head = checkpoint_distogram_head
6123- self .checkpoint_confidence_head = checkpoint_confidence_head
6124-
61256124 # loss related
61266125
61276126 self .ignore_index = ignore_index
@@ -6510,16 +6509,9 @@ def forward(
65106509
65116510 pairwise = embedded_msa + pairwise
65126511
6513- # maybe checkpoint trunk pairformer
6514-
6515- pairformer = self .pairformer
6516-
6517- if should_checkpoint (self , (single , pairwise ), 'checkpoint_trunk_pairformer' ):
6518- pairformer = partial (checkpoint , pairformer )
6519-
65206512 # main attention trunk (pairformer)
65216513
6522- single , pairwise = pairformer (
6514+ single , pairwise = self . pairformer (
65236515 single_repr = single ,
65246516 pairwise_repr = pairwise ,
65256517 mask = mask
@@ -6650,12 +6642,7 @@ def forward(
66506642
66516643 distance_labels = torch .where (distogram_mask , distance_labels , ignore )
66526644
6653- distogram_head_fn = self .distogram_head
6654-
6655- if should_checkpoint (self , pairwise , 'checkpoint_distogram_head' ):
6656- distogram_head_fn = partial (checkpoint , distogram_head_fn )
6657-
6658- distogram_logits = distogram_head_fn (
6645+ distogram_logits = self .distogram_head (
66596646 pairwise ,
66606647 molecule_atom_lens = molecule_atom_lens ,
66616648 atom_feats = atom_feats
0 commit comments