@@ -683,7 +683,7 @@ def forward(
683683 ** kwargs
684684 ) -> (
685685 Float ['b n d' ] |
686- tuple [Float ['b n d' ], Float ['b _ _' ]]
686+ tuple [Float ['b n d' ], Float ['b _ _ _ ' ]]
687687 ):
688688 x = self .adaptive_norm (x , cond = cond )
689689
@@ -797,11 +797,11 @@ def forward(
797797 pairwise_repr : Float ["b n n dp" ] | Float ["b nw w (w*2) dp" ], # type: ignore
798798 attn_bias : Float ["b n n" ] | Float ["b nw w (w*2)" ] | None = None , # type: ignore
799799 return_values : bool = False ,
800- value_residual : Float ['b _ _' ] | None = None ,
800+ value_residual : Float ['b _ _ _ ' ] | None = None ,
801801 ** kwargs ,
802802 ) -> (
803803 Float ['b n ds' ] |
804- tuple [Float ['b n ds' ], Float ['b _ _' ]]
804+ tuple [Float ['b n ds' ], Float ['b _ _ _ ' ]]
805805 ): # type: ignore
806806
807807 """Perform the forward pass.
@@ -961,6 +961,7 @@ def __init__(
961961 tri_attn_heads = 4 ,
962962 dropout_row_prob = 0.25 ,
963963 dropout_col_prob = 0.25 ,
964+ accept_value_residual = False
964965 ):
965966 super ().__init__ ()
966967
@@ -974,7 +975,8 @@ def __init__(
974975 tri_attn_kwargs = dict (
975976 dim = dim_pairwise ,
976977 heads = tri_attn_heads ,
977- dim_head = tri_attn_dim_head
978+ dim_head = tri_attn_dim_head ,
979+ accept_value_residual = accept_value_residual
978980 )
979981
980982 self .tri_mult_outgoing = pre_ln (TriangleMultiplication (mix = 'outgoing' , dropout = dropout_row_prob , dropout_type = 'row' , ** tri_mult_kwargs ))
@@ -1436,16 +1438,20 @@ def __init__(
14361438 ** pair_bias_attn_kwargs
14371439 )
14381440
1439- for _ in range (depth ):
1441+ for i in range (depth ):
1442+
1443+ is_first = i == 0
1444+ accept_value_residual = add_value_residual and not is_first
14401445
14411446 single_pre_ln = partial (PreLayerNorm , dim = dim_single )
14421447
14431448 pairwise_block = PairwiseBlock (
14441449 dim_pairwise = dim_pairwise ,
1450+ accept_value_residual = accept_value_residual ,
14451451 ** pairwise_block_kwargs
14461452 )
14471453
1448- pair_bias_attn = AttentionPairBias (** pair_bias_attn_kwargs )
1454+ pair_bias_attn = AttentionPairBias (accept_value_residual = accept_value_residual , ** pair_bias_attn_kwargs )
14491455 single_transition = Transition (dim = dim_single )
14501456
14511457 layers .append (ModuleList ([
@@ -1486,10 +1492,11 @@ def to_layers(
14861492
14871493 ) -> Tuple [Float ['b n ds' ], Float ['b n n dp' ]]:
14881494
1489- value_residual = None
1490- pairwise_value_residuals = None
1491-
14921495 for _ in range (self .recurrent_depth ):
1496+
1497+ value_residual = None
1498+ pairwise_value_residuals = None
1499+
14931500 for (
14941501 pairwise_block ,
14951502 pair_bias_attn ,
@@ -1520,54 +1527,59 @@ def to_checkpointed_layers(
15201527
15211528 ) -> Tuple [Float ['b n ds' ], Float ['b n n dp' ]]:
15221529
1523- inputs = (single_repr , pairwise_repr , mask , None )
1524-
15251530 def pairwise_block_wrapper (layer ):
15261531 @wraps (layer )
15271532 def inner (inputs , * args , ** kwargs ):
1528- single_repr , pairwise_repr , mask , maybe_value_residual = inputs
1529- pairwise_repr = layer (pairwise_repr = pairwise_repr , mask = mask )
1530- return single_repr , pairwise_repr , mask , maybe_value_residual
1533+ single_repr , pairwise_repr , mask , maybe_value_residual , maybe_pairwise_value_residuals = inputs
1534+ pairwise_repr , pairwise_attn_values = layer (pairwise_repr = pairwise_repr , mask = mask , value_residuals = maybe_pairwise_value_residuals , return_values = True )
1535+
1536+ if self .add_value_residual :
1537+ maybe_pairwise_value_residuals = default (maybe_pairwise_value_residuals , pairwise_attn_values )
1538+
1539+ return single_repr , pairwise_repr , mask , maybe_value_residual , maybe_pairwise_value_residuals
15311540 return inner
15321541
15331542 def pair_bias_attn_wrapper (layer ):
15341543 @wraps (layer )
15351544 def inner (inputs , * args , ** kwargs ):
1536- single_repr , pairwise_repr , mask , maybe_value_residual = inputs
1545+ single_repr , pairwise_repr , mask , maybe_value_residual , maybe_pairwise_value_residuals = inputs
15371546 attn_out , attn_values = layer (single_repr , pairwise_repr = pairwise_repr , mask = mask , return_values = True , value_residual = maybe_value_residual )
15381547 single_repr = single_repr + attn_out
15391548
15401549 if self .add_value_residual :
15411550 maybe_value_residual = default (maybe_value_residual , attn_values )
15421551
1543- return single_repr , pairwise_repr , mask , maybe_value_residual
1552+ return single_repr , pairwise_repr , mask , maybe_value_residual , maybe_pairwise_value_residuals
15441553 return inner
15451554
15461555 def single_transition_wrapper (layer ):
15471556 @wraps (layer )
15481557 def inner (inputs , * args , ** kwargs ):
1549- single_repr , pairwise_repr , mask , maybe_value_residual = inputs
1558+ single_repr , pairwise_repr , mask , maybe_value_residual , maybe_pairwise_value_residuals = inputs
15501559 single_repr = layer (single_repr ) + single_repr
1551- return single_repr , pairwise_repr , mask , maybe_value_residual
1560+ return single_repr , pairwise_repr , mask , maybe_value_residual , maybe_pairwise_value_residuals
15521561 return inner
15531562
15541563 wrapped_layers = []
15551564
1565+ for (
1566+ pairwise_block ,
1567+ pair_bias_attn ,
1568+ single_transition
1569+ ) in self .layers :
1570+
1571+ wrapped_layers .append (pairwise_block_wrapper (pairwise_block ))
1572+ wrapped_layers .append (pair_bias_attn_wrapper (pair_bias_attn ))
1573+ wrapped_layers .append (single_transition_wrapper (single_transition ))
1574+
15561575 for _ in range (self .recurrent_depth ):
1557- for (
1558- pairwise_block ,
1559- pair_bias_attn ,
1560- single_transition
1561- ) in self .layers :
1576+ inputs = (single_repr , pairwise_repr , mask , None , None )
15621577
1563- wrapped_layers .append (pairwise_block_wrapper (pairwise_block ))
1564- wrapped_layers .append (pair_bias_attn_wrapper (pair_bias_attn ))
1565- wrapped_layers .append (single_transition_wrapper (single_transition ))
1578+ for layer in wrapped_layers :
1579+ inputs = checkpoint (layer , inputs )
15661580
1567- for layer in wrapped_layers :
1568- inputs = checkpoint (layer , inputs )
1581+ single_repr , pairwise_repr , * _ = inputs
15691582
1570- single_repr , pairwise_repr , * _ = inputs
15711583 return single_repr , pairwise_repr
15721584
15731585 @typecheck
@@ -2016,7 +2028,8 @@ def __init__(
20162028
20172029 layers = ModuleList ([])
20182030
2019- for _ in range (depth ):
2031+ for i in range (depth ):
2032+ is_first = i == 0
20202033
20212034 linear_attn = None
20222035
@@ -2038,12 +2051,15 @@ def __init__(
20382051 ** colt5_attn_kwargs
20392052 )
20402053
2054+ accept_value_residual = add_value_residual and not is_first
2055+
20412056 pair_bias_attn = AttentionPairBias (
20422057 dim = dim ,
20432058 dim_pairwise = dim_pairwise ,
20442059 heads = heads ,
20452060 window_size = attn_window_size ,
20462061 num_memory_kv = attn_num_memory_kv ,
2062+ accept_value_residual = accept_value_residual ,
20472063 ** attn_pair_bias_kwargs
20482064 )
20492065
0 commit comments