Skip to content

Commit 365377b

Browse files
committed
detach when recycling by default
1 parent b0324e5 commit 365377b

File tree

2 files changed

+11
-2
lines changed

2 files changed

+11
-2
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4518,6 +4518,7 @@ def __init__(
45184518
checkpoint_input_embedding = False,
45194519
checkpoint_trunk_pairformer = False,
45204520
checkpoint_diffusion_token_transformer = False,
4521+
detach_when_recycling = True
45214522
):
45224523
super().__init__()
45234524

@@ -4642,6 +4643,8 @@ def __init__(
46424643

46434644
# recycling related
46444645

4646+
self.detach_when_recycling = detach_when_recycling
4647+
46454648
self.recycle_single = nn.Sequential(
46464649
nn.LayerNorm(dim_single),
46474650
LinearNoBias(dim_single, dim_single)
@@ -4835,7 +4838,8 @@ def forward(
48354838
return_present_sampled_atoms: bool = False,
48364839
return_confidence_head_logits: bool = False,
48374840
num_rollout_steps: int | None = None,
4838-
rollout_show_tqdm_pbar: bool = False
4841+
rollout_show_tqdm_pbar: bool = False,
4842+
detach_when_recycling: bool = None
48394843
) -> (
48404844
Float['b m 3'] |
48414845
Float['l 3'] |
@@ -4995,6 +4999,9 @@ def forward(
49954999

49965000
# init recycled single and pairwise
49975001

5002+
detach_when_recycling = default(detach_when_recycling, self.detach_when_recycling)
5003+
maybe_recycling_detach = torch.detach if detach_when_recycling else identity
5004+
49985005
recycled_pairwise = recycled_single = None
49995006
single = pairwise = None
50005007

@@ -5007,9 +5014,11 @@ def forward(
50075014
recycled_single = recycled_pairwise = 0.
50085015

50095016
if exists(single):
5017+
single = maybe_recycling_detach(single)
50105018
recycled_single = self.recycle_single(single)
50115019

50125020
if exists(pairwise):
5021+
pairwise = maybe_recycling_detach(pairwise)
50135022
recycled_pairwise = self.recycle_pairwise(pairwise)
50145023

50155024
single = single_init + recycled_single

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.2.106"
3+
version = "0.2.107"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" },

0 commit comments

Comments
 (0)