Skip to content

Commit 7e3de12

Browse files
authored
Add support for activation checkpointing with DeepSpeed (#204)
* Update model_utils.py * Update alphafold3.py * Update model_utils.py * Update tensor_typing.py * Update tensor_typing.py * Update test_af3.py * Update alphafold3.py * Update .env.sample * Update test.yml * Update alphafold3.py * Update test_af3.py * Update pyproject.toml
1 parent 4e0fd07 commit 7e3de12

File tree

7 files changed

+62
-64
lines changed

7 files changed

+62
-64
lines changed

.env.sample

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
TYPECHECK=True
22
DEBUG=False
3+
DEEPSPEED_CHECKPOINTING=False

.github/workflows/test.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ on: [push, pull_request]
44
env:
55
TYPECHECK: True
66
DEBUG: True
7+
DEEPSPEED_CHECKPOINTING: False
78

89
jobs:
910
build:

alphafold3_pytorch/alphafold3.py

Lines changed: 40 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
from torch import Tensor
1414
from torch.amp import autocast
1515
import torch.nn.functional as F
16-
from torch.utils.checkpoint import checkpoint, checkpoint_sequential
1716

1817
from torch.nn import (
1918
Module,
@@ -73,6 +72,7 @@
7372
ExpressCoordinatesInFrame,
7473
RigidFrom3Points,
7574
calculate_weighted_rigid_align_weights,
75+
package_available,
7676
)
7777

7878
from frame_averaging_pytorch import FrameAverage
@@ -84,6 +84,7 @@
8484
import einx
8585
from einops import rearrange, repeat, reduce, einsum, pack, unpack
8686
from einops.layers.torch import Rearrange
87+
from environs import Env
8788

8889
from tqdm import tqdm
8990

@@ -169,10 +170,23 @@
169170

170171
LinearNoBias = partial(Linear, bias = False)
171172

173+
# environment
174+
175+
env = Env()
176+
env.read_env()
177+
172178
# always use non reentrant checkpointing
173179

174-
checkpoint = partial(checkpoint, use_reentrant = False)
175-
checkpoint_sequential = partial(checkpoint_sequential, use_reentrant = False)
180+
DEEPSPEED_CHECKPOINTING = env.bool('DEEPSPEED_CHECKPOINTING', False)
181+
182+
if DEEPSPEED_CHECKPOINTING:
183+
assert package_available("deepspeed"), "DeepSpeed must be installed for checkpointing."
184+
185+
import deepspeed
186+
187+
checkpoint = deepspeed.checkpointing.checkpoint
188+
else:
189+
checkpoint = partial(torch.utils.checkpoint.checkpoint, use_reentrant = False)
176190

177191
# helper functions
178192

@@ -1061,7 +1075,6 @@ def __init__(
10611075
msa_pwa_heads = 8,
10621076
msa_pwa_dim_head = 32,
10631077
checkpoint = False,
1064-
checkpoint_segments = 1,
10651078
pairwise_block_kwargs: dict = dict(),
10661079
max_num_msa: int | None = None,
10671080
layerscale_output: bool = True
@@ -1112,7 +1125,6 @@ def __init__(
11121125
]))
11131126

11141127
self.checkpoint = checkpoint
1115-
self.checkpoint_segments = checkpoint_segments
11161128

11171129
self.layers = layers
11181130

@@ -1182,19 +1194,19 @@ def inner(inputs):
11821194
return pairwise_repr, mask, msa, msa_mask
11831195
return inner
11841196

1185-
def pairwise_block_wrapper(fn):
1197+
def msa_transition_wrapper(fn):
11861198
@wraps(fn)
11871199
def inner(inputs):
11881200
pairwise_repr, mask, msa, msa_mask = inputs
1189-
pairwise_repr = fn(pairwise_repr = pairwise_repr, mask = mask)
1201+
msa = fn(msa) + msa
11901202
return pairwise_repr, mask, msa, msa_mask
11911203
return inner
11921204

1193-
def msa_transition_wrapper(fn):
1205+
def pairwise_block_wrapper(fn):
11941206
@wraps(fn)
11951207
def inner(inputs):
11961208
pairwise_repr, mask, msa, msa_mask = inputs
1197-
msa = fn(msa) + msa
1209+
pairwise_repr = fn(pairwise_repr = pairwise_repr, mask = mask)
11981210
return pairwise_repr, mask, msa, msa_mask
11991211
return inner
12001212

@@ -1210,8 +1222,10 @@ def inner(inputs):
12101222
wrapped_layers.append(msa_transition_wrapper(msa_transition))
12111223
wrapped_layers.append(pairwise_block_wrapper(pairwise_block))
12121224

1213-
pairwise_repr, *_ = checkpoint_sequential(wrapped_layers, self.checkpoint_segments, inputs)
1225+
for layer in wrapped_layers:
1226+
inputs = checkpoint(layer, inputs)
12141227

1228+
pairwise_repr, *_ = inputs
12151229
return pairwise_repr
12161230

12171231
@typecheck
@@ -1318,7 +1332,6 @@ def __init__(
13181332
dropout_row_prob = 0.25,
13191333
num_register_tokens = 0,
13201334
checkpoint = False,
1321-
checkpoint_segments = 1,
13221335
pairwise_block_kwargs: dict = dict(),
13231336
pair_bias_attn_kwargs: dict = dict()
13241337
):
@@ -1357,7 +1370,6 @@ def __init__(
13571370
# checkpointing
13581371

13591372
self.checkpoint = checkpoint
1360-
self.checkpoint_segments = checkpoint_segments
13611373

13621374
# https://arxiv.org/abs/2405.16039 and https://arxiv.org/abs/2405.15071
13631375
# although possibly recycling already takes care of this
@@ -1446,8 +1458,10 @@ def inner(inputs, *args, **kwargs):
14461458
wrapped_layers.append(pair_bias_attn_wrapper(pair_bias_attn))
14471459
wrapped_layers.append(single_transition_wrapper(single_transition))
14481460

1449-
single_repr, pairwise_repr, _ = checkpoint_sequential(wrapped_layers, self.checkpoint_segments, inputs)
1461+
for layer in wrapped_layers:
1462+
inputs = checkpoint(layer, inputs)
14501463

1464+
single_repr, pairwise_repr, _ = inputs
14511465
return single_repr, pairwise_repr
14521466

14531467
@typecheck
@@ -1590,7 +1604,6 @@ def __init__(
15901604
pairwise_block_kwargs: dict = dict(),
15911605
eps = 1e-5,
15921606
checkpoint = False,
1593-
checkpoint_segments = 1,
15941607
layerscale_output = True
15951608
):
15961609
super().__init__()
@@ -1615,7 +1628,6 @@ def __init__(
16151628
self.pairformer_stack = layers
16161629

16171630
self.checkpoint = checkpoint
1618-
self.checkpoint_segments = checkpoint_segments
16191631

16201632
self.final_norm = nn.LayerNorm(dim)
16211633

@@ -1666,8 +1678,10 @@ def inner(inputs):
16661678
for block in self.pairformer_stack:
16671679
wrapped_layers.append(block_wrapper(block))
16681680

1669-
templates, _ = checkpoint_sequential(wrapped_layers, self.checkpoint_segments, inputs)
1681+
for layer in wrapped_layers:
1682+
inputs = checkpoint(layer, inputs)
16701683

1684+
templates, _ = inputs
16711685
return templates
16721686

16731687
@typecheck
@@ -1877,7 +1891,6 @@ def __init__(
18771891
add_residual = True,
18781892
use_linear_attn = False,
18791893
checkpoint = False,
1880-
checkpoint_segments = 1,
18811894
linear_attn_kwargs = dict(
18821895
heads = 8,
18831896
dim_head = 16
@@ -1956,7 +1969,6 @@ def __init__(
19561969
assert not (not serial and checkpoint), 'checkpointing can only be used for serial version of diffusion transformer'
19571970

19581971
self.checkpoint = checkpoint
1959-
self.checkpoint_segments = checkpoint_segments
19601972

19611973
self.layers = layers
19621974

@@ -2021,9 +2033,10 @@ def inner(inputs):
20212033
wrapped_layers.append(attn_wrapper(attn))
20222034
wrapped_layers.append(transition_wrapper(transition))
20232035

2024-
out = checkpoint_sequential(wrapped_layers, self.checkpoint_segments, inputs)
2036+
for layer in wrapped_layers:
2037+
inputs = checkpoint(layer, inputs)
20252038

2026-
noised_repr, *_ = out
2039+
noised_repr, *_ = inputs
20272040
return noised_repr
20282041

20292042
@typecheck
@@ -2314,10 +2327,6 @@ def __init__(
23142327

23152328
self.attended_token_norm = nn.LayerNorm(dim_token)
23162329

2317-
# checkpointing
2318-
2319-
self.checkpoint_token_transformer = checkpoint_token_transformer
2320-
23212330
# atom attention decoding related modules
23222331

23232332
self.tokens_to_atom_decoder_input_cond = LinearNoBias(dim_token, dim_atom)
@@ -2332,6 +2341,7 @@ def __init__(
23322341
serial = serial,
23332342
use_linear_attn = use_linear_attn,
23342343
linear_attn_kwargs = linear_attn_kwargs,
2344+
checkpoint = checkpoint_token_transformer,
23352345
**atom_decoder_kwargs
23362346
)
23372347

@@ -2484,18 +2494,11 @@ def forward(
24842494
molecule_atom_lens = molecule_atom_lens
24852495
)
24862496

2487-
# maybe checkpoint token transformer
2488-
2489-
token_transformer = self.token_transformer
2490-
2491-
if should_checkpoint(self, tokens, 'checkpoint_token_transformer'):
2492-
token_transformer = partial(checkpoint, token_transformer)
2493-
24942497
# token transformer
24952498

24962499
tokens = self.cond_tokens_with_cond_single(conditioned_single_repr) + tokens
24972500

2498-
tokens = token_transformer(
2501+
tokens = self.token_transformer(
24992502
tokens,
25002503
mask = mask,
25012504
single_repr = conditioned_single_repr,
@@ -5991,6 +5994,7 @@ def __init__(
59915994
dim_template_feats = dim_template_feats,
59925995
dim = dim_template_model,
59935996
dim_pairwise = dim_pairwise,
5997+
checkpoint=checkpoint_input_embedding,
59945998
**template_embedder_kwargs
59955999
)
59966000

@@ -6003,6 +6007,7 @@ def __init__(
60036007
dim_pairwise = dim_pairwise,
60046008
dim_msa_input = dim_msa_inputs,
60056009
dim_additional_msa_feats = dim_additional_msa_feats,
6010+
checkpoint=checkpoint_input_embedding,
60066011
**msa_module_kwargs,
60076012
)
60086013

@@ -6011,6 +6016,7 @@ def __init__(
60116016
self.pairformer = PairformerStack(
60126017
dim_single = dim_single,
60136018
dim_pairwise = dim_pairwise,
6019+
checkpoint=checkpoint_trunk_pairformer,
60146020
**pairformer_stack
60156021
)
60166022

@@ -6115,13 +6121,6 @@ def __init__(
61156121

61166122
self.register_buffer('lddt_thresholds', torch.tensor([0.5, 1.0, 2.0, 4.0]))
61176123

6118-
# checkpointing related
6119-
6120-
self.checkpoint_trunk_pairformer = checkpoint_trunk_pairformer
6121-
self.checkpoint_diffusion_token_transformer = checkpoint_diffusion_token_transformer
6122-
self.checkpoint_distogram_head = checkpoint_distogram_head
6123-
self.checkpoint_confidence_head = checkpoint_confidence_head
6124-
61256124
# loss related
61266125

61276126
self.ignore_index = ignore_index
@@ -6510,16 +6509,9 @@ def forward(
65106509

65116510
pairwise = embedded_msa + pairwise
65126511

6513-
# maybe checkpoint trunk pairformer
6514-
6515-
pairformer = self.pairformer
6516-
6517-
if should_checkpoint(self, (single, pairwise), 'checkpoint_trunk_pairformer'):
6518-
pairformer = partial(checkpoint, pairformer)
6519-
65206512
# main attention trunk (pairformer)
65216513

6522-
single, pairwise = pairformer(
6514+
single, pairwise = self.pairformer(
65236515
single_repr = single,
65246516
pairwise_repr = pairwise,
65256517
mask = mask
@@ -6650,12 +6642,7 @@ def forward(
66506642

66516643
distance_labels = torch.where(distogram_mask, distance_labels, ignore)
66526644

6653-
distogram_head_fn = self.distogram_head
6654-
6655-
if should_checkpoint(self, pairwise, 'checkpoint_distogram_head'):
6656-
distogram_head_fn = partial(checkpoint, distogram_head_fn)
6657-
6658-
distogram_logits = distogram_head_fn(
6645+
distogram_logits = self.distogram_head(
66596646
pairwise,
66606647
molecule_atom_lens = molecule_atom_lens,
66616648
atom_feats = atom_feats

alphafold3_pytorch/tensor_typing.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,13 @@ def __getitem__(self, shapes: str):
6868

6969
if should_typecheck:
7070
logger.info("Type checking is enabled.")
71+
else:
72+
logger.info("Type checking is disabled.")
73+
7174
if IS_DEBUGGING:
7275
logger.info("Debugging is enabled.")
76+
else:
77+
logger.info("Debugging is disabled.")
7378

7479
__all__ = [
7580
Shaped,

alphafold3_pytorch/utils/model_utils.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from typing import Callable, List, Tuple, Union
33

44
import einx
5+
import pkg_resources
56
import torch
67
import torch.nn.functional as F
78
from einops import einsum, pack, rearrange, reduce, repeat, unpack
@@ -621,6 +622,19 @@ def should_checkpoint(
621622
)
622623

623624

625+
@typecheck
626+
def package_available(package_name: str) -> bool:
627+
"""Check if a package is available in your environment.
628+
629+
:param package_name: The name of the package to be checked.
630+
:return: `True` if the package is available. `False` otherwise.
631+
"""
632+
try:
633+
return pkg_resources.require(package_name) is not None
634+
except pkg_resources.DistributionNotFound:
635+
return False
636+
637+
624638
# functions for deriving the frames for ligands
625639
# this follows the logic from Alphafold3 Supplementary section 4.3.2
626640

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ dependencies = [
4141
"huggingface_hub>=0.21.4",
4242
"jaxtyping>=0.2.28",
4343
"lightning>=2.2.5",
44-
"numpy",
44+
"numpy==1.23.5",
4545
"polars>=1.1.0",
4646
"pdbeccdutils>=0.8.5",
4747
"pydantic>=2.8.2",

tests/test_af3.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -818,22 +818,12 @@ def test_alphafold3_without_msa_and_templates():
818818
depth = 1
819819
),
820820
pairformer_stack = dict(
821-
checkpoint = True,
822821
depth = 2
823822
),
824823
diffusion_module_kwargs = dict(
825824
atom_encoder_depth = 2,
826-
atom_encoder_kwargs = dict(
827-
checkpoint = True,
828-
),
829825
token_transformer_depth = 2,
830-
token_transformer_kwargs = dict(
831-
checkpoint = True,
832-
),
833826
atom_decoder_depth = 2,
834-
atom_decoder_kwargs = dict(
835-
checkpoint = True,
836-
),
837827
),
838828
)
839829

0 commit comments

Comments
 (0)