@@ -635,7 +635,7 @@ def molecule_to_atom_input(mol_input: MoleculeInput) -> AtomInput:
635635 molecule_atom_indices = i .molecule_atom_indices
636636 distogram_atom_indices = i .distogram_atom_indices
637637
638- if exists (missing_token_indices ):
638+ if exists (missing_token_indices ) and missing_token_indices . shape [ - 1 ] :
639639 is_missing_molecule_atom = einx .equal (
640640 "n missing, n -> n missing" , missing_token_indices , molecule_atom_indices
641641 ).any (dim = - 1 )
@@ -1064,7 +1064,7 @@ def molecule_lengthed_molecule_input_to_atom_input(mol_input: MoleculeLengthMole
10641064
10651065 # mask out molecule atom indices and distogram atom indices where it is in the missing atom indices list
10661066
1067- if exists (missing_token_indices ):
1067+ if exists (missing_token_indices ) and missing_token_indices . shape [ - 1 ] :
10681068 missing_token_indices = repeat_interleave (missing_token_indices , token_repeats , dim = 0 )
10691069
10701070 is_missing_molecule_atom = einx .equal (
@@ -2500,26 +2500,23 @@ def pdb_input_to_molecule_input(
25002500 for mol in molecules
25012501 ]
25022502
2503- if any (molecules_missing_atom_indices ):
2504- missing_atom_indices = []
2505- missing_token_indices = []
2503+ missing_atom_indices = []
2504+ missing_token_indices = []
25062505
2507- for mol_miss_atom_indices , mol , mol_type in zip (
2508- molecules_missing_atom_indices , molecules , molecule_types
2509- ):
2510- mol_miss_atom_indices = default (mol_miss_atom_indices , [])
2511- mol_miss_atom_indices = tensor (mol_miss_atom_indices , dtype = torch .long )
2506+ for mol_miss_atom_indices , mol , mol_type in zip (
2507+ molecules_missing_atom_indices , molecules , molecule_types
2508+ ):
2509+ mol_miss_atom_indices = default (mol_miss_atom_indices , [])
2510+ mol_miss_atom_indices = tensor (mol_miss_atom_indices , dtype = torch .long )
25122511
2513- missing_atom_indices .append (mol_miss_atom_indices )
2514- if is_atomized_residue (mol_type ):
2515- missing_token_indices .extend (
2516- [mol_miss_atom_indices for _ in range (mol .GetNumAtoms ())]
2517- )
2518- else :
2519- missing_token_indices .append (mol_miss_atom_indices )
2512+ missing_atom_indices .append (mol_miss_atom_indices )
2513+ if is_atomized_residue (mol_type ):
2514+ missing_token_indices .extend ([mol_miss_atom_indices for _ in range (mol .GetNumAtoms ())])
2515+ else :
2516+ missing_token_indices .append (mol_miss_atom_indices )
25202517
2521- assert len (molecules ) == len (missing_atom_indices )
2522- assert len (missing_token_indices ) == num_tokens
2518+ assert len (molecules ) == len (missing_atom_indices )
2519+ assert len (missing_token_indices ) == num_tokens
25232520
25242521 # TODO: install additional token features once MSAs are available
25252522 # 0: f_profile
0 commit comments