Skip to content

Commit 05bc079

Browse files
committed
token bond should probably be an embedding, for special bonds like disulfides.. fix later, but change to float for now
1 parent 0efd7f7 commit 05bc079

File tree

3 files changed

+10
-8
lines changed

3 files changed

+10
-8
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2804,7 +2804,7 @@ def forward(
28042804
additional_residue_feats: Float['b n 10'],
28052805
residue_atom_lens: Int['b n'] | None = None,
28062806
atom_mask: Bool['b m'] | None = None,
2807-
token_bond: Bool['b n n'] | None = None,
2807+
token_bond: Float['b n n'] | None = None,
28082808
msa: Float['b s n d'] | None = None,
28092809
msa_mask: Bool['b s'] | None = None,
28102810
templates: Float['b t n n dt'] | None = None,
@@ -2882,14 +2882,14 @@ def forward(
28822882
# (1) mask out diagonal - token to itself does not count as a bond
28832883
# (2) symmetrize, in case it is not already symmetrical (could also throw an error)
28842884

2885-
token_bond = token_bond | rearrange(token_bond, 'b i j -> b j i')
2885+
assert torch.allclose(token_bond, rearrange(token_bond, 'b i j -> b j i')), 'token bond must be symmetrical'
28862886
diagonal = torch.eye(seq_len, device = self.device, dtype = torch.bool)
2887-
token_bond.masked_fill_(diagonal, False)
2887+
token_bond.masked_fill_(diagonal, 0.)
28882888
else:
28892889
seq_arange = torch.arange(seq_len, device = self.device)
2890-
token_bond = einx.subtract('i, j -> i j', seq_arange, seq_arange).abs() == 1
2890+
token_bond = (einx.subtract('i, j -> i j', seq_arange, seq_arange).abs() == 1).float()
28912891

2892-
token_bond_feats = self.token_bond_to_pairwise_feat(token_bond.float())
2892+
token_bond_feats = self.token_bond_to_pairwise_feat(token_bond)
28932893

28942894
pairwise_init = pairwise_init + token_bond_feats
28952895

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.0.28"
3+
version = "0.0.29"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

tests/test_af3.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -366,7 +366,8 @@ def test_alphafold3():
366366
seq_len = 16
367367
atom_seq_len = seq_len * 27
368368

369-
token_bond = torch.randint(0, 2, (2, seq_len, seq_len)).bool()
369+
token_bond = torch.randint(0, 2, (2, seq_len, seq_len)).float()
370+
token_bond = token_bond + token_bond.transpose(-1, -2)
370371

371372
atom_inputs = torch.randn(2, atom_seq_len, 77)
372373
atom_mask = torch.ones((2, atom_seq_len)).bool()
@@ -512,7 +513,8 @@ def test_alphafold3_with_packed_atom_repr():
512513

513514
atom_seq_len = residue_atom_lens.sum(dim = -1).amax()
514515

515-
token_bond = torch.randint(0, 2, (2, seq_len, seq_len)).bool()
516+
token_bond = torch.randint(0, 2, (2, seq_len, seq_len)).float()
517+
token_bond = token_bond + token_bond.transpose(-1, -2)
516518

517519
atom_inputs = torch.randn(2, atom_seq_len, 77)
518520

0 commit comments

Comments
 (0)