Skip to content

Commit d85f17e

Browse files
committed
local attention cannot used register tokens, so we will use the next best thing
1 parent ac669e8 commit d85f17e

File tree

4 files changed

+56
-4
lines changed

4 files changed

+56
-4
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -570,6 +570,7 @@ def __init__(
570570
heads,
571571
dim_pairwise,
572572
window_size = None,
573+
num_memory_kv = 0,
573574
**attn_kwargs
574575
):
575576
super().__init__()
@@ -579,6 +580,7 @@ def __init__(
579580
self.attn = Attention(
580581
heads = heads,
581582
window_size = window_size,
583+
num_memory_kv = num_memory_kv,
582584
**attn_kwargs
583585
)
584586

@@ -1434,6 +1436,7 @@ def __init__(
14341436
dim_pairwise = 128,
14351437
attn_window_size = None,
14361438
attn_pair_bias_kwargs: dict = dict(),
1439+
attn_num_memory_kv = False,
14371440
num_register_tokens = 0,
14381441
serial = False,
14391442
use_linear_attn = False,
@@ -1466,6 +1469,7 @@ def __init__(
14661469
dim_pairwise = dim_pairwise,
14671470
heads = heads,
14681471
window_size = attn_window_size,
1472+
num_memory_kv = attn_num_memory_kv,
14691473
**attn_pair_bias_kwargs
14701474
)
14711475

alphafold3_pytorch/attention.py

Lines changed: 48 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,7 @@ def __init__(
165165
query_bias = True,
166166
flash = True,
167167
window_size = None,
168+
num_memory_kv: int = 0,
168169
efficient_attn_config: Config = Config(True, True, True)
169170
):
170171
super().__init__()
@@ -178,6 +179,7 @@ def __init__(
178179
e - dimension (pairwise rep)
179180
i - source sequence
180181
j - context sequence
182+
m - memory key / value seq
181183
"""
182184

183185
dim_inner = dim_head * heads
@@ -196,6 +198,12 @@ def __init__(
196198
self.to_kv = nn.Linear(dim, dim_inner * 2, bias = False)
197199
self.to_out = nn.Linear(dim_inner, dim, bias = False)
198200

201+
self.memory_kv = None
202+
203+
if num_memory_kv > 0:
204+
self.memory_kv = nn.Parameter(torch.zeros(2, heads, num_memory_kv, dim_head))
205+
nn.init.normal_(self.memory_kv, std = 0.02)
206+
199207
# gating of value
200208
# allows attention to attend to nothing
201209

@@ -230,7 +238,8 @@ def forward(
230238
out = self.attend(
231239
q, k, v,
232240
attn_bias = attn_bias,
233-
mask = mask
241+
mask = mask,
242+
memory_kv = self.memory_kv
234243
)
235244

236245
# merge heads
@@ -315,7 +324,8 @@ def local_attn(
315324
k: Float['b h n d'],
316325
v: Float['b h n d'],
317326
mask: Bool['b n'] | None = None,
318-
attn_bias: Float['... n n'] | Float['... nw w (w*2)'] | None = None
327+
attn_bias: Float['... n n'] | Float['... nw w (w*2)'] | None = None,
328+
memory_kv: Float['2 h m d'] | None = None
319329
) -> Float['b h n d']:
320330
"""
321331
simple local attention with a radius of 1 window size
@@ -363,6 +373,24 @@ def local_attn(
363373

364374
q = q * scale
365375

376+
# append memory key / values for local attention windows
377+
378+
if exists(memory_kv):
379+
batch, seq, num_mem_kv = k.shape[0], k.shape[2], memory_kv.shape[-2]
380+
381+
mk, mv = memory_kv
382+
mk, mv = tuple(repeat(t, 'h m d -> b h n m d', b = batch, n = seq) for t in (mk, mv))
383+
k = torch.cat((mk, k), dim = -2)
384+
v = torch.cat((mv, v), dim = -2)
385+
386+
if exists(attn_bias):
387+
attn_bias = pad_at_dim(attn_bias, (num_mem_kv, 0), value = 0.)
388+
389+
if exists(mask):
390+
mask = pad_at_dim(mask, (num_mem_kv, 0), value = True)
391+
392+
# similarity
393+
366394
sim = einsum(q, k, "... i d, ... j d -> ... i j")
367395

368396
if exists(attn_bias):
@@ -399,6 +427,7 @@ def forward(
399427
v: Float['b h j d'],
400428
mask: Bool['b j'] | None = None,
401429
attn_bias: Float['... i j'] | Float['... nw w (w*2)'] | None = None,
430+
memory_kv: Float['2 h m d'] | None = None
402431
) -> Float['b h i d']:
403432

404433
is_windowed_attn_bias = None
@@ -410,10 +439,26 @@ def forward(
410439
# todo (handle attn bias efficiently)
411440

412441
if self.is_local_attn:
413-
return self.local_attn(q, k, v, mask = mask, attn_bias = attn_bias)
442+
return self.local_attn(q, k, v, mask = mask, attn_bias = attn_bias, memory_kv = memory_kv)
414443

415444
assert not exists(is_windowed_attn_bias) or not is_windowed_attn_bias
416445

446+
# append memory key / values
447+
448+
if exists(memory_kv):
449+
batch, num_mem_kv = q.shape[0], memory_kv.shape[-2]
450+
451+
mk, mv = memory_kv
452+
mk, mv = tuple(repeat(t, 'h m d -> b h m d', b = batch) for t in (mk, mv))
453+
k = torch.cat((mk, k), dim = -2)
454+
v = torchc.at((mv, v), dim = -2)
455+
456+
if exists(attn_bias):
457+
attn_bias = pad_at_dim(attn_bias, (num_mem_kv, 0), value = 0.)
458+
459+
if exists(mask):
460+
mask = pad_at_dim(mask, (num_mem_kv, 0), value = True)
461+
417462
# forward to using flash attention if applicable
418463

419464
can_use_flash = self.flash and not exists(attn_bias), 'flash attention does not support attention bias with gradients'

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.1.11"
3+
version = "0.1.12"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

tests/test_af3.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,9 @@ def test_diffusion_module():
262262
atom_encoder_depth = 1,
263263
atom_decoder_depth = 1,
264264
token_transformer_depth = 1,
265+
atom_encoder_kwargs = dict(
266+
attn_num_memory_kv = 2
267+
),
265268
token_transformer_kwargs = dict(
266269
num_register_tokens = 2
267270
)

0 commit comments

Comments
 (0)