Skip to content

Commit b63a0d9

Browse files
committed
complete token bonds to spec
1 parent 20962b0 commit b63a0d9

File tree

3 files changed

+39
-3
lines changed

3 files changed

+39
-3
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2187,7 +2187,7 @@ def __init__(
21872187
self,
21882188
*,
21892189
dim_atom_inputs,
2190-
dim_additional_residue_feats,
2190+
dim_additional_residue_feats = 10,
21912191
atoms_per_window = 27,
21922192
dim_atom = 128,
21932193
dim_atompair = 16,
@@ -2558,6 +2558,14 @@ def __init__(
25582558
**relative_position_encoding_kwargs
25592559
)
25602560

2561+
# token bonds
2562+
# Algorithm 1 - line 5
2563+
2564+
self.token_bond_to_pairwise_feat = nn.Sequential(
2565+
Rearrange('... -> ... 1'),
2566+
LinearNoBias(1, dim_pairwise)
2567+
)
2568+
25612569
# templates
25622570

25632571
self.template_embedder = TemplateEmbedder(
@@ -2654,7 +2662,8 @@ def forward(
26542662
atom_inputs: Float['b m dai'],
26552663
atom_mask: Bool['b m'],
26562664
atompair_feats: Float['b m m dap'],
2657-
additional_residue_feats: Float['b n rf'],
2665+
additional_residue_feats: Float['b n 10'],
2666+
token_bond: Bool['b n n'] | None = None,
26582667
msa: Float['b s n d'] | None = None,
26592668
msa_mask: Bool['b s'] | None = None,
26602669
templates: Float['b t n n dt'] | None = None,
@@ -2673,7 +2682,13 @@ def forward(
26732682
return_loss_breakdown = False
26742683
) -> Float['b m 3'] | Float[''] | Tuple[Float[''], LossBreakdown]:
26752684

2685+
# get atom sequence length and residue sequence length
2686+
26762687
w = self.atoms_per_window
2688+
atom_seq_len = atom_inputs.shape[-2]
2689+
2690+
assert divisible_by(atom_seq_len, w)
2691+
seq_len = atom_inputs.shape[-2] // w
26772692

26782693
# embed inputs
26792694

@@ -2698,6 +2713,24 @@ def forward(
26982713

26992714
pairwise_init = pairwise_init + relative_position_encoding
27002715

2716+
# token bond features
2717+
2718+
if exists(token_bond):
2719+
# well do some precautionary standardization
2720+
# (1) mask out diagonal - token to itself does not count as a bond
2721+
# (2) symmetrize, in case it is not already symmetrical (could also throw an error)
2722+
2723+
token_bond = token_bond | rearrange(token_bond, 'b i j -> b j i')
2724+
diagonal = torch.eye(seq_len, device = self.device, dtype = torch.bool)
2725+
token_bond.masked_fill_(diagonal, False)
2726+
else:
2727+
seq_arange = torch.arange(seq_len, device = self.device)
2728+
token_bond = einx.subtract('i, j -> i j', seq_arange, seq_arange).abs() == 1
2729+
2730+
token_bond_feats = self.token_bond_to_pairwise_feat(token_bond.float())
2731+
2732+
pairwise_init = pairwise_init + token_bond_feats
2733+
27012734
# pairwise mask
27022735

27032736
mask = reduce(atom_mask, 'b (n w) -> b n', w = w, reduction = 'any')

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.0.22"
3+
version = "0.0.23"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

tests/test_af3.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,8 @@ def test_alphafold3():
368368
seq_len = 16
369369
atom_seq_len = seq_len * 27
370370

371+
token_bond = torch.randint(0, 2, (2, seq_len, seq_len)).bool()
372+
371373
atom_inputs = torch.randn(2, atom_seq_len, 77)
372374
atom_mask = torch.ones((2, atom_seq_len)).bool()
373375
atompair_feats = torch.randn(2, atom_seq_len, atom_seq_len, 16)
@@ -418,6 +420,7 @@ def test_alphafold3():
418420
atom_mask = atom_mask,
419421
atompair_feats = atompair_feats,
420422
additional_residue_feats = additional_residue_feats,
423+
token_bond = token_bond,
421424
msa = msa,
422425
msa_mask = msa_mask,
423426
templates = template_feats,

0 commit comments

Comments
 (0)