Skip to content

Commit f64f81c

Browse files
committed
just do the molecule ids correctly, give metal ion its own id
1 parent 1be7cf4 commit f64f81c

File tree

3 files changed

+15
-9
lines changed

3 files changed

+15
-9
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
IS_PROTEIN_INDEX,
4444
IS_LIGAND_INDEX,
4545
IS_BIOMOLECULE_INDICES,
46+
NUM_MOLECULE_IDS,
4647
ADDITIONAL_MOLECULE_FEATS
4748
)
4849

@@ -3001,8 +3002,8 @@ def __init__(
30013002
dim_single = 384,
30023003
dim_pairwise = 128,
30033004
dim_token = 768,
3004-
dim_additional_token_feats = 2, # in paper, they include two meta information per token (f_profile, f_deletion_mean)
3005-
num_molecule_types: int = 32, # restype in additional residue information, apparently 32 (must be human amino acids + nucleotides + something else)
3005+
dim_additional_token_feats = 2, # in paper, they include two meta information per token (f_profile, f_deletion_mean)
3006+
num_molecule_types: int = NUM_MOLECULE_IDS, # restype in additional residue information, apparently 32. will do 33 to account for metal ions
30063007
num_atom_embeds: int | None = None,
30073008
num_atompair_embeds: int | None = None,
30083009
num_molecule_mods: int | None = None,

alphafold3_pytorch/inputs.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
)
3030
from alphafold3_pytorch.data import mmcif_parsing
3131
from alphafold3_pytorch.data.data_pipeline import get_assembly
32+
3233
from alphafold3_pytorch.life import (
3334
ATOM_BONDS,
3435
ATOMS,
@@ -40,6 +41,8 @@
4041
reverse_complement,
4142
reverse_complement_tensor,
4243
)
44+
45+
4346
from alphafold3_pytorch.tensor_typing import Bool, Float, Int, typecheck
4447
from alphafold3_pytorch.utils.data_utils import RESIDUE_MOLECULE_TYPE, get_residue_molecule_type
4548
from alphafold3_pytorch.utils.model_utils import exclusive_cumsum
@@ -53,6 +56,10 @@
5356
IS_METAL_ION_INDEX = -1
5457
IS_BIOMOLECULE_INDICES = slice(0, 3)
5558

59+
MOLECULE_GAP_ID = len(HUMAN_AMINO_ACIDS) + len(RNA_NUCLEOTIDES) + len(DNA_NUCLEOTIDES) - 1
60+
MOLECULE_METAL_ION_ID = MOLECULE_GAP_ID + 1
61+
NUM_MOLECULE_IDS = len(HUMAN_AMINO_ACIDS) + len(RNA_NUCLEOTIDES) + len(DNA_NUCLEOTIDES) + 2
62+
5663
ADDITIONAL_MOLECULE_FEATS = 5
5764

5865
CCD_COMPONENTS_FILEPATH = os.path.join('data', 'ccd_data', 'components.cif')
@@ -621,9 +628,7 @@ def alphafold3_input_to_molecule_input(alphafold3_input: Alphafold3Input) -> Mol
621628

622629
rna_offset = len(HUMAN_AMINO_ACIDS)
623630
dna_offset = len(RNA_NUCLEOTIDES) + rna_offset
624-
625631
ligand_id = len(HUMAN_AMINO_ACIDS) - 1
626-
gap_id = len(DNA_NUCLEOTIDES) + dna_offset
627632

628633
molecule_ids = []
629634

@@ -703,7 +708,7 @@ def alphafold3_input_to_molecule_input(alphafold3_input: Alphafold3Input) -> Mol
703708
metal_ions = alphafold3_input.metal_ions
704709
mol_metal_ions = map_int_or_string_indices_to_mol(METALS, metal_ions)
705710

706-
molecule_ids.append(tensor([gap_id] * len(mol_metal_ions)))
711+
molecule_ids.append(tensor([MOLECULE_METAL_ION_ID] * len(mol_metal_ions)))
707712

708713
# convert ligands to rdchem.Mol
709714

@@ -1434,7 +1439,7 @@ def pdb_input_to_molecule_input(pdb_input: PDBInput, training: bool = True) -> M
14341439
# retrieve is_molecule_types from the `Biomolecule` object, which is a boolean tensor of shape [*, 4]
14351440
# is_protein | is_rna | is_dna | is_ligand | is_metal_ion
14361441
# this is needed for their special diffusion loss
1437-
n_one_hot = 5
1442+
n_one_hot = IS_MOLECULE_TYPES
14381443
is_molecule_types = F.one_hot(torch.from_numpy(biomol.chemtype), num_classes=n_one_hot).bool()
14391444

14401445
# manually derive remaining features using the `Biomolecule` object
@@ -1460,7 +1465,7 @@ def pdb_input_to_molecule_input(pdb_input: PDBInput, training: bool = True) -> M
14601465
molecule_idx = 0
14611466
token_pool_lens = []
14621467
molecule_atom_types = []
1463-
gap_id = len(HUMAN_AMINO_ACIDS) + len(RNA_NUCLEOTIDES) + len(DNA_NUCLEOTIDES)
1468+
14641469
for mol, mol_type in zip(molecules, molecule_types):
14651470
num_atoms = mol.GetNumAtoms()
14661471
if mol_type == "ligand":
@@ -1476,7 +1481,7 @@ def pdb_input_to_molecule_input(pdb_input: PDBInput, training: bool = True) -> M
14761481

14771482
if num_atoms == 1:
14781483
# NOTE: we manually set the molecule ID of ions to the `gap` ID
1479-
molecule_ids[molecule_type_row_idx] = gap_id
1484+
molecule_ids[molecule_idx] = MOLECULE_METAL_ION_ID
14801485
is_mol_type_index = IS_METAL_ION_INDEX
14811486
else:
14821487
is_mol_type_index = IS_LIGAND_INDEX

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.2.1"
3+
version = "0.2.2"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

0 commit comments

Comments
 (0)