Skip to content

Commit d16d1e7

Browse files
committed
add ability to use taylor series linear attention in atom encoder and decoder
1 parent e4e6359 commit d16d1e7

File tree

3 files changed

+36
-7
lines changed

3 files changed

+36
-7
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
)
6262

6363
from alphafold3_pytorch.attention import Attention
64+
from taylor_series_linear_attention import TaylorSeriesLinearAttn
6465

6566
import einx
6667
from einops import rearrange, repeat, reduce, einsum, pack, unpack
@@ -1334,7 +1335,12 @@ def __init__(
13341335
attn_window_size = None,
13351336
attn_pair_bias_kwargs: dict = dict(),
13361337
num_register_tokens = 0,
1337-
serial = False
1338+
serial = False,
1339+
use_linear_attn = False,
1340+
linear_attn_kwargs = dict(
1341+
heads = 8,
1342+
dim_head = 16
1343+
)
13381344
):
13391345
super().__init__()
13401346
dim_single_cond = default(dim_single_cond, dim)
@@ -1343,6 +1349,15 @@ def __init__(
13431349

13441350
for _ in range(depth):
13451351

1352+
linear_attn = None
1353+
1354+
if use_linear_attn:
1355+
linear_attn = TaylorSeriesLinearAttn(
1356+
dim = dim,
1357+
prenorm = True,
1358+
**linear_attn_kwargs
1359+
)
1360+
13461361
pair_bias_attn = AttentionPairBias(
13471362
dim = dim,
13481363
dim_pairwise = dim_pairwise,
@@ -1368,6 +1383,7 @@ def __init__(
13681383
)
13691384

13701385
layers.append(ModuleList([
1386+
linear_attn,
13711387
conditionable_pair_bias,
13721388
conditionable_transition
13731389
]))
@@ -1408,7 +1424,10 @@ def forward(
14081424

14091425
# main transformer
14101426

1411-
for attn, transition in self.layers:
1427+
for linear_attn, attn, transition in self.layers:
1428+
1429+
if exists(linear_attn):
1430+
noised_repr = linear_attn(noised_repr, mask = mask) + noised_repr
14121431

14131432
attn_out = attn(
14141433
noised_repr,
@@ -1527,7 +1546,12 @@ def __init__(
15271546
serial = False,
15281547
atom_encoder_kwargs: dict = dict(),
15291548
atom_decoder_kwargs: dict = dict(),
1530-
token_transformer_kwargs: dict = dict()
1549+
token_transformer_kwargs: dict = dict(),
1550+
use_linear_attn = False,
1551+
linear_attn_kwargs: dict = dict(
1552+
heads = 8,
1553+
dim_head = 16
1554+
)
15311555
):
15321556
super().__init__()
15331557

@@ -1584,6 +1608,8 @@ def __init__(
15841608
depth = atom_encoder_depth,
15851609
heads = atom_encoder_heads,
15861610
serial = serial,
1611+
use_linear_attn = use_linear_attn,
1612+
linear_attn_kwargs = linear_attn_kwargs,
15871613
**atom_encoder_kwargs
15881614
)
15891615

@@ -1624,6 +1650,8 @@ def __init__(
16241650
depth = atom_decoder_depth,
16251651
heads = atom_decoder_heads,
16261652
serial = serial,
1653+
use_linear_attn = use_linear_attn,
1654+
linear_attn_kwargs = linear_attn_kwargs,
16271655
**atom_decoder_kwargs
16281656
)
16291657

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.0.50"
3+
version = "0.0.51"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

tests/test_af3.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -176,16 +176,17 @@ def test_msa_module():
176176

177177
assert pairwise.shape == pairwise_out.shape
178178

179-
180-
def test_diffusion_transformer():
179+
@pytest.mark.parametrize('use_linear_attn', (False, True))
180+
def test_diffusion_transformer(use_linear_attn):
181181

182182
single = torch.randn(2, 16, 384)
183183
pairwise = torch.randn(2, 16, 16, 128)
184184
mask = torch.randint(0, 2, (2, 16)).bool()
185185

186186
diffusion_transformer = DiffusionTransformer(
187187
depth = 2,
188-
heads = 16
188+
heads = 16,
189+
use_linear_attn = use_linear_attn
189190
)
190191

191192
single_out = diffusion_transformer(

0 commit comments

Comments
 (0)