Skip to content

Commit 91da5fb

Browse files
committed
add token_bonds to AtomInput (should be renamed molecular bonds at some point, except for the fact they treat each atom of the ligand as a token..)
1 parent 5451683 commit 91da5fb

File tree

5 files changed

+14
-10
lines changed

5 files changed

+14
-10
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3221,7 +3221,7 @@ def forward(
32213221
atom_ids: Int['b m'] | None = None,
32223222
atompair_ids: Int['b m m'] | Int['b nw w1 w2'] | None = None,
32233223
atom_mask: Bool['b m'] | None = None,
3224-
token_bond: Bool['b n n'] | None = None,
3224+
token_bonds: Bool['b n n'] | None = None,
32253225
msa: Float['b s n d'] | None = None,
32263226
msa_mask: Bool['b s'] | None = None,
32273227
templates: Float['b t n n dt'] | None = None,
@@ -3322,21 +3322,21 @@ def forward(
33223322

33233323
# token bond features
33243324

3325-
if exists(token_bond):
3325+
if exists(token_bonds):
33263326
# well do some precautionary standardization
33273327
# (1) mask out diagonal - token to itself does not count as a bond
33283328
# (2) symmetrize, in case it is not already symmetrical (could also throw an error)
33293329

3330-
token_bond = token_bond | rearrange(token_bond, 'b i j -> b j i')
3330+
token_bonds = token_bonds | rearrange(token_bonds, 'b i j -> b j i')
33313331
diagonal = torch.eye(seq_len, device = self.device, dtype = torch.bool)
3332-
token_bond = token_bond.masked_fill(diagonal, False)
3332+
token_bonds = token_bonds.masked_fill(diagonal, False)
33333333
else:
33343334
seq_arange = torch.arange(seq_len, device = self.device)
3335-
token_bond = einx.subtract('i, j -> i j', seq_arange, seq_arange).abs() == 1
3335+
token_bonds = einx.subtract('i, j -> i j', seq_arange, seq_arange).abs() == 1
33363336

3337-
token_bond_feats = self.token_bond_to_pairwise_feat(token_bond.float())
3337+
token_bonds_feats = self.token_bond_to_pairwise_feat(token_bonds.float())
33383338

3339-
pairwise_init = pairwise_init + token_bond_feats
3339+
pairwise_init = pairwise_init + token_bonds_feats
33403340

33413341
# molecule mask and pairwise mask
33423342

alphafold3_pytorch/inputs.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ class AtomInput(TypedDict):
1616
additional_molecule_feats: Float['n 9']
1717
templates: Float['t n n dt']
1818
msa: Float['s n dm']
19+
token_bonds: Bool['n n'] | None
1920
atom_ids: Int['m'] | None
2021
atompair_ids: Int['m m'] | Int['nw w (w*2)'] | None
2122
template_mask: Bool['t'] | None
@@ -36,6 +37,7 @@ class BatchedAtomInput(TypedDict):
3637
additional_molecule_feats: Float['b n 9']
3738
templates: Float['b t n n dt']
3839
msa: Float['b s n dm']
40+
token_bonds: Bool['b n n'] | None
3941
atom_ids: Int['b m'] | None
4042
atompair_ids: Int['b m m'] | Int['b nw w (w*2)'] | None
4143
template_mask: Bool['b t'] | None

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.1.60"
3+
version = "0.1.62"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

tests/test_af3.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -422,7 +422,7 @@ def test_alphafold3(
422422
molecule_atom_lens = torch.randint(1, 3, (2, seq_len))
423423
atom_seq_len = molecule_atom_lens.sum(dim = -1).amax()
424424

425-
token_bond = torch.randint(0, 2, (2, seq_len, seq_len)).bool()
425+
token_bonds = torch.randint(0, 2, (2, seq_len, seq_len)).bool()
426426

427427
atom_inputs = torch.randn(2, atom_seq_len, 77)
428428

@@ -480,7 +480,7 @@ def test_alphafold3(
480480
molecule_atom_lens = molecule_atom_lens,
481481
atompair_inputs = atompair_inputs,
482482
additional_molecule_feats = additional_molecule_feats,
483-
token_bond = token_bond,
483+
token_bonds = token_bonds,
484484
msa = msa,
485485
msa_mask = msa_mask,
486486
templates = template_feats,

tests/test_trainer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ def __getitem__(self, idx):
4848
molecule_atom_lens = torch.randint(1, self.atoms_per_window, (seq_len,))
4949
additional_molecule_feats = torch.randn(seq_len, 9)
5050
molecule_ids = torch.randint(0, 32, (seq_len,))
51+
token_bonds = torch.randint(0, 2, (seq_len, seq_len)).bool()
5152

5253
templates = torch.randn(2, seq_len, seq_len, 44)
5354
template_mask = torch.ones((2,)).bool()
@@ -73,6 +74,7 @@ def __getitem__(self, idx):
7374
atom_inputs = atom_inputs,
7475
atompair_inputs = atompair_inputs,
7576
molecule_ids = molecule_ids,
77+
token_bonds = token_bonds,
7678
molecule_atom_lens = molecule_atom_lens,
7779
additional_molecule_feats = additional_molecule_feats,
7880
templates = templates,

0 commit comments

Comments
 (0)