Skip to content

Commit 95d3ab6

Browse files
authored
value residual learning (#312)
* cite * add value residual learning * oops * slip in value residual learning for pairformer stack * also cite Nguyen, whose initial paper led here
1 parent 9e5bb92 commit 95d3ab6

File tree

5 files changed

+152
-37
lines changed

5 files changed

+152
-37
lines changed

README.md

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -494,3 +494,23 @@ docker run -v .:/data --gpus all -it af3
494494
url = {https://api.semanticscholar.org/CorpusID:267657558}
495495
}
496496
```
497+
498+
```bibtex
499+
@article{Nguyen2023MitigatingOI,
500+
title = {Mitigating Over-smoothing in Transformers via Regularized Nonlocal Functionals},
501+
author = {Tam Nguyen and Tan M. Nguyen and Richard G. Baraniuk},
502+
journal = {ArXiv},
503+
year = {2023},
504+
volume = {abs/2312.00751},
505+
url = {https://api.semanticscholar.org/CorpusID:264300597}
506+
}
507+
```
508+
509+
```bibtex
510+
@inproceedings{Zhou2024ValueRL,
511+
title = {Value Residual Learning For Alleviating Attention Concentration In Transformers},
512+
author = {Zhanchao Zhou and Tianyi Wu and Zhiyun Jiang and Zhenzhong Lan},
513+
year = {2024},
514+
url = {https://api.semanticscholar.org/CorpusID:273532030}
515+
}
516+
```

alphafold3_pytorch/alphafold3.py

Lines changed: 98 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -607,7 +607,10 @@ def forward(
607607
self,
608608
x: Float['... n d'],
609609
**kwargs
610-
) -> Float['... n d']:
610+
) -> (
611+
Float['... n d'] |
612+
tuple[Float['... n d'] | Any]
613+
):
611614

612615
x = self.norm(x)
613616
return self.fn(x, **kwargs)
@@ -678,13 +681,26 @@ def forward(
678681
*,
679682
cond: Float['b n dc'],
680683
**kwargs
681-
) -> Float['b n d']:
684+
) -> (
685+
Float['b n d'] |
686+
tuple[Float['b n d'], Float['b _ _']]
687+
):
682688
x = self.adaptive_norm(x, cond = cond)
683689

684690
out = self.fn(x, **kwargs)
685691

692+
tuple_output = isinstance(out, tuple)
693+
694+
if tuple_output:
695+
out, *rest = out
696+
686697
gamma = self.to_adaln_zero_gamma(cond)
687-
return out * gamma
698+
out = out * gamma
699+
700+
if tuple_output:
701+
out = (out, *rest)
702+
703+
return out
688704

689705
# triangle multiplicative module
690706
# seems to be unchanged from alphafold2
@@ -762,7 +778,10 @@ def __init__(self, *, heads, dim_pairwise, window_size=None, num_memory_kv=0, **
762778
self.window_size = window_size
763779

764780
self.attn = Attention(
765-
heads=heads, window_size=window_size, num_memory_kv=num_memory_kv, **attn_kwargs
781+
heads = heads,
782+
window_size = window_size,
783+
num_memory_kv = num_memory_kv,
784+
**attn_kwargs
766785
)
767786

768787
# line 8 of Algorithm 24
@@ -777,8 +796,14 @@ def forward(
777796
*,
778797
pairwise_repr: Float["b n n dp"] | Float["b nw w (w*2) dp"], # type: ignore
779798
attn_bias: Float["b n n"] | Float["b nw w (w*2)"] | None = None, # type: ignore
799+
return_values: bool = False,
800+
value_residual: Float['b _ _'] | None = None,
780801
**kwargs,
781-
) -> Float["b n ds"]: # type: ignore
802+
) -> (
803+
Float['b n ds'] |
804+
tuple[Float['b n ds'], Float['b _ _']]
805+
): # type: ignore
806+
782807
"""Perform the forward pass.
783808
784809
:param single_repr: The single representation tensor.
@@ -837,9 +862,22 @@ def forward(
837862
else:
838863
attn_bias = self.to_attn_bias(self.to_attn_bias_norm(pairwise_repr)) + attn_bias
839864

840-
out = self.attn(single_repr, attn_bias=attn_bias, **kwargs)
865+
# attention
841866

842-
return out
867+
out, values = self.attn(
868+
single_repr,
869+
attn_bias = attn_bias,
870+
value_residual = value_residual,
871+
return_values = True,
872+
**kwargs
873+
)
874+
875+
# whether to return values for value residual learning
876+
877+
if not return_values:
878+
return out
879+
880+
return out, values
843881

844882
class TriangleAttention(Module):
845883
def __init__(
@@ -1360,6 +1398,7 @@ def __init__(
13601398
dropout_row_prob = 0.25,
13611399
num_register_tokens = 0,
13621400
checkpoint = False,
1401+
add_value_residual = False,
13631402
pairwise_block_kwargs: dict = dict(),
13641403
pair_bias_attn_kwargs: dict = dict()
13651404
):
@@ -1395,6 +1434,8 @@ def __init__(
13951434

13961435
self.layers = layers
13971436

1437+
self.add_value_residual = add_value_residual
1438+
13981439
# checkpointing
13991440

14001441
self.checkpoint = checkpoint
@@ -1423,6 +1464,8 @@ def to_layers(
14231464

14241465
) -> Tuple[Float['b n ds'], Float['b n n dp']]:
14251466

1467+
value_residual = None
1468+
14261469
for _ in range(self.recurrent_depth):
14271470
for (
14281471
pairwise_block,
@@ -1432,7 +1475,13 @@ def to_layers(
14321475

14331476
pairwise_repr = pairwise_block(pairwise_repr = pairwise_repr, mask = mask)
14341477

1435-
single_repr = pair_bias_attn(single_repr, pairwise_repr = pairwise_repr, mask = mask) + single_repr
1478+
attn_out, attn_values = pair_bias_attn(single_repr, pairwise_repr = pairwise_repr, mask = mask, return_values = True, value_residual = value_residual)
1479+
1480+
single_repr = single_repr + attn_out
1481+
1482+
if self.add_value_residual:
1483+
value_residual = default(value_residual, attn_values)
1484+
14361485
single_repr = single_transition(single_repr) + single_repr
14371486

14381487
return single_repr, pairwise_repr
@@ -1447,30 +1496,35 @@ def to_checkpointed_layers(
14471496

14481497
) -> Tuple[Float['b n ds'], Float['b n n dp']]:
14491498

1450-
inputs = (single_repr, pairwise_repr, mask)
1499+
inputs = (single_repr, pairwise_repr, mask, None)
14511500

14521501
def pairwise_block_wrapper(layer):
14531502
@wraps(layer)
14541503
def inner(inputs, *args, **kwargs):
1455-
single_repr, pairwise_repr, mask = inputs
1504+
single_repr, pairwise_repr, mask, maybe_value_residual = inputs
14561505
pairwise_repr = layer(pairwise_repr = pairwise_repr, mask = mask)
1457-
return single_repr, pairwise_repr, mask
1506+
return single_repr, pairwise_repr, mask, maybe_value_residual
14581507
return inner
14591508

14601509
def pair_bias_attn_wrapper(layer):
14611510
@wraps(layer)
14621511
def inner(inputs, *args, **kwargs):
1463-
single_repr, pairwise_repr, mask = inputs
1464-
single_repr = layer(single_repr, pairwise_repr = pairwise_repr, mask = mask) + single_repr
1465-
return single_repr, pairwise_repr, mask
1512+
single_repr, pairwise_repr, mask, maybe_value_residual = inputs
1513+
attn_out, attn_values = layer(single_repr, pairwise_repr = pairwise_repr, mask = mask, return_values = True, value_residual = maybe_value_residual)
1514+
single_repr = single_repr + attn_out
1515+
1516+
if self.add_value_residual:
1517+
maybe_value_residual = default(maybe_value_residual, attn_values)
1518+
1519+
return single_repr, pairwise_repr, mask, maybe_value_residual
14661520
return inner
14671521

14681522
def single_transition_wrapper(layer):
14691523
@wraps(layer)
14701524
def inner(inputs, *args, **kwargs):
1471-
single_repr, pairwise_repr, mask = inputs
1525+
single_repr, pairwise_repr, mask, maybe_value_residual = inputs
14721526
single_repr = layer(single_repr) + single_repr
1473-
return single_repr, pairwise_repr, mask
1527+
return single_repr, pairwise_repr, mask, maybe_value_residual
14741528
return inner
14751529

14761530
wrapped_layers = []
@@ -1489,7 +1543,7 @@ def inner(inputs, *args, **kwargs):
14891543
for layer in wrapped_layers:
14901544
inputs = checkpoint(layer, inputs)
14911545

1492-
single_repr, pairwise_repr, _ = inputs
1546+
single_repr, pairwise_repr, *_ = inputs
14931547
return single_repr, pairwise_repr
14941548

14951549
@typecheck
@@ -1915,9 +1969,9 @@ def __init__(
19151969
attn_num_memory_kv = False,
19161970
trans_expansion_factor = 2,
19171971
num_register_tokens = 0,
1918-
add_residual = True,
19191972
use_linear_attn = False,
19201973
checkpoint = False,
1974+
add_value_residual = False,
19211975
linear_attn_kwargs = dict(
19221976
heads = 8,
19231977
dim_head = 16
@@ -1997,7 +2051,7 @@ def __init__(
19972051

19982052
self.layers = layers
19992053

2000-
self.add_residual = add_residual
2054+
self.add_value_residual = add_value_residual
20012055

20022056
self.has_registers = num_register_tokens > 0
20032057
self.num_registers = num_register_tokens
@@ -2018,32 +2072,37 @@ def to_checkpointed_serial_layers(
20182072
windowed_mask: Bool['b nw w (w*2)'] | None = None
20192073
):
20202074

2021-
inputs = (noised_repr, single_repr, pairwise_repr, mask, windowed_mask)
2075+
inputs = (noised_repr, single_repr, pairwise_repr, mask, windowed_mask, None)
20222076

20232077
wrapped_layers = []
20242078

20252079
def efficient_attn_wrapper(fn):
20262080
@wraps(fn)
20272081
def inner(inputs):
2028-
noised_repr, single_repr, pairwise_repr, mask, windowed_mask = inputs
2082+
noised_repr, single_repr, pairwise_repr, mask, windowed_mask, maybe_value_residual = inputs
20292083
noised_repr = fn(noised_repr, mask = mask) + noised_repr
2030-
return noised_repr, single_repr, pairwise_repr, mask, windowed_mask
2084+
return noised_repr, single_repr, pairwise_repr, mask, windowed_mask, maybe_value_residual
20312085
return inner
20322086

20332087
def attn_wrapper(fn):
20342088
@wraps(fn)
20352089
def inner(inputs):
2036-
noised_repr, single_repr, pairwise_repr, mask, windowed_mask = inputs
2037-
noised_repr = fn(noised_repr, cond = single_repr, pairwise_repr = pairwise_repr, mask = mask, windowed_mask = windowed_mask) + noised_repr
2038-
return noised_repr, single_repr, pairwise_repr, mask, windowed_mask
2090+
noised_repr, single_repr, pairwise_repr, mask, windowed_mask, maybe_value_residual = inputs
2091+
attn_out, attn_values = fn(noised_repr, cond = single_repr, pairwise_repr = pairwise_repr, mask = mask, windowed_mask = windowed_mask, value_residual = maybe_value_residual, return_values = True)
2092+
noised_repr = attn_out + noised_repr
2093+
2094+
if self.add_value_residual:
2095+
maybe_value_residual = default(maybe_value_residual, attn_values)
2096+
2097+
return noised_repr, single_repr, pairwise_repr, mask, windowed_mask, maybe_value_residual
20392098
return inner
20402099

20412100
def transition_wrapper(fn):
20422101
@wraps(fn)
20432102
def inner(inputs):
2044-
noised_repr, single_repr, pairwise_repr, mask, windowed_mask = inputs
2103+
noised_repr, single_repr, pairwise_repr, mask, windowed_mask, maybe_value_residual = inputs
20452104
noised_repr = fn(noised_repr, cond = single_repr) + noised_repr
2046-
return noised_repr, single_repr, pairwise_repr, mask, windowed_mask
2105+
return noised_repr, single_repr, pairwise_repr, mask, windowed_mask, maybe_value_residual
20472106
return inner
20482107

20492108
for linear_attn, colt5_attn, attn, transition in self.layers:
@@ -2074,6 +2133,8 @@ def to_serial_layers(
20742133
windowed_mask: Bool['b nw w (w*2)'] | None = None
20752134
):
20762135

2136+
value_residual = None
2137+
20772138
for linear_attn, colt5_attn, attn, transition in self.layers:
20782139

20792140
if exists(linear_attn):
@@ -2082,13 +2143,20 @@ def to_serial_layers(
20822143
if exists(colt5_attn):
20832144
noised_repr = colt5_attn(noised_repr, mask = mask) + noised_repr
20842145

2085-
noised_repr = attn(
2146+
attn_out, attn_values = attn(
20862147
noised_repr,
20872148
cond = single_repr,
20882149
pairwise_repr = pairwise_repr,
20892150
mask = mask,
2090-
windowed_mask = windowed_mask
2091-
) + noised_repr
2151+
windowed_mask = windowed_mask,
2152+
return_values = True,
2153+
value_residual = value_residual
2154+
)
2155+
2156+
noised_repr = noised_repr + attn_out
2157+
2158+
if self.add_value_residual:
2159+
value_residual = default(value_residual, attn_values)
20922160

20932161
noised_repr = transition(
20942162
noised_repr,

alphafold3_pytorch/attention.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -237,15 +237,29 @@ def forward(
237237
mask: Bool['b n']| None = None,
238238
context: Float['b j d'] | None = None,
239239
windowed_mask: Bool['b nw w (w*2)'] | None = None,
240-
attn_bias: Float['... i j'] | Float['... nw w (w*2)'] | None = None
240+
attn_bias: Float['... i j'] | Float['... nw w (w*2)'] | None = None,
241+
return_values: bool = False,
242+
value_residual: Float['b j dh'] | None = None,
241243

242-
) -> Float['b i d']:
244+
) -> (
245+
Float['b i d'] |
246+
tuple[Float['b i d'], Float['b j _']]
247+
):
243248

244249
q = self.to_q(seq)
245250

246251
context_seq = default(context, seq)
247252
k, v = self.to_kv(context_seq).chunk(2, dim = -1)
248253

254+
# handle value residual
255+
256+
orig_v = v
257+
258+
if exists(value_residual):
259+
v = 0.5 * (v + value_residual)
260+
261+
# split heads
262+
249263
q, k, v = tuple(self.split_heads(t) for t in (q, k, v))
250264

251265
# attention
@@ -270,7 +284,14 @@ def forward(
270284

271285
# combine heads
272286

273-
return self.to_out(out)
287+
out = self.to_out(out)
288+
289+
# maybe return values
290+
291+
if not return_values:
292+
return out
293+
294+
return out, orig_v
274295

275296
# the main attention function
276297

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.6.5"
3+
version = "0.6.6"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" },

0 commit comments

Comments
 (0)