@@ -137,6 +137,7 @@ class AtomInput:
137137 atom_pos : Float ['m 3' ] | None = None
138138 output_atompos_indices : Int [' m' ] | None = None
139139 molecule_atom_indices : Int [' n' ] | None = None
140+ distogram_atom_indices : Int [' n' ] | None = None
140141 distance_labels : Int ['n n' ] | None = None
141142 pae_labels : Int ['n n' ] | None = None
142143 pde_labels : Int ['n n' ] | None = None
@@ -167,6 +168,7 @@ class BatchedAtomInput:
167168 atom_pos : Float ['b m 3' ] | None = None
168169 output_atompos_indices : Int ['b m' ] | None = None
169170 molecule_atom_indices : Int ['b n' ] | None = None
171+ distogram_atom_indices : Int ['b n' ] | None = None
170172 distance_labels : Int ['b n n' ] | None = None
171173 pae_labels : Int ['b n n' ] | None = None
172174 pde_labels : Int ['b n n' ] | None = None
@@ -202,11 +204,12 @@ def default_extract_atompair_feats_fn(mol: Mol):
202204class MoleculeInput :
203205 molecules : List [Mol ]
204206 molecule_token_pool_lens : List [int ]
205- molecule_atom_indices : List [int | None ]
206207 molecule_ids : Int [' n' ]
207208 additional_molecule_feats : Int [f'n { ADDITIONAL_MOLECULE_FEATS } ' ]
208209 is_molecule_types : Bool [f'n { IS_MOLECULE_TYPES } ' ]
209210 token_bonds : Bool ['n n' ]
211+ molecule_atom_indices : List [int | None ] | None = None
212+ distogram_atom_indices : List [int | None ] | None = None
210213 atom_parent_ids : Int [' m' ] | None = None
211214 additional_token_feats : Float [f'n dtf' ] | None = None
212215 templates : Float ['t n n dt' ] | None = None
@@ -398,6 +401,8 @@ def molecule_to_atom_input(
398401 atompair_inputs = atompair_inputs ,
399402 molecule_atom_lens = tensor (atom_lens , dtype = torch .long ),
400403 molecule_ids = i .molecule_ids ,
404+ molecule_atom_indices = i .molecule_atom_indices ,
405+ distogram_atom_indices = i .distogram_atom_indices ,
401406 additional_token_feats = i .additional_token_feats ,
402407 additional_molecule_feats = i .additional_molecule_feats ,
403408 is_molecule_types = i .is_molecule_types ,
@@ -542,12 +547,15 @@ def alphafold3_input_to_molecule_input(
542547 proteins = i .proteins
543548 mol_proteins = []
544549 protein_entries = []
550+
551+ distogram_atom_indices = []
545552 molecule_atom_indices = []
546553
547554 for protein in proteins :
548555 mol_peptides , protein_entries = map_int_or_string_indices_to_mol (HUMAN_AMINO_ACIDS , protein , chain = True , return_entries = True )
549556 mol_proteins .append (mol_peptides )
550557
558+ distogram_atom_indices .extend ([entry ['token_center_atom_idx' ] for entry in protein_entries ])
551559 molecule_atom_indices .extend ([entry ['distogram_atom_idx' ] for entry in protein_entries ])
552560
553561 protein_ids = maybe_string_to_int (HUMAN_AMINO_ACIDS , protein ) + protein_offset
@@ -810,7 +818,10 @@ def get_num_atoms_per_chain(chains: List[List[Mol]]) -> List[int]:
810818 sym_ids
811819 ), dim = - 1 )
812820
813- # molecule atom indices
821+ # distogram and token centre atom indices
822+
823+ distogram_atom_indices = tensor (distogram_atom_indices )
824+ distogram_atom_indices = pad_to_len (distogram_atom_indices , num_tokens , value = - 1 )
814825
815826 molecule_atom_indices = tensor (molecule_atom_indices )
816827 molecule_atom_indices = pad_to_len (molecule_atom_indices , num_tokens , value = - 1 )
@@ -874,6 +885,7 @@ def get_num_atoms_per_chain(chains: List[List[Mol]]) -> List[int]:
874885 molecules = molecules ,
875886 molecule_token_pool_lens = token_pool_lens ,
876887 molecule_atom_indices = molecule_atom_indices ,
888+ distogram_atom_indices = distogram_atom_indices ,
877889 molecule_ids = molecule_ids ,
878890 token_bonds = token_bonds ,
879891 additional_molecule_feats = additional_molecule_feats ,
0 commit comments