Skip to content

Commit e5fc64b

Browse files
committed
Revert "token bond should probably be an embedding, for special bonds like disulfides.. fix later, but change to float for now"
This reverts commit 05bc079.
1 parent 05bc079 commit e5fc64b

File tree

3 files changed

+8
-10
lines changed

3 files changed

+8
-10
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2804,7 +2804,7 @@ def forward(
28042804
additional_residue_feats: Float['b n 10'],
28052805
residue_atom_lens: Int['b n'] | None = None,
28062806
atom_mask: Bool['b m'] | None = None,
2807-
token_bond: Float['b n n'] | None = None,
2807+
token_bond: Bool['b n n'] | None = None,
28082808
msa: Float['b s n d'] | None = None,
28092809
msa_mask: Bool['b s'] | None = None,
28102810
templates: Float['b t n n dt'] | None = None,
@@ -2882,14 +2882,14 @@ def forward(
28822882
# (1) mask out diagonal - token to itself does not count as a bond
28832883
# (2) symmetrize, in case it is not already symmetrical (could also throw an error)
28842884

2885-
assert torch.allclose(token_bond, rearrange(token_bond, 'b i j -> b j i')), 'token bond must be symmetrical'
2885+
token_bond = token_bond | rearrange(token_bond, 'b i j -> b j i')
28862886
diagonal = torch.eye(seq_len, device = self.device, dtype = torch.bool)
2887-
token_bond.masked_fill_(diagonal, 0.)
2887+
token_bond.masked_fill_(diagonal, False)
28882888
else:
28892889
seq_arange = torch.arange(seq_len, device = self.device)
2890-
token_bond = (einx.subtract('i, j -> i j', seq_arange, seq_arange).abs() == 1).float()
2890+
token_bond = einx.subtract('i, j -> i j', seq_arange, seq_arange).abs() == 1
28912891

2892-
token_bond_feats = self.token_bond_to_pairwise_feat(token_bond)
2892+
token_bond_feats = self.token_bond_to_pairwise_feat(token_bond.float())
28932893

28942894
pairwise_init = pairwise_init + token_bond_feats
28952895

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.0.29"
3+
version = "0.0.28"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

tests/test_af3.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -366,8 +366,7 @@ def test_alphafold3():
366366
seq_len = 16
367367
atom_seq_len = seq_len * 27
368368

369-
token_bond = torch.randint(0, 2, (2, seq_len, seq_len)).float()
370-
token_bond = token_bond + token_bond.transpose(-1, -2)
369+
token_bond = torch.randint(0, 2, (2, seq_len, seq_len)).bool()
371370

372371
atom_inputs = torch.randn(2, atom_seq_len, 77)
373372
atom_mask = torch.ones((2, atom_seq_len)).bool()
@@ -513,8 +512,7 @@ def test_alphafold3_with_packed_atom_repr():
513512

514513
atom_seq_len = residue_atom_lens.sum(dim = -1).amax()
515514

516-
token_bond = torch.randint(0, 2, (2, seq_len, seq_len)).float()
517-
token_bond = token_bond + token_bond.transpose(-1, -2)
515+
token_bond = torch.randint(0, 2, (2, seq_len, seq_len)).bool()
518516

519517
atom_inputs = torch.randn(2, atom_seq_len, 77)
520518

0 commit comments

Comments
 (0)