Skip to content

Commit ac669e8

Browse files
committed
throw in recurrent depth for pairformer stack
1 parent 535fad9 commit ac669e8

File tree

3 files changed

+22
-10
lines changed

3 files changed

+22
-10
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1021,6 +1021,7 @@ def __init__(
10211021
dim_single = 384,
10221022
dim_pairwise = 128,
10231023
depth = 48,
1024+
recurrent_depth = 1, # effective depth will be depth * recurrent_depth
10241025
pair_bias_attn_dim_head = 64,
10251026
pair_bias_attn_heads = 16,
10261027
dropout_row_prob = 0.25,
@@ -1058,6 +1059,12 @@ def __init__(
10581059

10591060
self.layers = layers
10601061

1062+
# https://arxiv.org/abs/2405.16039 and https://arxiv.org/abs/2405.15071
1063+
# although possibly recycling already takes care of this
1064+
1065+
assert recurrent_depth > 0
1066+
self.recurrent_depth = recurrent_depth
1067+
10611068
self.num_registers = num_register_tokens
10621069
self.has_registers = num_register_tokens > 0
10631070

@@ -1093,16 +1100,17 @@ def forward(
10931100

10941101
# main transformer block layers
10951102

1096-
for (
1097-
pairwise_block,
1098-
pair_bias_attn,
1099-
single_transition
1100-
) in self.layers:
1103+
for _ in range(self.recurrent_depth):
1104+
for (
1105+
pairwise_block,
1106+
pair_bias_attn,
1107+
single_transition
1108+
) in self.layers:
11011109

1102-
pairwise_repr = pairwise_block(pairwise_repr = pairwise_repr, mask = mask)
1110+
pairwise_repr = pairwise_block(pairwise_repr = pairwise_repr, mask = mask)
11031111

1104-
single_repr = pair_bias_attn(single_repr, pairwise_repr = pairwise_repr, mask = mask) + single_repr
1105-
single_repr = single_transition(single_repr) + single_repr
1112+
single_repr = pair_bias_attn(single_repr, pairwise_repr = pairwise_repr, mask = mask) + single_repr
1113+
single_repr = single_transition(single_repr) + single_repr
11061114

11071115
# splice out registers
11081116

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.1.10"
3+
version = "0.1.11"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

tests/test_af3.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,14 +156,18 @@ def test_centre_random_augmentation():
156156
assert augmented_coords.shape == coords.shape
157157

158158

159-
def test_pairformer():
159+
@pytest.mark.parametrize('recurrent_depth', (1, 2))
160+
def test_pairformer(
161+
recurrent_depth
162+
):
160163
single = torch.randn(2, 16, 384)
161164
pairwise = torch.randn(2, 16, 16, 128)
162165
mask = torch.randint(0, 2, (2, 16)).bool()
163166

164167
pairformer = PairformerStack(
165168
depth = 4,
166169
num_register_tokens = 4,
170+
recurrent_depth = recurrent_depth
167171
)
168172

169173
single_out, pairwise_out = pairformer(

0 commit comments

Comments
 (0)