Skip to content

Commit d25fbca

Browse files
committed
add ability to use conditionally routed attention from CoLT5 paper for alleviating sequence-local issue with atoms
1 parent 32a0839 commit d25fbca

File tree

4 files changed

+38
-4
lines changed

4 files changed

+38
-4
lines changed

README.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,3 +263,11 @@ docker run -v .:/data --gpus all -it af3
263263
url = {https://api.semanticscholar.org/CorpusID:247187905}
264264
}
265265
```
266+
267+
```bibtex
268+
@inproceedings{Ainslie2023CoLT5FL,
269+
title = {CoLT5: Faster Long-Range Transformers with Conditional Computation},
270+
author = {Joshua Ainslie and Tao Lei and Michiel de Jong and Santiago Ontan'on and Siddhartha Brahma and Yury Zemlyanskiy and David Uthus and Mandy Guo and James Lee-Thorp and Yi Tay and Yun-Hsuan Sung and Sumit Sanghai},
271+
year = {2023}
272+
}
273+
```

alphafold3_pytorch/alphafold3.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@
4141

4242
from taylor_series_linear_attention import TaylorSeriesLinearAttn
4343

44+
from colt5_attention import ConditionalRoutedAttention
45+
4446
import einx
4547
from einops import rearrange, repeat, reduce, einsum, pack, unpack
4648
from einops.layers.torch import Rearrange
@@ -1460,7 +1462,15 @@ def __init__(
14601462
linear_attn_kwargs = dict(
14611463
heads = 8,
14621464
dim_head = 16
1465+
),
1466+
use_colt5_attn = False,
1467+
colt5_attn_kwargs = dict(
1468+
heavy_dim_head = 64,
1469+
heavy_heads = 8,
1470+
num_heavy_tokens_q = 512,
1471+
num_heavy_tokens_kv = 512
14631472
)
1473+
14641474
):
14651475
super().__init__()
14661476
self.attn_window_size = attn_window_size
@@ -1481,6 +1491,15 @@ def __init__(
14811491
**linear_attn_kwargs
14821492
)
14831493

1494+
colt5_attn = None
1495+
1496+
if use_colt5_attn:
1497+
colt5_attn = ConditionalRoutedAttention(
1498+
dim = dim,
1499+
has_light_attn = False,
1500+
**colt5_attn_kwargs
1501+
)
1502+
14841503
pair_bias_attn = AttentionPairBias(
14851504
dim = dim,
14861505
dim_pairwise = dim_pairwise,
@@ -1508,6 +1527,7 @@ def __init__(
15081527

15091528
layers.append(ModuleList([
15101529
linear_attn,
1530+
colt5_attn,
15111531
conditionable_pair_bias,
15121532
conditionable_transition
15131533
]))
@@ -1560,11 +1580,14 @@ def forward(
15601580

15611581
# main transformer
15621582

1563-
for linear_attn, attn, transition in self.layers:
1583+
for linear_attn, colt5_attn, attn, transition in self.layers:
15641584

15651585
if exists(linear_attn):
15661586
noised_repr = linear_attn(noised_repr, mask = mask) + noised_repr
15671587

1588+
if exists(colt5_attn):
1589+
noised_repr = colt5_attn(noised_repr, mask = mask) + noised_repr
1590+
15681591
attn_out = attn(
15691592
noised_repr,
15701593
cond = single_repr,

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.1.51"
3+
version = "0.1.52"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }
@@ -25,6 +25,7 @@ classifiers=[
2525
dependencies = [
2626
"beartype",
2727
"biopython>=1.83",
28+
"CoLT5-attention>=0.11.0",
2829
"einops>=0.8.0",
2930
"einx>=0.2.2",
3031
"ema-pytorch>=0.5.0",

tests/test_af3.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,8 @@ def test_msa_module():
206206
assert pairwise.shape == pairwise_out.shape
207207

208208
@pytest.mark.parametrize('use_linear_attn', (False, True))
209-
def test_diffusion_transformer(use_linear_attn):
209+
@pytest.mark.parametrize('use_colt5_attn', (False, True))
210+
def test_diffusion_transformer(use_linear_attn, use_colt5_attn):
210211

211212
single = torch.randn(2, 16, 384)
212213
pairwise = torch.randn(2, 16, 16, 128)
@@ -215,7 +216,8 @@ def test_diffusion_transformer(use_linear_attn):
215216
diffusion_transformer = DiffusionTransformer(
216217
depth = 2,
217218
heads = 16,
218-
use_linear_attn = use_linear_attn
219+
use_linear_attn = use_linear_attn,
220+
use_colt5_attn = use_colt5_attn
219221
)
220222

221223
single_out = diffusion_transformer(

0 commit comments

Comments
 (0)