@@ -2234,7 +2234,7 @@ def __init__(
22342234 atom_decoder_kwargs : dict = dict (),
22352235 token_transformer_kwargs : dict = dict (),
22362236 use_linear_attn = False ,
2237- checkpoint_token_transformer = False ,
2237+ checkpoint = False ,
22382238 linear_attn_kwargs : dict = dict (
22392239 heads = 8 ,
22402240 dim_head = 16
@@ -2300,6 +2300,7 @@ def __init__(
23002300 serial = serial ,
23012301 use_linear_attn = use_linear_attn ,
23022302 linear_attn_kwargs = linear_attn_kwargs ,
2303+ checkpoint = checkpoint ,
23032304 ** atom_encoder_kwargs
23042305 )
23052306
@@ -2322,6 +2323,7 @@ def __init__(
23222323 depth = token_transformer_depth ,
23232324 heads = token_transformer_heads ,
23242325 serial = serial ,
2326+ checkpoint = checkpoint ,
23252327 ** token_transformer_kwargs
23262328 )
23272329
@@ -2341,7 +2343,7 @@ def __init__(
23412343 serial = serial ,
23422344 use_linear_attn = use_linear_attn ,
23432345 linear_attn_kwargs = linear_attn_kwargs ,
2344- checkpoint = checkpoint_token_transformer ,
2346+ checkpoint = checkpoint ,
23452347 ** atom_decoder_kwargs
23462348 )
23472349
@@ -4228,7 +4230,8 @@ def __init__(
42284230 dim_pairwise = 128 ,
42294231 num_dist_bins = 38 ,
42304232 dim_atom = 128 ,
4231- atom_resolution = False
4233+ atom_resolution = False ,
4234+ checkpoint = False ,
42324235 ):
42334236 super ().__init__ ()
42344237
@@ -4245,29 +4248,120 @@ def __init__(
42454248 if atom_resolution :
42464249 self .atom_feats_to_pairwise = LinearNoBiasThenOuterSum (dim_atom , dim_pairwise )
42474250
4251+ # checkpointing
4252+
4253+ self .checkpoint = checkpoint
4254+
42484255 # tensor typing
42494256
42504257 self .da = dim_atom
42514258
42524259 @typecheck
4253- def forward (
4260+ def to_layers (
42544261 self ,
4255- pairwise_repr : Float ['b n n d' ],
4256- molecule_atom_lens : Int ['b n' ] | None = None ,
4257- atom_feats : Float ['b m {self.da}' ] | None = None ,
4258- ) -> Float ['b l n n' ] | Float ['b l m m' ]:
4262+ pairwise_repr : Float ["b n n d" ], # type: ignore
4263+ molecule_atom_lens : Int ["b n" ] | None = None , # type: ignore
4264+ atom_feats : Float ["b m {self.da}" ] | None = None , # type: ignore
4265+ ) -> Float ["b l n n" ] | Float ["b l m m" ]: # type: ignore
4266+ """Compute the distogram logits.
42594267
4268+ :param pairwise_repr: The pairwise representation tensor.
4269+ :param molecule_atom_lens: The molecule atom lengths tensor.
4270+ :param atom_feats: The atom features tensor.
4271+ :return: The distogram logits.
4272+ """
42604273 if self .atom_resolution :
42614274 assert exists (molecule_atom_lens )
42624275 assert exists (atom_feats )
42634276
42644277 pairwise_repr = batch_repeat_interleave_pairwise (pairwise_repr , molecule_atom_lens )
4278+
42654279 pairwise_repr = pairwise_repr + self .atom_feats_to_pairwise (atom_feats )
42664280
42674281 logits = self .to_distogram_logits (symmetrize (pairwise_repr ))
42684282
42694283 return logits
42704284
4285+ @typecheck
4286+ def to_checkpointed_layers (
4287+ self ,
4288+ pairwise_repr : Float ["b n n d" ], # type: ignore
4289+ molecule_atom_lens : Int ["b n" ] | None = None , # type: ignore
4290+ atom_feats : Float ["b m {self.da}" ] | None = None , # type: ignore
4291+ ) -> Float ["b l n n" ] | Float ["b l m m" ]: # type: ignore
4292+ """Compute the checkpointed distogram logits.
4293+
4294+ :param pairwise_repr: The pairwise representation tensor.
4295+ :param molecule_atom_lens: The molecule atom lengths tensor.
4296+ :param atom_feats: The atom features tensor.
4297+ :return: The checkpointed distogram logits.
4298+ """
4299+ wrapped_layers = []
4300+ inputs = (pairwise_repr , molecule_atom_lens , atom_feats )
4301+
4302+ def atom_resolution_wrapper (fn ):
4303+ @wraps (fn )
4304+ def inner (inputs ):
4305+ pairwise_repr , molecule_atom_lens , atom_feats = inputs
4306+
4307+ assert exists (molecule_atom_lens )
4308+ assert exists (atom_feats )
4309+
4310+ pairwise_repr = batch_repeat_interleave_pairwise (pairwise_repr , molecule_atom_lens )
4311+
4312+ pairwise_repr = pairwise_repr + fn (atom_feats )
4313+ return pairwise_repr , molecule_atom_lens , atom_feats
4314+
4315+ return inner
4316+
4317+ def distogram_wrapper (fn ):
4318+ @wraps (fn )
4319+ def inner (inputs ):
4320+ pairwise_repr , molecule_atom_lens , atom_feats = inputs
4321+ pairwise_repr = fn (symmetrize (pairwise_repr ))
4322+ return pairwise_repr , molecule_atom_lens , atom_feats
4323+
4324+ return inner
4325+
4326+ if self .atom_resolution :
4327+ wrapped_layers .append (atom_resolution_wrapper (self .atom_feats_to_pairwise ))
4328+ wrapped_layers .append (distogram_wrapper (self .to_distogram_logits ))
4329+
4330+ for layer in wrapped_layers :
4331+ inputs = checkpoint (layer , inputs )
4332+
4333+ logits , _ = inputs
4334+ return logits
4335+
4336+ @typecheck
4337+ def forward (
4338+ self ,
4339+ pairwise_repr : Float ["b n n d" ], # type: ignore
4340+ molecule_atom_lens : Int ["b n" ] | None = None , # type: ignore
4341+ atom_feats : Float ["b m {self.da}" ] | None = None , # type: ignore
4342+ ) -> Float ["b l n n" ] | Float ["b l m m" ]: # type: ignore
4343+ """Compute the distogram logits.
4344+
4345+ :param pairwise_repr: The pairwise representation tensor.
4346+ :param molecule_atom_lens: The molecule atom lengths tensor.
4347+ :param atom_feats: The atom features tensor.
4348+ :return: The distogram logits.
4349+ """
4350+ # going through the layers
4351+
4352+ if should_checkpoint (self , pairwise_repr ):
4353+ to_layers_fn = self .to_checkpointed_layers
4354+ else :
4355+ to_layers_fn = self .to_layers
4356+
4357+ logits = to_layers_fn (
4358+ pairwise_repr = pairwise_repr ,
4359+ molecule_atom_lens = molecule_atom_lens ,
4360+ atom_feats = atom_feats ,
4361+ )
4362+
4363+ return logits
4364+
42714365# confidence head
42724366
42734367class ConfidenceHeadLogits (NamedTuple ):
@@ -5892,7 +5986,7 @@ def __init__(
58925986 checkpoint_trunk_pairformer = False ,
58935987 checkpoint_distogram_head = False ,
58945988 checkpoint_confidence_head = False ,
5895- checkpoint_diffusion_token_transformer = False ,
5989+ checkpoint_diffusion_module = False ,
58965990 detach_when_recycling = True ,
58975991 pdb_training_set = True ,
58985992 ):
@@ -6048,7 +6142,7 @@ def __init__(
60486142 dim_atompair = dim_atompair ,
60496143 dim_token = dim_token ,
60506144 dim_single = dim_single + dim_single_inputs ,
6051- checkpoint_token_transformer = checkpoint_diffusion_token_transformer ,
6145+ checkpoint = checkpoint_diffusion_module ,
60526146 ** diffusion_module_kwargs
60536147 )
60546148
@@ -6081,6 +6175,7 @@ def __init__(
60816175 dim_atom = dim_atom ,
60826176 num_dist_bins = num_dist_bins ,
60836177 atom_resolution = distogram_atom_resolution ,
6178+ checkpoint = checkpoint_distogram_head
60846179 )
60856180
60866181 # lddt related
0 commit comments