Skip to content

Commit 1684d50

Browse files
committed
complete value residual for pairformer stack
1 parent 49f7c96 commit 1684d50

File tree

2 files changed

+34
-10
lines changed

2 files changed

+34
-10
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -907,8 +907,12 @@ def forward(
907907
self,
908908
pairwise_repr: Float['b n n d'],
909909
mask: Bool['b n'] | None = None,
910+
return_values = False,
910911
**kwargs
911-
) -> Float['b n n d']:
912+
) -> (
913+
Float['b n n d'] |
914+
tuple[Float['b n n d'], Tensor]
915+
):
912916

913917
if self.need_transpose:
914918
pairwise_repr = rearrange(pairwise_repr, 'b i j d -> b j i d')
@@ -923,10 +927,11 @@ def forward(
923927

924928
pairwise_repr, unpack_one = pack_one(pairwise_repr, '* n d')
925929

926-
out = self.attn(
930+
out, values = self.attn(
927931
pairwise_repr,
928932
mask = mask,
929933
attn_bias = attn_bias,
934+
return_values = True,
930935
**kwargs
931936
)
932937

@@ -935,7 +940,12 @@ def forward(
935940
if self.need_transpose:
936941
out = rearrange(out, 'b j i d -> b i j d')
937942

938-
return self.dropout(out)
943+
out = self.dropout(out)
944+
945+
if not return_values:
946+
return out
947+
948+
return out, values
939949

940950
# PairwiseBlock
941951
# used in both MSAModule and Pairformer
@@ -978,15 +988,27 @@ def forward(
978988
self,
979989
*,
980990
pairwise_repr: Float['b n n d'],
981-
mask: Bool['b n'] | None = None
991+
mask: Bool['b n'] | None = None,
992+
value_residuals: tuple[Tensor, Tensor] | None = None,
993+
return_values = False,
982994
):
983995
pairwise_repr = self.tri_mult_outgoing(pairwise_repr, mask = mask) + pairwise_repr
984996
pairwise_repr = self.tri_mult_incoming(pairwise_repr, mask = mask) + pairwise_repr
985-
pairwise_repr = self.tri_attn_starting(pairwise_repr, mask = mask) + pairwise_repr
986-
pairwise_repr = self.tri_attn_ending(pairwise_repr, mask = mask) + pairwise_repr
997+
998+
attn_start_value_residual, attn_end_value_residual = default(value_residuals, (None, None))
999+
1000+
attn_start_out, attn_start_values = self.tri_attn_starting(pairwise_repr, mask = mask, value_residual = attn_start_value_residual, return_values = True)
1001+
pairwise_repr = attn_start_out + pairwise_repr
1002+
1003+
attn_end_out, attn_end_values = self.tri_attn_ending(pairwise_repr, mask = mask, value_residual = attn_end_value_residual, return_values = True)
1004+
pairwise_repr = attn_end_out + pairwise_repr
9871005

9881006
pairwise_repr = self.pairwise_transition(pairwise_repr) + pairwise_repr
989-
return pairwise_repr
1007+
1008+
if not return_values:
1009+
return pairwise_repr
1010+
1011+
return pairwise_repr, (attn_start_values, attn_end_values)
9901012

9911013
# msa module
9921014

@@ -1465,6 +1487,7 @@ def to_layers(
14651487
) -> Tuple[Float['b n ds'], Float['b n n dp']]:
14661488

14671489
value_residual = None
1490+
pairwise_value_residuals = None
14681491

14691492
for _ in range(self.recurrent_depth):
14701493
for (
@@ -1473,14 +1496,15 @@ def to_layers(
14731496
single_transition
14741497
) in self.layers:
14751498

1476-
pairwise_repr = pairwise_block(pairwise_repr = pairwise_repr, mask = mask)
1499+
pairwise_repr, pairwise_attn_values = pairwise_block(pairwise_repr = pairwise_repr, mask = mask, value_residuals = pairwise_value_residuals, return_values = True)
14771500

14781501
attn_out, attn_values = pair_bias_attn(single_repr, pairwise_repr = pairwise_repr, mask = mask, return_values = True, value_residual = value_residual)
14791502

14801503
single_repr = single_repr + attn_out
14811504

14821505
if self.add_value_residual:
14831506
value_residual = default(value_residual, attn_values)
1507+
pairwise_value_residuals = default(pairwise_value_residuals, pairwise_attn_values)
14841508

14851509
single_repr = single_transition(single_repr) + single_repr
14861510

@@ -4493,7 +4517,7 @@ def __init__(
44934517
num_pae_bins = 64,
44944518
pairformer_depth = 4,
44954519
pairformer_kwargs: dict = dict(),
4496-
checkpoint=False
4520+
checkpoint = False
44974521
):
44984522
super().__init__()
44994523

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.6.11"
3+
version = "0.7.0"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" },

0 commit comments

Comments
 (0)