@@ -607,7 +607,10 @@ def forward(
607607 self ,
608608 x : Float ['... n d' ],
609609 ** kwargs
610- ) -> Float ['... n d' ]:
610+ ) -> (
611+ Float ['... n d' ] |
612+ tuple [Float ['... n d' ] | Any ]
613+ ):
611614
612615 x = self .norm (x )
613616 return self .fn (x , ** kwargs )
@@ -678,13 +681,26 @@ def forward(
678681 * ,
679682 cond : Float ['b n dc' ],
680683 ** kwargs
681- ) -> Float ['b n d' ]:
684+ ) -> (
685+ Float ['b n d' ] |
686+ tuple [Float ['b n d' ], Float ['b _ _' ]]
687+ ):
682688 x = self .adaptive_norm (x , cond = cond )
683689
684690 out = self .fn (x , ** kwargs )
685691
692+ tuple_output = isinstance (out , tuple )
693+
694+ if tuple_output :
695+ out , * rest = out
696+
686697 gamma = self .to_adaln_zero_gamma (cond )
687- return out * gamma
698+ out = out * gamma
699+
700+ if tuple_output :
701+ out = (out , * rest )
702+
703+ return out
688704
689705# triangle multiplicative module
690706# seems to be unchanged from alphafold2
@@ -762,7 +778,10 @@ def __init__(self, *, heads, dim_pairwise, window_size=None, num_memory_kv=0, **
762778 self .window_size = window_size
763779
764780 self .attn = Attention (
765- heads = heads , window_size = window_size , num_memory_kv = num_memory_kv , ** attn_kwargs
781+ heads = heads ,
782+ window_size = window_size ,
783+ num_memory_kv = num_memory_kv ,
784+ ** attn_kwargs
766785 )
767786
768787 # line 8 of Algorithm 24
@@ -777,8 +796,14 @@ def forward(
777796 * ,
778797 pairwise_repr : Float ["b n n dp" ] | Float ["b nw w (w*2) dp" ], # type: ignore
779798 attn_bias : Float ["b n n" ] | Float ["b nw w (w*2)" ] | None = None , # type: ignore
799+ return_values : bool = False ,
800+ value_residual : Float ['b _ _' ] | None = None ,
780801 ** kwargs ,
781- ) -> Float ["b n ds" ]: # type: ignore
802+ ) -> (
803+ Float ['b n ds' ] |
804+ tuple [Float ['b n ds' ], Float ['b _ _' ]]
805+ ): # type: ignore
806+
782807 """Perform the forward pass.
783808
784809 :param single_repr: The single representation tensor.
@@ -837,9 +862,22 @@ def forward(
837862 else :
838863 attn_bias = self .to_attn_bias (self .to_attn_bias_norm (pairwise_repr )) + attn_bias
839864
840- out = self . attn ( single_repr , attn_bias = attn_bias , ** kwargs )
865+ # attention
841866
842- return out
867+ out , values = self .attn (
868+ single_repr ,
869+ attn_bias = attn_bias ,
870+ value_residual = value_residual ,
871+ return_values = True ,
872+ ** kwargs
873+ )
874+
875+ # whether to return values for value residual learning
876+
877+ if not return_values :
878+ return out
879+
880+ return out , values
843881
844882class TriangleAttention (Module ):
845883 def __init__ (
@@ -1360,6 +1398,7 @@ def __init__(
13601398 dropout_row_prob = 0.25 ,
13611399 num_register_tokens = 0 ,
13621400 checkpoint = False ,
1401+ add_value_residual = False ,
13631402 pairwise_block_kwargs : dict = dict (),
13641403 pair_bias_attn_kwargs : dict = dict ()
13651404 ):
@@ -1395,6 +1434,8 @@ def __init__(
13951434
13961435 self .layers = layers
13971436
1437+ self .add_value_residual = add_value_residual
1438+
13981439 # checkpointing
13991440
14001441 self .checkpoint = checkpoint
@@ -1423,6 +1464,8 @@ def to_layers(
14231464
14241465 ) -> Tuple [Float ['b n ds' ], Float ['b n n dp' ]]:
14251466
1467+ value_residual = None
1468+
14261469 for _ in range (self .recurrent_depth ):
14271470 for (
14281471 pairwise_block ,
@@ -1432,7 +1475,13 @@ def to_layers(
14321475
14331476 pairwise_repr = pairwise_block (pairwise_repr = pairwise_repr , mask = mask )
14341477
1435- single_repr = pair_bias_attn (single_repr , pairwise_repr = pairwise_repr , mask = mask ) + single_repr
1478+ attn_out , attn_values = pair_bias_attn (single_repr , pairwise_repr = pairwise_repr , mask = mask , return_values = True , value_residual = value_residual )
1479+
1480+ single_repr = single_repr + attn_out
1481+
1482+ if self .add_value_residual :
1483+ value_residual = default (value_residual , attn_values )
1484+
14361485 single_repr = single_transition (single_repr ) + single_repr
14371486
14381487 return single_repr , pairwise_repr
@@ -1447,30 +1496,35 @@ def to_checkpointed_layers(
14471496
14481497 ) -> Tuple [Float ['b n ds' ], Float ['b n n dp' ]]:
14491498
1450- inputs = (single_repr , pairwise_repr , mask )
1499+ inputs = (single_repr , pairwise_repr , mask , None )
14511500
14521501 def pairwise_block_wrapper (layer ):
14531502 @wraps (layer )
14541503 def inner (inputs , * args , ** kwargs ):
1455- single_repr , pairwise_repr , mask = inputs
1504+ single_repr , pairwise_repr , mask , maybe_value_residual = inputs
14561505 pairwise_repr = layer (pairwise_repr = pairwise_repr , mask = mask )
1457- return single_repr , pairwise_repr , mask
1506+ return single_repr , pairwise_repr , mask , maybe_value_residual
14581507 return inner
14591508
14601509 def pair_bias_attn_wrapper (layer ):
14611510 @wraps (layer )
14621511 def inner (inputs , * args , ** kwargs ):
1463- single_repr , pairwise_repr , mask = inputs
1464- single_repr = layer (single_repr , pairwise_repr = pairwise_repr , mask = mask ) + single_repr
1465- return single_repr , pairwise_repr , mask
1512+ single_repr , pairwise_repr , mask , maybe_value_residual = inputs
1513+ attn_out , attn_values = layer (single_repr , pairwise_repr = pairwise_repr , mask = mask , return_values = True , value_residual = maybe_value_residual )
1514+ single_repr = single_repr + attn_out
1515+
1516+ if self .add_value_residual :
1517+ maybe_value_residual = default (maybe_value_residual , attn_values )
1518+
1519+ return single_repr , pairwise_repr , mask , maybe_value_residual
14661520 return inner
14671521
14681522 def single_transition_wrapper (layer ):
14691523 @wraps (layer )
14701524 def inner (inputs , * args , ** kwargs ):
1471- single_repr , pairwise_repr , mask = inputs
1525+ single_repr , pairwise_repr , mask , maybe_value_residual = inputs
14721526 single_repr = layer (single_repr ) + single_repr
1473- return single_repr , pairwise_repr , mask
1527+ return single_repr , pairwise_repr , mask , maybe_value_residual
14741528 return inner
14751529
14761530 wrapped_layers = []
@@ -1489,7 +1543,7 @@ def inner(inputs, *args, **kwargs):
14891543 for layer in wrapped_layers :
14901544 inputs = checkpoint (layer , inputs )
14911545
1492- single_repr , pairwise_repr , _ = inputs
1546+ single_repr , pairwise_repr , * _ = inputs
14931547 return single_repr , pairwise_repr
14941548
14951549 @typecheck
@@ -1915,9 +1969,9 @@ def __init__(
19151969 attn_num_memory_kv = False ,
19161970 trans_expansion_factor = 2 ,
19171971 num_register_tokens = 0 ,
1918- add_residual = True ,
19191972 use_linear_attn = False ,
19201973 checkpoint = False ,
1974+ add_value_residual = False ,
19211975 linear_attn_kwargs = dict (
19221976 heads = 8 ,
19231977 dim_head = 16
@@ -1997,7 +2051,7 @@ def __init__(
19972051
19982052 self .layers = layers
19992053
2000- self .add_residual = add_residual
2054+ self .add_value_residual = add_value_residual
20012055
20022056 self .has_registers = num_register_tokens > 0
20032057 self .num_registers = num_register_tokens
@@ -2018,32 +2072,37 @@ def to_checkpointed_serial_layers(
20182072 windowed_mask : Bool ['b nw w (w*2)' ] | None = None
20192073 ):
20202074
2021- inputs = (noised_repr , single_repr , pairwise_repr , mask , windowed_mask )
2075+ inputs = (noised_repr , single_repr , pairwise_repr , mask , windowed_mask , None )
20222076
20232077 wrapped_layers = []
20242078
20252079 def efficient_attn_wrapper (fn ):
20262080 @wraps (fn )
20272081 def inner (inputs ):
2028- noised_repr , single_repr , pairwise_repr , mask , windowed_mask = inputs
2082+ noised_repr , single_repr , pairwise_repr , mask , windowed_mask , maybe_value_residual = inputs
20292083 noised_repr = fn (noised_repr , mask = mask ) + noised_repr
2030- return noised_repr , single_repr , pairwise_repr , mask , windowed_mask
2084+ return noised_repr , single_repr , pairwise_repr , mask , windowed_mask , maybe_value_residual
20312085 return inner
20322086
20332087 def attn_wrapper (fn ):
20342088 @wraps (fn )
20352089 def inner (inputs ):
2036- noised_repr , single_repr , pairwise_repr , mask , windowed_mask = inputs
2037- noised_repr = fn (noised_repr , cond = single_repr , pairwise_repr = pairwise_repr , mask = mask , windowed_mask = windowed_mask ) + noised_repr
2038- return noised_repr , single_repr , pairwise_repr , mask , windowed_mask
2090+ noised_repr , single_repr , pairwise_repr , mask , windowed_mask , maybe_value_residual = inputs
2091+ attn_out , attn_values = fn (noised_repr , cond = single_repr , pairwise_repr = pairwise_repr , mask = mask , windowed_mask = windowed_mask , value_residual = maybe_value_residual , return_values = True )
2092+ noised_repr = attn_out + noised_repr
2093+
2094+ if self .add_value_residual :
2095+ maybe_value_residual = default (maybe_value_residual , attn_values )
2096+
2097+ return noised_repr , single_repr , pairwise_repr , mask , windowed_mask , maybe_value_residual
20392098 return inner
20402099
20412100 def transition_wrapper (fn ):
20422101 @wraps (fn )
20432102 def inner (inputs ):
2044- noised_repr , single_repr , pairwise_repr , mask , windowed_mask = inputs
2103+ noised_repr , single_repr , pairwise_repr , mask , windowed_mask , maybe_value_residual = inputs
20452104 noised_repr = fn (noised_repr , cond = single_repr ) + noised_repr
2046- return noised_repr , single_repr , pairwise_repr , mask , windowed_mask
2105+ return noised_repr , single_repr , pairwise_repr , mask , windowed_mask , maybe_value_residual
20472106 return inner
20482107
20492108 for linear_attn , colt5_attn , attn , transition in self .layers :
@@ -2074,6 +2133,8 @@ def to_serial_layers(
20742133 windowed_mask : Bool ['b nw w (w*2)' ] | None = None
20752134 ):
20762135
2136+ value_residual = None
2137+
20772138 for linear_attn , colt5_attn , attn , transition in self .layers :
20782139
20792140 if exists (linear_attn ):
@@ -2082,13 +2143,20 @@ def to_serial_layers(
20822143 if exists (colt5_attn ):
20832144 noised_repr = colt5_attn (noised_repr , mask = mask ) + noised_repr
20842145
2085- noised_repr = attn (
2146+ attn_out , attn_values = attn (
20862147 noised_repr ,
20872148 cond = single_repr ,
20882149 pairwise_repr = pairwise_repr ,
20892150 mask = mask ,
2090- windowed_mask = windowed_mask
2091- ) + noised_repr
2151+ windowed_mask = windowed_mask ,
2152+ return_values = True ,
2153+ value_residual = value_residual
2154+ )
2155+
2156+ noised_repr = noised_repr + attn_out
2157+
2158+ if self .add_value_residual :
2159+ value_residual = default (value_residual , attn_values )
20922160
20932161 noised_repr = transition (
20942162 noised_repr ,
0 commit comments