2929)
3030from alphafold3_pytorch .data import mmcif_parsing
3131from alphafold3_pytorch .data .data_pipeline import get_assembly
32+
3233from alphafold3_pytorch .life import (
3334 ATOM_BONDS ,
3435 ATOMS ,
4041 reverse_complement ,
4142 reverse_complement_tensor ,
4243)
44+
45+
4346from alphafold3_pytorch .tensor_typing import Bool , Float , Int , typecheck
4447from alphafold3_pytorch .utils .data_utils import RESIDUE_MOLECULE_TYPE , get_residue_molecule_type
4548from alphafold3_pytorch .utils .model_utils import exclusive_cumsum
5356IS_METAL_ION_INDEX = - 1
5457IS_BIOMOLECULE_INDICES = slice (0 , 3 )
5558
59+ MOLECULE_GAP_ID = len (HUMAN_AMINO_ACIDS ) + len (RNA_NUCLEOTIDES ) + len (DNA_NUCLEOTIDES ) - 1
60+ MOLECULE_METAL_ION_ID = MOLECULE_GAP_ID + 1
61+ NUM_MOLECULE_IDS = len (HUMAN_AMINO_ACIDS ) + len (RNA_NUCLEOTIDES ) + len (DNA_NUCLEOTIDES ) + 2
62+
5663ADDITIONAL_MOLECULE_FEATS = 5
5764
5865CCD_COMPONENTS_FILEPATH = os .path .join ('data' , 'ccd_data' , 'components.cif' )
@@ -621,9 +628,7 @@ def alphafold3_input_to_molecule_input(alphafold3_input: Alphafold3Input) -> Mol
621628
622629 rna_offset = len (HUMAN_AMINO_ACIDS )
623630 dna_offset = len (RNA_NUCLEOTIDES ) + rna_offset
624-
625631 ligand_id = len (HUMAN_AMINO_ACIDS ) - 1
626- gap_id = len (DNA_NUCLEOTIDES ) + dna_offset
627632
628633 molecule_ids = []
629634
@@ -703,7 +708,7 @@ def alphafold3_input_to_molecule_input(alphafold3_input: Alphafold3Input) -> Mol
703708 metal_ions = alphafold3_input .metal_ions
704709 mol_metal_ions = map_int_or_string_indices_to_mol (METALS , metal_ions )
705710
706- molecule_ids .append (tensor ([gap_id ] * len (mol_metal_ions )))
711+ molecule_ids .append (tensor ([MOLECULE_METAL_ION_ID ] * len (mol_metal_ions )))
707712
708713 # convert ligands to rdchem.Mol
709714
@@ -1434,7 +1439,7 @@ def pdb_input_to_molecule_input(pdb_input: PDBInput, training: bool = True) -> M
14341439 # retrieve is_molecule_types from the `Biomolecule` object, which is a boolean tensor of shape [*, 4]
14351440 # is_protein | is_rna | is_dna | is_ligand | is_metal_ion
14361441 # this is needed for their special diffusion loss
1437- n_one_hot = 5
1442+ n_one_hot = IS_MOLECULE_TYPES
14381443 is_molecule_types = F .one_hot (torch .from_numpy (biomol .chemtype ), num_classes = n_one_hot ).bool ()
14391444
14401445 # manually derive remaining features using the `Biomolecule` object
@@ -1460,7 +1465,7 @@ def pdb_input_to_molecule_input(pdb_input: PDBInput, training: bool = True) -> M
14601465 molecule_idx = 0
14611466 token_pool_lens = []
14621467 molecule_atom_types = []
1463- gap_id = len ( HUMAN_AMINO_ACIDS ) + len ( RNA_NUCLEOTIDES ) + len ( DNA_NUCLEOTIDES )
1468+
14641469 for mol , mol_type in zip (molecules , molecule_types ):
14651470 num_atoms = mol .GetNumAtoms ()
14661471 if mol_type == "ligand" :
@@ -1476,7 +1481,7 @@ def pdb_input_to_molecule_input(pdb_input: PDBInput, training: bool = True) -> M
14761481
14771482 if num_atoms == 1 :
14781483 # NOTE: we manually set the molecule ID of ions to the `gap` ID
1479- molecule_ids [molecule_type_row_idx ] = gap_id
1484+ molecule_ids [molecule_idx ] = MOLECULE_METAL_ION_ID
14801485 is_mol_type_index = IS_METAL_ION_INDEX
14811486 else :
14821487 is_mol_type_index = IS_LIGAND_INDEX
0 commit comments