Skip to content

Commit 1be7cf4

Browse files
committed
address #90 and just do is_molecule_types correctly
1 parent 09ca3ba commit 1be7cf4

File tree

5 files changed

+38
-20
lines changed

5 files changed

+38
-20
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ atompair_inputs = torch.randn(2, atom_seq_len, atom_seq_len, 5)
6262

6363
additional_molecule_feats = torch.randint(0, 2, (2, seq_len, 5))
6464
additional_token_feats = torch.randn(2, seq_len, 2)
65-
is_molecule_types = torch.randint(0, 2, (2, seq_len, 4)).bool()
65+
is_molecule_types = torch.randint(0, 2, (2, seq_len, 5)).bool()
6666
molecule_ids = torch.randint(0, 32, (2, seq_len))
6767

6868
template_feats = torch.randn(2, 2, seq_len, seq_len, 44)

alphafold3_pytorch/alphafold3.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@
4040

4141
from alphafold3_pytorch.inputs import (
4242
IS_MOLECULE_TYPES,
43+
IS_PROTEIN_INDEX,
44+
IS_LIGAND_INDEX,
45+
IS_BIOMOLECULE_INDICES,
4346
ADDITIONAL_MOLECULE_FEATS
4447
)
4548

@@ -104,12 +107,13 @@
104107
"""
105108

106109
"""
107-
is_molecule_types: [*, 4]
110+
is_molecule_types: [*, 5]
108111
109112
0: is_protein
110113
1: is_rna
111114
2: is_dna
112115
3: is_ligand
116+
4: is_metal_ions_or_misc
113117
"""
114118

115119
# constants
@@ -2269,7 +2273,7 @@ def forward(
22692273
is_nucleotide_or_ligand_fields = tuple(repeat_consecutive_with_lens(t, molecule_atom_lens) for t in is_nucleotide_or_ligand_fields)
22702274
is_nucleotide_or_ligand_fields = tuple(pad_or_slice_to(t, length = align_weights.shape[-1], dim = -1) for t in is_nucleotide_or_ligand_fields)
22712275

2272-
_, atom_is_dna, atom_is_rna, atom_is_ligand = is_nucleotide_or_ligand_fields
2276+
_, atom_is_dna, atom_is_rna, atom_is_ligand, _ = is_nucleotide_or_ligand_fields
22732277

22742278
# section 3.7.1 equation 4
22752279

@@ -3493,7 +3497,7 @@ def forward(
34933497
# only apply relative positional encodings to biomolecules that are chained
34943498
# not to ligands + metal ions
34953499

3496-
is_chained_biomol = is_molecule_types[..., :3].any(dim = -1) # first three types are chained biomolecules (protein, rna, dna)
3500+
is_chained_biomol = is_molecule_types[..., IS_BIOMOLECULE_INDICES].any(dim = -1) # first three types are chained biomolecules (protein, rna, dna)
34973501
paired_is_chained_biomol = einx.logical_and('b i, b j -> b i j', is_chained_biomol, is_chained_biomol)
34983502

34993503
relative_position_encoding = einx.where(
@@ -3531,7 +3535,7 @@ def forward(
35313535
# prepare mask for msa module and template embedder
35323536
# which is equivalent to the `is_protein` of the `is_molecular_types` input
35333537

3534-
is_protein_mask = is_molecule_types[..., 0]
3538+
is_protein_mask = is_molecule_types[..., IS_PROTEIN_INDEX]
35353539

35363540
# init recycled single and pairwise
35373541

alphafold3_pytorch/inputs.py

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,12 @@
4747

4848
# constants
4949

50-
IS_MOLECULE_TYPES = 4
50+
IS_MOLECULE_TYPES = 5
51+
IS_PROTEIN_INDEX = 0
52+
IS_LIGAND_INDEX = -2
53+
IS_METAL_ION_INDEX = -1
54+
IS_BIOMOLECULE_INDICES = slice(0, 3)
55+
5156
ADDITIONAL_MOLECULE_FEATS = 5
5257

5358
CCD_COMPONENTS_FILEPATH = os.path.join('data', 'ccd_data', 'components.cif')
@@ -243,7 +248,7 @@ def molecule_to_atom_input(mol_input: MoleculeInput) -> AtomInput:
243248
if not exists(atom_lens):
244249
atom_lens = []
245250

246-
for mol, is_ligand in zip(molecules, i.is_molecule_types[:, -1]):
251+
for mol, is_ligand in zip(molecules, i.is_molecule_types[:, IS_LIGAND_INDEX]):
247252
num_atoms = mol.GetNumAtoms()
248253

249254
if is_ligand:
@@ -347,7 +352,7 @@ def molecule_to_atom_input(mol_input: MoleculeInput) -> AtomInput:
347352
asym_ids = F.pad(asym_ids, (1, 0), value=-1)
348353
is_first_mol_in_chains = (asym_ids[1:] - asym_ids[:-1]) == 1
349354

350-
is_chainable_biomolecules = i.is_molecule_types[..., :3].any(dim=-1)
355+
is_chainable_biomolecules = i.is_molecule_types[..., IS_BIOMOLECULE_INDICES].any(dim=-1)
351356

352357
# for every molecule, build the bonds id matrix and add to `atompair_ids`
353358

@@ -746,9 +751,10 @@ def alphafold3_input_to_molecule_input(alphafold3_input: Alphafold3Input) -> Mol
746751
len(all_rna_mols),
747752
len(all_dna_mols),
748753
total_ligand_tokens,
754+
num_metal_ions
749755
]
750756

751-
num_tokens = sum(molecule_type_token_lens) + num_metal_ions
757+
num_tokens = sum(molecule_type_token_lens)
752758

753759
assert num_tokens > 0, "you have an empty alphafold3 input"
754760

@@ -757,7 +763,6 @@ def alphafold3_input_to_molecule_input(alphafold3_input: Alphafold3Input) -> Mol
757763
molecule_types_lens_cumsum = tensor([0, *molecule_type_token_lens]).cumsum(dim=-1)
758764
left, right = molecule_types_lens_cumsum[:-1], molecule_types_lens_cumsum[1:]
759765

760-
# TODO: fix bug that may leave molecules with no assigned type
761766
is_molecule_types = (arange >= left) & (arange < right)
762767

763768
# all molecules, layout is
@@ -950,7 +955,7 @@ def get_num_atoms_per_chain(chains: List[List[Mol]]) -> List[int]:
950955
for mol_index, (mol_miss_atom_indices, mol) in enumerate(
951956
zip(i.missing_atom_indices, molecules)
952957
):
953-
is_ligand_residue = is_molecule_types[mol_index, -1].item()
958+
is_ligand_residue = is_molecule_types[mol_index, IS_LIGAND_INDEX].item()
954959
mol_miss_atom_indices = default(mol_miss_atom_indices, [])
955960
mol_miss_atom_indices = tensor(mol_miss_atom_indices, dtype=torch.long)
956961

@@ -1427,9 +1432,9 @@ def pdb_input_to_molecule_input(pdb_input: PDBInput, training: bool = True) -> M
14271432
molecule_ids = torch.from_numpy(biomol.restype)
14281433

14291434
# retrieve is_molecule_types from the `Biomolecule` object, which is a boolean tensor of shape [*, 4]
1430-
# is_protein | is_rna | is_dna | is_ligand
1435+
# is_protein | is_rna | is_dna | is_ligand | is_metal_ion
14311436
# this is needed for their special diffusion loss
1432-
n_one_hot = 4
1437+
n_one_hot = 5
14331438
is_molecule_types = F.one_hot(torch.from_numpy(biomol.chemtype), num_classes=n_one_hot).bool()
14341439

14351440
# manually derive remaining features using the `Biomolecule` object
@@ -1464,11 +1469,20 @@ def pdb_input_to_molecule_input(pdb_input: PDBInput, training: bool = True) -> M
14641469
molecule_atom_types.extend([mol_type] * num_atoms)
14651470
# ensure modified polymer residues are one-hot encoded as ligands
14661471
# TODO: double-check whether this handling of modified polymer residues makes sense
1467-
is_molecule_types[molecule_idx : molecule_idx + num_atoms, : n_one_hot - 1] = False
1468-
is_molecule_types[molecule_idx : molecule_idx + num_atoms, n_one_hot - 1] = True
1472+
1473+
molecule_type_row_idx = slice(molecule_idx, molecule_idx + num_atoms)
1474+
1475+
is_molecule_types[molecule_type_row_idx, IS_BIOMOLECULE_INDICES] = False
1476+
14691477
if num_atoms == 1:
14701478
# NOTE: we manually set the molecule ID of ions to the `gap` ID
1471-
molecule_ids[molecule_idx] = gap_id
1479+
molecule_ids[molecule_type_row_idx] = gap_id
1480+
is_mol_type_index = IS_METAL_ION_INDEX
1481+
else:
1482+
is_mol_type_index = IS_LIGAND_INDEX
1483+
1484+
is_molecule_types[molecule_type_row_idx, is_mol_type_index] = True
1485+
14721486
molecule_idx += num_atoms
14731487
else:
14741488
token_pool_lens.append(num_atoms)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.2.0"
3+
version = "0.2.1"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

tests/test_af3.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -444,7 +444,7 @@ def test_alphafold3(
444444

445445
additional_molecule_feats = torch.randint(0, 2, (2, seq_len, 5))
446446
additional_token_feats = torch.randn(2, 16, 2)
447-
is_molecule_types = torch.randint(0, 2, (2, seq_len, 4)).bool()
447+
is_molecule_types = torch.randint(0, 2, (2, seq_len, 5)).bool()
448448
molecule_ids = torch.randint(0, 32, (2, seq_len))
449449

450450
is_molecule_mod = None
@@ -556,7 +556,7 @@ def test_alphafold3_without_msa_and_templates():
556556
atompair_inputs = torch.randn(2, atom_seq_len, atom_seq_len, 5)
557557
additional_molecule_feats = torch.randint(0, 2, (2, seq_len, 5))
558558
additional_token_feats = torch.randn(2, seq_len, 2)
559-
is_molecule_types = torch.randint(0, 2, (2, seq_len, 4)).bool()
559+
is_molecule_types = torch.randint(0, 2, (2, seq_len, 5)).bool()
560560
molecule_ids = torch.randint(0, 32, (2, seq_len))
561561

562562
atom_pos = torch.randn(2, atom_seq_len, 3)
@@ -716,7 +716,7 @@ def test_alphafold3_with_atom_and_bond_embeddings():
716716

717717
additional_molecule_feats = torch.randint(0, 2, (2, seq_len, 5))
718718
additional_token_feats = torch.randn(2, seq_len, 2)
719-
is_molecule_types = torch.randint(0, 2, (2, seq_len, 4)).bool()
719+
is_molecule_types = torch.randint(0, 2, (2, seq_len, 5)).bool()
720720
molecule_ids = torch.randint(0, 32, (2, seq_len))
721721

722722
template_feats = torch.randn(2, 2, seq_len, seq_len, 44)

0 commit comments

Comments
 (0)