Skip to content

Commit f66be08

Browse files
committed
add layerscale, for zeroing out the contributions from MSA module and template embedder, so one can introduce MSA and templates at later stage
1 parent dc40483 commit f66be08

File tree

3 files changed

+25
-6
lines changed

3 files changed

+25
-6
lines changed

README.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,3 +248,14 @@ docker run -v .:/data --gpus all -it af3
248248
url = {https://api.semanticscholar.org/CorpusID:258564608}
249249
}
250250
```
251+
252+
```bibtex
253+
@article{Wang2022DeepNetST,
254+
title = {DeepNet: Scaling Transformers to 1, 000 Layers},
255+
author = {Hongyu Wang and Shuming Ma and Li Dong and Shaohan Huang and Dongdong Zhang and Furu Wei},
256+
journal = {ArXiv},
257+
year = {2022},
258+
volume = {abs/2203.00555},
259+
url = {https://api.semanticscholar.org/CorpusID:247187905}
260+
}
261+
```

alphafold3_pytorch/alphafold3.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -900,7 +900,8 @@ def __init__(
900900
msa_pwa_heads = 8,
901901
msa_pwa_dim_head = 32,
902902
pairwise_block_kwargs: dict = dict(),
903-
max_num_msa: int | None = None
903+
max_num_msa: int | None = None,
904+
layerscale_output: bool = True
904905
):
905906
super().__init__()
906907

@@ -947,6 +948,8 @@ def __init__(
947948

948949
self.layers = layers
949950

951+
self.layerscale_output = nn.Parameter(torch.zeros(dim_pairwise)) if layerscale_output else 1.
952+
950953
@typecheck
951954
def forward(
952955
self,
@@ -1012,7 +1015,7 @@ def forward(
10121015
has_msa, pairwise_repr, 0.
10131016
)
10141017

1015-
return pairwise_repr
1018+
return pairwise_repr * self.layerscale_output
10161019

10171020
# pairformer stack
10181021

@@ -1214,7 +1217,8 @@ def __init__(
12141217
dim_pairwise = 128,
12151218
pairformer_stack_depth = 2,
12161219
pairwise_block_kwargs: dict = dict(),
1217-
eps = 1e-5
1220+
eps = 1e-5,
1221+
layerscale_output = True
12181222
):
12191223
super().__init__()
12201224
self.eps = eps
@@ -1246,6 +1250,8 @@ def __init__(
12461250
nn.ReLU()
12471251
)
12481252

1253+
self.layerscale = nn.Parameter(torch.zeros(dim_pairwise)) if layerscale_output else 1.
1254+
12491255
@typecheck
12501256
def forward(
12511257
self,
@@ -1299,7 +1305,7 @@ def forward(
12991305
has_templates, out, 0.
13001306
)
13011307

1302-
return out
1308+
return out * self.layerscale
13031309

13041310
# diffusion related
13051311
# both diffusion transformer as well as atom encoder / decoder
@@ -2852,6 +2858,7 @@ def __init__(
28522858
template_embedder_kwargs: dict = dict(
28532859
pairformer_stack_depth = 2,
28542860
pairwise_block_kwargs = dict(),
2861+
layerscale_output = True,
28552862
),
28562863
msa_module_kwargs: dict = dict(
28572864
depth = 4,
@@ -2861,7 +2868,8 @@ def __init__(
28612868
msa_pwa_dropout_row_prob = 0.15,
28622869
msa_pwa_heads = 8,
28632870
msa_pwa_dim_head = 32,
2864-
pairwise_block_kwargs = dict()
2871+
pairwise_block_kwargs = dict(),
2872+
layerscale_output = True,
28652873
),
28662874
pairformer_stack: dict = dict(
28672875
depth = 48,

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.1.27"
3+
version = "0.1.28"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

0 commit comments

Comments
 (0)