Skip to content

Commit 56e3a67

Browse files
committed
final checkpointing code for template embedding pairformer stack
1 parent 6f81b7a commit 56e3a67

File tree

3 files changed

+74
-10
lines changed

3 files changed

+74
-10
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 63 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -174,9 +174,11 @@ def exclusive_cumsum(t, dim = -1):
174174
@typecheck
175175
def should_checkpoint(
176176
self: Module,
177-
inputs: Tuple[Tensor, ...],
177+
inputs: Tensor | Tuple[Tensor, ...],
178178
check_instance_variable: str | None = 'checkpoint'
179179
) -> bool:
180+
if torch.is_tensor(inputs):
181+
inputs = (inputs,)
180182

181183
return (
182184
self.training and
@@ -1481,6 +1483,8 @@ def __init__(
14811483
pairformer_stack_depth = 2,
14821484
pairwise_block_kwargs: dict = dict(),
14831485
eps = 1e-5,
1486+
checkpoint = False,
1487+
checkpoint_segments = 1,
14841488
layerscale_output = True
14851489
):
14861490
super().__init__()
@@ -1504,6 +1508,9 @@ def __init__(
15041508

15051509
self.pairformer_stack = layers
15061510

1511+
self.checkpoint = checkpoint
1512+
self.checkpoint_segments = checkpoint_segments
1513+
15071514
self.final_norm = nn.LayerNorm(dim)
15081515

15091516
# final projection of mean pooled repr -> out
@@ -1515,6 +1522,48 @@ def __init__(
15151522

15161523
self.layerscale = nn.Parameter(torch.zeros(dim_pairwise)) if layerscale_output else 1.
15171524

1525+
@typecheck
1526+
def to_layers(
1527+
self,
1528+
v: Float['bt n n dt'],
1529+
*,
1530+
mask: Bool['bt n'] | None = None
1531+
) -> Float['bt n n dt']:
1532+
1533+
for block in self.pairformer_stack:
1534+
v = block(
1535+
pairwise_repr = v,
1536+
mask = mask
1537+
) + v
1538+
1539+
return v
1540+
1541+
@typecheck
1542+
def to_checkpointed_layers(
1543+
self,
1544+
v: Float['bt n n dt'],
1545+
*,
1546+
mask: Bool['bt n'] | None = None
1547+
) -> Float['bt n n dt']:
1548+
1549+
wrapped_layers = []
1550+
inputs = (v, mask)
1551+
1552+
def block_wrapper(fn):
1553+
@wraps(fn)
1554+
def inner(inputs):
1555+
v, mask = inputs
1556+
v = fn(pairwise_repr = v, mask = mask)
1557+
return v, mask
1558+
return inner
1559+
1560+
for block in self.pairformer_stack:
1561+
wrapped_layers.append(block_wrapper(block))
1562+
1563+
v, _ = checkpoint_sequential(wrapped_layers, self.checkpoint_segments, inputs, use_reentrant = False)
1564+
1565+
return v
1566+
15181567
@typecheck
15191568
def forward(
15201569
self,
@@ -1539,11 +1588,19 @@ def forward(
15391588
if exists(mask):
15401589
mask = repeat(mask, 'b n -> (b t) n', t = num_templates)
15411590

1542-
for block in self.pairformer_stack:
1543-
v = block(
1544-
pairwise_repr = v,
1545-
mask = mask
1546-
) + v
1591+
# going through the pairformer stack
1592+
1593+
if should_checkpoint(self, v):
1594+
to_layers_fn = self.to_checkpointed_layers
1595+
else:
1596+
to_layers_fn = self.to_layers
1597+
1598+
# layers
1599+
# todo - figure out why single-variable names v and u used here and name it better.
1600+
1601+
v = to_layers_fn(v)
1602+
1603+
# final norm
15471604

15481605
u = self.final_norm(v)
15491606

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.2.65"
3+
version = "0.2.66"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

tests/test_af3.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -380,15 +380,19 @@ def test_relative_position_encoding():
380380
additional_molecule_feats = additional_molecule_feats
381381
)
382382

383-
def test_template_embed():
383+
@pytest.mark.parametrize('checkpoint', (False, True))
384+
def test_template_embed(
385+
checkpoint
386+
):
384387
template_feats = torch.randn(2, 2, 16, 16, 77)
385388
template_mask = torch.ones((2, 2)).bool()
386389

387-
pairwise_repr = torch.randn(2, 16, 16, 128)
390+
pairwise_repr = torch.randn(2, 16, 16, 128).requires_grad_()
388391
mask = torch.ones((2, 16)).bool()
389392

390393
embedder = TemplateEmbedder(
391-
dim_template_feats = 77
394+
dim_template_feats = 77,
395+
checkpoint = checkpoint
392396
)
393397

394398
template_embed = embedder(
@@ -398,6 +402,9 @@ def test_template_embed():
398402
mask = mask
399403
)
400404

405+
if checkpoint:
406+
loss = template_embed.sum()
407+
loss.backward()
401408

402409
def test_confidence_head():
403410
single_inputs_repr = torch.randn(2, 16, 77)

0 commit comments

Comments
 (0)