Skip to content

Commit 64fc0b1

Browse files
committed
able to restrict atom attention to intramolecular in atom encoder / decoder, so in the case of sequence-local attention, ligand is not like erroneously conducting attention with one end of some polypeptide
1 parent 5812d8d commit 64fc0b1

File tree

4 files changed

+57
-6
lines changed

4 files changed

+57
-6
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1555,7 +1555,8 @@ def forward(
15551555
*,
15561556
single_repr: Float['b n ds'],
15571557
pairwise_repr: Float['b n n dp'] | Float['b nw w (w*2) dp'],
1558-
mask: Bool['b n'] | None = None
1558+
mask: Bool['b n'] | None = None,
1559+
windowed_mask: Bool['b nw w (w*2)'] | None = None
15591560
):
15601561
w = self.attn_window_size
15611562
has_windows = exists(w)
@@ -1596,7 +1597,8 @@ def forward(
15961597
noised_repr,
15971598
cond = single_repr,
15981599
pairwise_repr = pairwise_repr,
1599-
mask = mask
1600+
mask = mask,
1601+
windowed_mask = windowed_mask
16001602
)
16011603

16021604
if serial:
@@ -1806,7 +1808,8 @@ def forward(
18061808
single_inputs_repr: Float['b n dsi'],
18071809
pairwise_trunk: Float['b n n dpt'],
18081810
pairwise_rel_pos_feats: Float['b n n dpr'],
1809-
molecule_atom_lens: Int['b n']
1811+
molecule_atom_lens: Int['b n'],
1812+
atom_parent_ids: Int['b m'] | None = None
18101813
):
18111814
w = self.atoms_per_window
18121815
device = noised_atom_pos.device
@@ -1887,11 +1890,22 @@ def forward(
18871890

18881891
atompair_feats = self.atompair_feats_mlp(atompair_feats) + atompair_feats
18891892

1893+
# take care of restricting atom attention to be intra molecular, if the atom_parent_ids were passed in
1894+
1895+
windowed_mask = None
1896+
1897+
if exists(atom_parent_ids):
1898+
atom_parent_ids_rows = pad_and_window(atom_parent_ids, w)
1899+
atom_parent_ids_columns = concat_previous_window(atom_parent_ids_rows, dim_seq = 1, dim_window = 2)
1900+
1901+
windowed_mask = einx.equal('b n i, b n j -> b n i j', atom_parent_ids_rows, atom_parent_ids_columns)
1902+
18901903
# atom encoder
18911904

18921905
atom_feats = self.atom_encoder(
18931906
atom_feats,
18941907
mask = atom_mask,
1908+
windowed_mask = windowed_mask,
18951909
single_repr = atom_feats_cond,
18961910
pairwise_repr = atompair_feats
18971911
)
@@ -1929,6 +1943,7 @@ def forward(
19291943
atom_feats = self.atom_decoder(
19301944
atom_decoder_input,
19311945
mask = atom_mask,
1946+
windowed_mask = windowed_mask,
19321947
single_repr = atom_feats_cond,
19331948
pairwise_repr = atompair_feats
19341949
)
@@ -2154,6 +2169,7 @@ def forward(
21542169
pairwise_trunk: Float['b n n dpt'],
21552170
pairwise_rel_pos_feats: Float['b n n dpr'],
21562171
molecule_atom_lens: Int['b n'],
2172+
atom_parent_ids: Int['b m'] | None = None,
21572173
return_denoised_pos = False,
21582174
additional_molecule_feats: Float[f'b n {ADDITIONAL_MOLECULE_FEATS}'] | None = None,
21592175
add_smooth_lddt_loss = False,
@@ -2181,6 +2197,7 @@ def forward(
21812197
atom_feats = atom_feats,
21822198
atom_mask = atom_mask,
21832199
atompair_feats = atompair_feats,
2200+
atom_parent_ids = atom_parent_ids,
21842201
mask = mask,
21852202
single_trunk_repr = single_trunk_repr,
21862203
single_inputs_repr = single_inputs_repr,
@@ -3222,6 +3239,7 @@ def forward(
32223239
atom_ids: Int['b m'] | None = None,
32233240
atompair_ids: Int['b m m'] | Int['b nw w1 w2'] | None = None,
32243241
atom_mask: Bool['b m'] | None = None,
3242+
atom_parent_ids: Int['b m'] | None = None,
32253243
token_bonds: Bool['b n n'] | None = None,
32263244
msa: Float['b s n d'] | None = None,
32273245
msa_mask: Bool['b s'] | None = None,
@@ -3426,6 +3444,7 @@ def forward(
34263444
num_sample_steps = num_sample_steps,
34273445
atom_feats = atom_feats,
34283446
atompair_feats = atompair_feats,
3447+
atom_parent_ids = atom_parent_ids,
34293448
atom_mask = atom_mask,
34303449
mask = mask,
34313450
single_trunk_repr = single,
@@ -3483,6 +3502,7 @@ def forward(
34833502
atom_pos,
34843503
atom_mask,
34853504
atom_feats,
3505+
atom_parent_ids,
34863506
atompair_feats,
34873507
mask,
34883508
pairwise_mask,
@@ -3504,6 +3524,7 @@ def forward(
35043524
atom_pos,
35053525
atom_mask,
35063526
atom_feats,
3527+
atom_parent_ids,
35073528
atompair_feats,
35083529
mask,
35093530
pairwise_mask,
@@ -3547,6 +3568,7 @@ def forward(
35473568
add_bond_loss = diffusion_add_bond_loss,
35483569
atom_feats = atom_feats,
35493570
atompair_feats = atompair_feats,
3571+
atom_parent_ids = atom_parent_ids,
35503572
atom_mask = atom_mask,
35513573
mask = mask,
35523574
single_trunk_repr = single,

alphafold3_pytorch/attention.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,7 @@ def forward(
222222
seq: Float['b i d'],
223223
mask: Bool['b n']| None = None,
224224
context: Float['b j d'] | None = None,
225+
windowed_mask: Bool['b nw w (w*2)'] | None = None,
225226
attn_bias: Float['... i j'] | Float['... nw w (w*2)'] | None = None
226227

227228
) -> Float['b i d']:
@@ -239,6 +240,7 @@ def forward(
239240
q, k, v,
240241
attn_bias = attn_bias,
241242
mask = mask,
243+
windowed_mask = windowed_mask,
242244
memory_kv = self.memory_kv
243245
)
244246

@@ -324,6 +326,7 @@ def local_attn(
324326
k: Float['b h n d'],
325327
v: Float['b h n d'],
326328
mask: Bool['b n'] | None = None,
329+
windowed_mask: Bool['b nw w (w*2)'] | None = None,
327330
attn_bias: Float['... n n'] | Float['... nw w (w*2)'] | None = None,
328331
memory_kv: Float['2 h m d'] | None = None
329332
) -> Float['b h n d']:
@@ -386,6 +389,9 @@ def local_attn(
386389
if exists(attn_bias):
387390
attn_bias = pad_at_dim(attn_bias, (num_mem_kv, 0), value = 0.)
388391

392+
if exists(windowed_mask):
393+
windowed_mask = pad_at_dim(windowed_mask, (num_mem_kv, 0), value = True)
394+
389395
if exists(mask):
390396
mask = pad_at_dim(mask, (num_mem_kv, 0), value = True)
391397

@@ -400,13 +406,27 @@ def local_attn(
400406
assert attn_bias.ndim == sim.ndim
401407
sim = sim + attn_bias
402408

409+
# windowed masking - for masking out atoms not belonging to the same molecule / polypeptide / nucleic acid in sequence-local attention
410+
411+
if exists(windowed_mask):
412+
sim = einx.where(
413+
'b n i j, b h n i j, -> b h n i j',
414+
windowed_mask, sim, max_neg_value(sim)
415+
)
416+
417+
# mask out buckets of padding
418+
403419
sim = einx.where(
404420
'b n j, b h n i j, -> b h n i j',
405421
mask, sim, max_neg_value(sim)
406422
)
407423

424+
# local attention
425+
408426
attn = sim.softmax(dim = -1)
409427

428+
# aggregate
429+
410430
out = einsum(attn, v, "... i j, ... j d -> ... i d")
411431

412432
# un-window the output
@@ -426,6 +446,7 @@ def forward(
426446
k: Float['b h j d'],
427447
v: Float['b h j d'],
428448
mask: Bool['b j'] | None = None,
449+
windowed_mask: Bool['b nw w (w*2)'] | None = None,
429450
attn_bias: Float['... i j'] | Float['... nw w (w*2)'] | None = None,
430451
memory_kv: Float['2 h m d'] | None = None
431452
) -> Float['b h i d']:
@@ -439,7 +460,7 @@ def forward(
439460
# todo (handle attn bias efficiently)
440461

441462
if self.is_local_attn:
442-
return self.local_attn(q, k, v, mask = mask, attn_bias = attn_bias, memory_kv = memory_kv)
463+
return self.local_attn(q, k, v, mask = mask, windowed_mask = windowed_mask, attn_bias = attn_bias, memory_kv = memory_kv)
443464

444465
assert not exists(is_windowed_attn_bias) or not is_windowed_attn_bias
445466

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.1.64"
3+
version = "0.1.65"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

tests/test_af3.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -412,9 +412,11 @@ def test_distogram_head():
412412

413413
@pytest.mark.parametrize('window_atompair_inputs', (True, False))
414414
@pytest.mark.parametrize('stochastic_frame_average', (True, False))
415+
@pytest.mark.parametrize('atom_transformer_intramolecular_attn', (True, False))
415416
def test_alphafold3(
416417
window_atompair_inputs: bool,
417-
stochastic_frame_average: bool
418+
stochastic_frame_average: bool,
419+
atom_transformer_intramolecular_attn: bool
418420
):
419421
seq_len = 16
420422
atoms_per_window = 27
@@ -434,6 +436,11 @@ def test_alphafold3(
434436
additional_molecule_feats = torch.randn(2, seq_len, 9)
435437
molecule_ids = torch.randint(0, 32, (2, seq_len))
436438

439+
atom_parent_ids = None
440+
441+
if atom_transformer_intramolecular_attn:
442+
atom_parent_ids = torch.ones(2, atom_seq_len).long()
443+
437444
template_feats = torch.randn(2, 2, seq_len, seq_len, 44)
438445
template_mask = torch.ones((2, 2)).bool()
439446

@@ -478,6 +485,7 @@ def test_alphafold3(
478485
atom_inputs = atom_inputs,
479486
molecule_ids = molecule_ids,
480487
molecule_atom_lens = molecule_atom_lens,
488+
atom_parent_ids = atom_parent_ids,
481489
atompair_inputs = atompair_inputs,
482490
additional_molecule_feats = additional_molecule_feats,
483491
token_bonds = token_bonds,

0 commit comments

Comments
 (0)