Skip to content

Commit 8b2bf10

Browse files
authored
Enable DistogramHead checkpointing (#206)
* Update test_af3.py * Update alphafold3.py * Update alphafold3.py
1 parent 496790f commit 8b2bf10

File tree

2 files changed

+106
-11
lines changed

2 files changed

+106
-11
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 105 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2234,7 +2234,7 @@ def __init__(
22342234
atom_decoder_kwargs: dict = dict(),
22352235
token_transformer_kwargs: dict = dict(),
22362236
use_linear_attn = False,
2237-
checkpoint_token_transformer = False,
2237+
checkpoint = False,
22382238
linear_attn_kwargs: dict = dict(
22392239
heads = 8,
22402240
dim_head = 16
@@ -2300,6 +2300,7 @@ def __init__(
23002300
serial = serial,
23012301
use_linear_attn = use_linear_attn,
23022302
linear_attn_kwargs = linear_attn_kwargs,
2303+
checkpoint = checkpoint,
23032304
**atom_encoder_kwargs
23042305
)
23052306

@@ -2322,6 +2323,7 @@ def __init__(
23222323
depth = token_transformer_depth,
23232324
heads = token_transformer_heads,
23242325
serial = serial,
2326+
checkpoint = checkpoint,
23252327
**token_transformer_kwargs
23262328
)
23272329

@@ -2341,7 +2343,7 @@ def __init__(
23412343
serial = serial,
23422344
use_linear_attn = use_linear_attn,
23432345
linear_attn_kwargs = linear_attn_kwargs,
2344-
checkpoint = checkpoint_token_transformer,
2346+
checkpoint = checkpoint,
23452347
**atom_decoder_kwargs
23462348
)
23472349

@@ -4228,7 +4230,8 @@ def __init__(
42284230
dim_pairwise = 128,
42294231
num_dist_bins = 38,
42304232
dim_atom = 128,
4231-
atom_resolution = False
4233+
atom_resolution = False,
4234+
checkpoint = False,
42324235
):
42334236
super().__init__()
42344237

@@ -4245,29 +4248,120 @@ def __init__(
42454248
if atom_resolution:
42464249
self.atom_feats_to_pairwise = LinearNoBiasThenOuterSum(dim_atom, dim_pairwise)
42474250

4251+
# checkpointing
4252+
4253+
self.checkpoint = checkpoint
4254+
42484255
# tensor typing
42494256

42504257
self.da = dim_atom
42514258

42524259
@typecheck
4253-
def forward(
4260+
def to_layers(
42544261
self,
4255-
pairwise_repr: Float['b n n d'],
4256-
molecule_atom_lens: Int['b n'] | None = None,
4257-
atom_feats: Float['b m {self.da}'] | None = None,
4258-
) -> Float['b l n n'] | Float['b l m m']:
4262+
pairwise_repr: Float["b n n d"], # type: ignore
4263+
molecule_atom_lens: Int["b n"] | None = None, # type: ignore
4264+
atom_feats: Float["b m {self.da}"] | None = None, # type: ignore
4265+
) -> Float["b l n n"] | Float["b l m m"]: # type: ignore
4266+
"""Compute the distogram logits.
42594267
4268+
:param pairwise_repr: The pairwise representation tensor.
4269+
:param molecule_atom_lens: The molecule atom lengths tensor.
4270+
:param atom_feats: The atom features tensor.
4271+
:return: The distogram logits.
4272+
"""
42604273
if self.atom_resolution:
42614274
assert exists(molecule_atom_lens)
42624275
assert exists(atom_feats)
42634276

42644277
pairwise_repr = batch_repeat_interleave_pairwise(pairwise_repr, molecule_atom_lens)
4278+
42654279
pairwise_repr = pairwise_repr + self.atom_feats_to_pairwise(atom_feats)
42664280

42674281
logits = self.to_distogram_logits(symmetrize(pairwise_repr))
42684282

42694283
return logits
42704284

4285+
@typecheck
4286+
def to_checkpointed_layers(
4287+
self,
4288+
pairwise_repr: Float["b n n d"], # type: ignore
4289+
molecule_atom_lens: Int["b n"] | None = None, # type: ignore
4290+
atom_feats: Float["b m {self.da}"] | None = None, # type: ignore
4291+
) -> Float["b l n n"] | Float["b l m m"]: # type: ignore
4292+
"""Compute the checkpointed distogram logits.
4293+
4294+
:param pairwise_repr: The pairwise representation tensor.
4295+
:param molecule_atom_lens: The molecule atom lengths tensor.
4296+
:param atom_feats: The atom features tensor.
4297+
:return: The checkpointed distogram logits.
4298+
"""
4299+
wrapped_layers = []
4300+
inputs = (pairwise_repr, molecule_atom_lens, atom_feats)
4301+
4302+
def atom_resolution_wrapper(fn):
4303+
@wraps(fn)
4304+
def inner(inputs):
4305+
pairwise_repr, molecule_atom_lens, atom_feats = inputs
4306+
4307+
assert exists(molecule_atom_lens)
4308+
assert exists(atom_feats)
4309+
4310+
pairwise_repr = batch_repeat_interleave_pairwise(pairwise_repr, molecule_atom_lens)
4311+
4312+
pairwise_repr = pairwise_repr + fn(atom_feats)
4313+
return pairwise_repr, molecule_atom_lens, atom_feats
4314+
4315+
return inner
4316+
4317+
def distogram_wrapper(fn):
4318+
@wraps(fn)
4319+
def inner(inputs):
4320+
pairwise_repr, molecule_atom_lens, atom_feats = inputs
4321+
pairwise_repr = fn(symmetrize(pairwise_repr))
4322+
return pairwise_repr, molecule_atom_lens, atom_feats
4323+
4324+
return inner
4325+
4326+
if self.atom_resolution:
4327+
wrapped_layers.append(atom_resolution_wrapper(self.atom_feats_to_pairwise))
4328+
wrapped_layers.append(distogram_wrapper(self.to_distogram_logits))
4329+
4330+
for layer in wrapped_layers:
4331+
inputs = checkpoint(layer, inputs)
4332+
4333+
logits, _ = inputs
4334+
return logits
4335+
4336+
@typecheck
4337+
def forward(
4338+
self,
4339+
pairwise_repr: Float["b n n d"], # type: ignore
4340+
molecule_atom_lens: Int["b n"] | None = None, # type: ignore
4341+
atom_feats: Float["b m {self.da}"] | None = None, # type: ignore
4342+
) -> Float["b l n n"] | Float["b l m m"]: # type: ignore
4343+
"""Compute the distogram logits.
4344+
4345+
:param pairwise_repr: The pairwise representation tensor.
4346+
:param molecule_atom_lens: The molecule atom lengths tensor.
4347+
:param atom_feats: The atom features tensor.
4348+
:return: The distogram logits.
4349+
"""
4350+
# going through the layers
4351+
4352+
if should_checkpoint(self, pairwise_repr):
4353+
to_layers_fn = self.to_checkpointed_layers
4354+
else:
4355+
to_layers_fn = self.to_layers
4356+
4357+
logits = to_layers_fn(
4358+
pairwise_repr=pairwise_repr,
4359+
molecule_atom_lens=molecule_atom_lens,
4360+
atom_feats=atom_feats,
4361+
)
4362+
4363+
return logits
4364+
42714365
# confidence head
42724366

42734367
class ConfidenceHeadLogits(NamedTuple):
@@ -5892,7 +5986,7 @@ def __init__(
58925986
checkpoint_trunk_pairformer = False,
58935987
checkpoint_distogram_head = False,
58945988
checkpoint_confidence_head = False,
5895-
checkpoint_diffusion_token_transformer = False,
5989+
checkpoint_diffusion_module = False,
58965990
detach_when_recycling = True,
58975991
pdb_training_set=True,
58985992
):
@@ -6048,7 +6142,7 @@ def __init__(
60486142
dim_atompair = dim_atompair,
60496143
dim_token = dim_token,
60506144
dim_single = dim_single + dim_single_inputs,
6051-
checkpoint_token_transformer = checkpoint_diffusion_token_transformer,
6145+
checkpoint = checkpoint_diffusion_module,
60526146
**diffusion_module_kwargs
60536147
)
60546148

@@ -6081,6 +6175,7 @@ def __init__(
60816175
dim_atom = dim_atom,
60826176
num_dist_bins = num_dist_bins,
60836177
atom_resolution = distogram_atom_resolution,
6178+
checkpoint = checkpoint_distogram_head
60846179
)
60856180

60866181
# lddt related

tests/test_af3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -807,7 +807,7 @@ def test_alphafold3_without_msa_and_templates():
807807
num_dist_bins = 38,
808808
num_molecule_mods = 0,
809809
checkpoint_trunk_pairformer = True,
810-
checkpoint_diffusion_token_transformer = True,
810+
checkpoint_diffusion_module = True,
811811
confidence_head_kwargs = dict(
812812
pairformer_depth = 1
813813
),

0 commit comments

Comments
 (0)