@@ -819,6 +819,7 @@ def molecule_to_atom_input(mol_input: MoleculeInput) -> AtomInput:
819819
820820 molecule_atom_indices = i .molecule_atom_indices
821821 distogram_atom_indices = i .distogram_atom_indices
822+ atom_indices_for_frame = i .atom_indices_for_frame
822823
823824 if exists (missing_token_indices ) and missing_token_indices .shape [- 1 ]:
824825 is_missing_molecule_atom = einx .equal (
@@ -827,9 +828,29 @@ def molecule_to_atom_input(mol_input: MoleculeInput) -> AtomInput:
827828 is_missing_distogram_atom = einx .equal (
828829 "n missing, n -> n missing" , missing_token_indices , distogram_atom_indices
829830 ).any (dim = - 1 )
831+ is_missing_atom_indices_for_frame = einx .equal (
832+ "n missing, n three -> n three missing" , missing_token_indices , atom_indices_for_frame
833+ ).any (dim = - 1 )
830834
831835 molecule_atom_indices = molecule_atom_indices .masked_fill (is_missing_molecule_atom , - 1 )
832836 distogram_atom_indices = distogram_atom_indices .masked_fill (is_missing_distogram_atom , - 1 )
837+ atom_indices_for_frame = atom_indices_for_frame .masked_fill (
838+ is_missing_atom_indices_for_frame , - 1
839+ )
840+
841+ # sanity-check the atom indices
842+ if not (- 1 <= molecule_atom_indices .min () <= molecule_atom_indices .max () < total_atoms ):
843+ raise ValueError (
844+ f"Invalid molecule atom indices found in `molecule_to_atom_input()` for { i .filepath } : { molecule_atom_indices } "
845+ )
846+ if not (- 1 <= distogram_atom_indices .min () <= distogram_atom_indices .max () < total_atoms ):
847+ raise ValueError (
848+ f"Invalid distogram atom indices found in `molecule_to_atom_input()` for { i .filepath } : { distogram_atom_indices } "
849+ )
850+ if not (- 1 <= atom_indices_for_frame .min () <= atom_indices_for_frame .max () < total_atoms ):
851+ raise ValueError (
852+ f"Invalid atom indices for frame found in `molecule_to_atom_input()` for { i .filepath } : { atom_indices_for_frame } "
853+ )
833854
834855 # handle atom positions
835856
@@ -856,9 +877,9 @@ def molecule_to_atom_input(mol_input: MoleculeInput) -> AtomInput:
856877 atompair_inputs = atompair_inputs ,
857878 molecule_atom_lens = atom_lens .long (),
858879 molecule_ids = i .molecule_ids ,
859- molecule_atom_indices = i . molecule_atom_indices ,
860- distogram_atom_indices = i . distogram_atom_indices ,
861- atom_indices_for_frame = i . atom_indices_for_frame ,
880+ molecule_atom_indices = molecule_atom_indices ,
881+ distogram_atom_indices = distogram_atom_indices ,
882+ atom_indices_for_frame = atom_indices_for_frame ,
862883 is_molecule_mod = is_molecule_mod ,
863884 msa = i .msa ,
864885 templates = i .templates ,
@@ -3025,12 +3046,14 @@ def pdb_input_to_molecule_input(
30253046
30263047 current_atom_index = 0
30273048 current_res_index = - 1
3049+ current_chain_index = - 1
30283050
3029- for mol_type , atom_mask , chemid , res_index in zip (
3051+ for mol_type , atom_mask , chemid , res_index , res_chain_index in zip (
30303052 molecule_atom_types ,
30313053 biomol .atom_mask ,
30323054 biomol .chemid ,
30333055 biomol .residue_index ,
3056+ biomol .chain_index ,
30343057 ):
30353058 residue_constants = get_residue_constants (
30363059 mol_type .replace ("protein" , "peptide" ).replace ("mod_" , "" )
@@ -3045,11 +3068,12 @@ def pdb_input_to_molecule_input(
30453068
30463069 if is_atomized_residue (mol_type ):
30473070 # collect indices for each ligand and modified polymer residue token (i.e., atom)
3048- if current_res_index == res_index :
3071+ if current_res_index == res_index and current_chain_index == res_chain_index :
30493072 current_atom_index += 1
30503073 else :
30513074 current_atom_index = 0
30523075 current_res_index = res_index
3076+ current_chain_index = res_chain_index
30533077
30543078 # NOTE: we have to dynamically determine the token center atom index for atomized residues
30553079 token_center_atom_index = np .where (atom_mask )[0 ][0 ]
@@ -3439,6 +3463,20 @@ def pdb_input_to_molecule_input(
34393463 )
34403464 num_atoms = atom_pos .shape [0 ]
34413465
3466+ # sanity-check the atom indices
3467+ if not (- 1 <= distogram_atom_indices .min () <= distogram_atom_indices .max () < num_atoms ):
3468+ raise ValueError (
3469+ f"Invalid distogram atom indices found in `pdb_input_to_molecule_input()` for { filepath } : { distogram_atom_indices } "
3470+ )
3471+ if not (- 1 <= molecule_atom_indices .min () <= molecule_atom_indices .max () < num_atoms ):
3472+ raise ValueError (
3473+ f"Invalid molecule atom indices found in `pdb_input_to_molecule_input()` for { filepath } : { molecule_atom_indices } "
3474+ )
3475+ if not (- 1 <= atom_indices_for_frame .min () <= atom_indices_for_frame .max () < num_atoms ):
3476+ raise ValueError (
3477+ f"Invalid atom indices for frame found in `pdb_input_to_molecule_input()` for { filepath } : { atom_indices_for_frame } "
3478+ )
3479+
34423480 # create atom_parent_ids using the `Biomolecule` object, which governs in the atom
34433481 # encoder / decoder which atom attends to which, where a design choice is made such
34443482 # that mmCIF author chain indices are directly adopted to group atoms belonging to
0 commit comments