4747
4848# constants
4949
50- IS_MOLECULE_TYPES = 4
50+ IS_MOLECULE_TYPES = 5
51+ IS_PROTEIN_INDEX = 0
52+ IS_LIGAND_INDEX = - 2
53+ IS_METAL_ION_INDEX = - 1
54+ IS_BIOMOLECULE_INDICES = slice (0 , 3 )
55+
5156ADDITIONAL_MOLECULE_FEATS = 5
5257
5358CCD_COMPONENTS_FILEPATH = os .path .join ('data' , 'ccd_data' , 'components.cif' )
@@ -243,7 +248,7 @@ def molecule_to_atom_input(mol_input: MoleculeInput) -> AtomInput:
243248 if not exists (atom_lens ):
244249 atom_lens = []
245250
246- for mol , is_ligand in zip (molecules , i .is_molecule_types [:, - 1 ]):
251+ for mol , is_ligand in zip (molecules , i .is_molecule_types [:, IS_LIGAND_INDEX ]):
247252 num_atoms = mol .GetNumAtoms ()
248253
249254 if is_ligand :
@@ -347,7 +352,7 @@ def molecule_to_atom_input(mol_input: MoleculeInput) -> AtomInput:
347352 asym_ids = F .pad (asym_ids , (1 , 0 ), value = - 1 )
348353 is_first_mol_in_chains = (asym_ids [1 :] - asym_ids [:- 1 ]) == 1
349354
350- is_chainable_biomolecules = i .is_molecule_types [..., : 3 ].any (dim = - 1 )
355+ is_chainable_biomolecules = i .is_molecule_types [..., IS_BIOMOLECULE_INDICES ].any (dim = - 1 )
351356
352357 # for every molecule, build the bonds id matrix and add to `atompair_ids`
353358
@@ -746,9 +751,10 @@ def alphafold3_input_to_molecule_input(alphafold3_input: Alphafold3Input) -> Mol
746751 len (all_rna_mols ),
747752 len (all_dna_mols ),
748753 total_ligand_tokens ,
754+ num_metal_ions
749755 ]
750756
751- num_tokens = sum (molecule_type_token_lens ) + num_metal_ions
757+ num_tokens = sum (molecule_type_token_lens )
752758
753759 assert num_tokens > 0 , "you have an empty alphafold3 input"
754760
@@ -757,7 +763,6 @@ def alphafold3_input_to_molecule_input(alphafold3_input: Alphafold3Input) -> Mol
757763 molecule_types_lens_cumsum = tensor ([0 , * molecule_type_token_lens ]).cumsum (dim = - 1 )
758764 left , right = molecule_types_lens_cumsum [:- 1 ], molecule_types_lens_cumsum [1 :]
759765
760- # TODO: fix bug that may leave molecules with no assigned type
761766 is_molecule_types = (arange >= left ) & (arange < right )
762767
763768 # all molecules, layout is
@@ -950,7 +955,7 @@ def get_num_atoms_per_chain(chains: List[List[Mol]]) -> List[int]:
950955 for mol_index , (mol_miss_atom_indices , mol ) in enumerate (
951956 zip (i .missing_atom_indices , molecules )
952957 ):
953- is_ligand_residue = is_molecule_types [mol_index , - 1 ].item ()
958+ is_ligand_residue = is_molecule_types [mol_index , IS_LIGAND_INDEX ].item ()
954959 mol_miss_atom_indices = default (mol_miss_atom_indices , [])
955960 mol_miss_atom_indices = tensor (mol_miss_atom_indices , dtype = torch .long )
956961
@@ -1427,9 +1432,9 @@ def pdb_input_to_molecule_input(pdb_input: PDBInput, training: bool = True) -> M
14271432 molecule_ids = torch .from_numpy (biomol .restype )
14281433
14291434 # retrieve is_molecule_types from the `Biomolecule` object, which is a boolean tensor of shape [*, 4]
1430- # is_protein | is_rna | is_dna | is_ligand
1435+ # is_protein | is_rna | is_dna | is_ligand | is_metal_ion
14311436 # this is needed for their special diffusion loss
1432- n_one_hot = 4
1437+ n_one_hot = 5
14331438 is_molecule_types = F .one_hot (torch .from_numpy (biomol .chemtype ), num_classes = n_one_hot ).bool ()
14341439
14351440 # manually derive remaining features using the `Biomolecule` object
@@ -1464,11 +1469,20 @@ def pdb_input_to_molecule_input(pdb_input: PDBInput, training: bool = True) -> M
14641469 molecule_atom_types .extend ([mol_type ] * num_atoms )
14651470 # ensure modified polymer residues are one-hot encoded as ligands
14661471 # TODO: double-check whether this handling of modified polymer residues makes sense
1467- is_molecule_types [molecule_idx : molecule_idx + num_atoms , : n_one_hot - 1 ] = False
1468- is_molecule_types [molecule_idx : molecule_idx + num_atoms , n_one_hot - 1 ] = True
1472+
1473+ molecule_type_row_idx = slice (molecule_idx , molecule_idx + num_atoms )
1474+
1475+ is_molecule_types [molecule_type_row_idx , IS_BIOMOLECULE_INDICES ] = False
1476+
14691477 if num_atoms == 1 :
14701478 # NOTE: we manually set the molecule ID of ions to the `gap` ID
1471- molecule_ids [molecule_idx ] = gap_id
1479+ molecule_ids [molecule_type_row_idx ] = gap_id
1480+ is_mol_type_index = IS_METAL_ION_INDEX
1481+ else :
1482+ is_mol_type_index = IS_LIGAND_INDEX
1483+
1484+ is_molecule_types [molecule_type_row_idx , is_mol_type_index ] = True
1485+
14721486 molecule_idx += num_atoms
14731487 else :
14741488 token_pool_lens .append (num_atoms )
0 commit comments