@@ -2130,6 +2130,7 @@ def __init__(
21302130 atom_decoder_kwargs : dict = dict (),
21312131 token_transformer_kwargs : dict = dict (),
21322132 use_linear_attn = False ,
2133+ checkpoint_token_transformer = False ,
21332134 linear_attn_kwargs : dict = dict (
21342135 heads = 8 ,
21352136 dim_head = 16
@@ -2222,6 +2223,10 @@ def __init__(
22222223
22232224 self .attended_token_norm = nn .LayerNorm (dim_token )
22242225
2226+ # checkpointing
2227+
2228+ self .checkpoint_token_transformer = checkpoint_token_transformer
2229+
22252230 # atom attention decoding related modules
22262231
22272232 self .tokens_to_atom_decoder_input_cond = LinearNoBias (dim_token , dim_atom )
@@ -2378,11 +2383,18 @@ def forward(
23782383 molecule_atom_lens = molecule_atom_lens
23792384 )
23802385
2386+ # maybe checkpoint token transformer
2387+
2388+ token_transformer = self .token_transformer
2389+
2390+ if should_checkpoint (self , tokens , 'checkpoint_token_transformer' ):
2391+ token_transformer = partial (checkpoint , token_transformer , use_reentrant = False )
2392+
23812393 # token transformer
23822394
23832395 tokens = self .cond_tokens_with_cond_single (conditioned_single_repr ) + tokens
23842396
2385- tokens = self . token_transformer (
2397+ tokens = token_transformer (
23862398 tokens ,
23872399 mask = mask ,
23882400 single_repr = conditioned_single_repr ,
@@ -4300,7 +4312,10 @@ def __init__(
43004312 ),
43014313 augment_kwargs : dict = dict (),
43024314 stochastic_frame_average = False ,
4303- confidence_head_atom_resolution = False
4315+ confidence_head_atom_resolution = False ,
4316+ checkpoint_input_embedding = False ,
4317+ checkpoint_trunk_pairformer = False ,
4318+ checkpoint_diffusion_token_transformer = False ,
43044319 ):
43054320 super ().__init__ ()
43064321
@@ -4447,6 +4462,7 @@ def __init__(
44474462 dim_atompair = dim_atompair ,
44484463 dim_token = dim_token ,
44494464 dim_single = dim_single + dim_single_inputs ,
4465+ checkpoint_token_transformer = checkpoint_diffusion_token_transformer ,
44504466 ** diffusion_module_kwargs
44514467 )
44524468
@@ -4484,6 +4500,11 @@ def __init__(
44844500 ** confidence_head_kwargs
44854501 )
44864502
4503+ # checkpointing related
4504+
4505+ self .checkpoint_trunk_pairformer = checkpoint_trunk_pairformer
4506+ self .checkpoint_diffusion_token_transformer = checkpoint_diffusion_token_transformer
4507+
44874508 # loss related
44884509
44894510 self .ignore_index = ignore_index
@@ -4817,9 +4838,16 @@ def forward(
48174838
48184839 pairwise = embedded_msa + pairwise
48194840
4841+ # maybe checkpoint trunk pairformer
4842+
4843+ pairformer = self .pairformer
4844+
4845+ if should_checkpoint (self , (single , pairwise ), 'checkpoint_trunk_pairformer' ):
4846+ pairformer = partial (checkpoint , pairformer , use_reentrant = False )
4847+
48204848 # main attention trunk (pairformer)
48214849
4822- single , pairwise = self . pairformer (
4850+ single , pairwise = pairformer (
48234851 single_repr = single ,
48244852 pairwise_repr = pairwise ,
48254853 mask = mask
@@ -4877,7 +4905,7 @@ def forward(
48774905 pred_atom_pos = confidence_head_atom_pos_input .detach (),
48784906 molecule_atom_indices = molecule_atom_indices ,
48794907 molecule_atom_lens = molecule_atom_lens ,
4880- atom_feats = atom_feats ,
4908+ atom_feats = atom_feats . detach () ,
48814909 mask = mask ,
48824910 return_pae_logits = True
48834911 )
@@ -5056,7 +5084,7 @@ def forward(
50565084 molecule_atom_indices = molecule_atom_indices ,
50575085 molecule_atom_lens = molecule_atom_lens ,
50585086 mask = mask ,
5059- atom_feats = atom_feats ,
5087+ atom_feats = atom_feats . detach () ,
50605088 return_pae_logits = return_pae_logits
50615089 )
50625090
0 commit comments