@@ -907,8 +907,12 @@ def forward(
907907 self ,
908908 pairwise_repr : Float ['b n n d' ],
909909 mask : Bool ['b n' ] | None = None ,
910+ return_values = False ,
910911 ** kwargs
911- ) -> Float ['b n n d' ]:
912+ ) -> (
913+ Float ['b n n d' ] |
914+ tuple [Float ['b n n d' ], Tensor ]
915+ ):
912916
913917 if self .need_transpose :
914918 pairwise_repr = rearrange (pairwise_repr , 'b i j d -> b j i d' )
@@ -923,10 +927,11 @@ def forward(
923927
924928 pairwise_repr , unpack_one = pack_one (pairwise_repr , '* n d' )
925929
926- out = self .attn (
930+ out , values = self .attn (
927931 pairwise_repr ,
928932 mask = mask ,
929933 attn_bias = attn_bias ,
934+ return_values = True ,
930935 ** kwargs
931936 )
932937
@@ -935,7 +940,12 @@ def forward(
935940 if self .need_transpose :
936941 out = rearrange (out , 'b j i d -> b i j d' )
937942
938- return self .dropout (out )
943+ out = self .dropout (out )
944+
945+ if not return_values :
946+ return out
947+
948+ return out , values
939949
940950# PairwiseBlock
941951# used in both MSAModule and Pairformer
@@ -978,15 +988,27 @@ def forward(
978988 self ,
979989 * ,
980990 pairwise_repr : Float ['b n n d' ],
981- mask : Bool ['b n' ] | None = None
991+ mask : Bool ['b n' ] | None = None ,
992+ value_residuals : tuple [Tensor , Tensor ] | None = None ,
993+ return_values = False ,
982994 ):
983995 pairwise_repr = self .tri_mult_outgoing (pairwise_repr , mask = mask ) + pairwise_repr
984996 pairwise_repr = self .tri_mult_incoming (pairwise_repr , mask = mask ) + pairwise_repr
985- pairwise_repr = self .tri_attn_starting (pairwise_repr , mask = mask ) + pairwise_repr
986- pairwise_repr = self .tri_attn_ending (pairwise_repr , mask = mask ) + pairwise_repr
997+
998+ attn_start_value_residual , attn_end_value_residual = default (value_residuals , (None , None ))
999+
1000+ attn_start_out , attn_start_values = self .tri_attn_starting (pairwise_repr , mask = mask , value_residual = attn_start_value_residual , return_values = True )
1001+ pairwise_repr = attn_start_out + pairwise_repr
1002+
1003+ attn_end_out , attn_end_values = self .tri_attn_ending (pairwise_repr , mask = mask , value_residual = attn_end_value_residual , return_values = True )
1004+ pairwise_repr = attn_end_out + pairwise_repr
9871005
9881006 pairwise_repr = self .pairwise_transition (pairwise_repr ) + pairwise_repr
989- return pairwise_repr
1007+
1008+ if not return_values :
1009+ return pairwise_repr
1010+
1011+ return pairwise_repr , (attn_start_values , attn_end_values )
9901012
9911013# msa module
9921014
@@ -1465,6 +1487,7 @@ def to_layers(
14651487 ) -> Tuple [Float ['b n ds' ], Float ['b n n dp' ]]:
14661488
14671489 value_residual = None
1490+ pairwise_value_residuals = None
14681491
14691492 for _ in range (self .recurrent_depth ):
14701493 for (
@@ -1473,14 +1496,15 @@ def to_layers(
14731496 single_transition
14741497 ) in self .layers :
14751498
1476- pairwise_repr = pairwise_block (pairwise_repr = pairwise_repr , mask = mask )
1499+ pairwise_repr , pairwise_attn_values = pairwise_block (pairwise_repr = pairwise_repr , mask = mask , value_residuals = pairwise_value_residuals , return_values = True )
14771500
14781501 attn_out , attn_values = pair_bias_attn (single_repr , pairwise_repr = pairwise_repr , mask = mask , return_values = True , value_residual = value_residual )
14791502
14801503 single_repr = single_repr + attn_out
14811504
14821505 if self .add_value_residual :
14831506 value_residual = default (value_residual , attn_values )
1507+ pairwise_value_residuals = default (pairwise_value_residuals , pairwise_attn_values )
14841508
14851509 single_repr = single_transition (single_repr ) + single_repr
14861510
@@ -4493,7 +4517,7 @@ def __init__(
44934517 num_pae_bins = 64 ,
44944518 pairformer_depth = 4 ,
44954519 pairformer_kwargs : dict = dict (),
4496- checkpoint = False
4520+ checkpoint = False
44974521 ):
44984522 super ().__init__ ()
44994523
0 commit comments