@@ -135,7 +135,6 @@ class AtomInput:
135135 template_mask : Bool [' t' ] | None = None
136136 msa_mask : Bool [' s' ] | None = None
137137 atom_pos : Float ['m 3' ] | None = None
138- output_atompos_indices : Int [' m' ] | None = None
139138 molecule_atom_indices : Int [' n' ] | None = None
140139 distogram_atom_indices : Int [' n' ] | None = None
141140 distance_labels : Int ['n n' ] | None = None
@@ -166,7 +165,6 @@ class BatchedAtomInput:
166165 template_mask : Bool ['b t' ] | None = None
167166 msa_mask : Bool ['b s' ] | None = None
168167 atom_pos : Float ['b m 3' ] | None = None
169- output_atompos_indices : Int ['b m' ] | None = None
170168 molecule_atom_indices : Int ['b n' ] | None = None
171169 distogram_atom_indices : Int ['b n' ] | None = None
172170 distance_labels : Int ['b n n' ] | None = None
@@ -215,7 +213,6 @@ class MoleculeInput:
215213 templates : Float ['t n n dt' ] | None = None
216214 msa : Float ['s n dm' ] | None = None
217215 atom_pos : List [Float ['_ 3' ]] | Float ['m 3' ] | None = None
218- output_atompos_indices : Int [' m' ] | None = None
219216 template_mask : Bool [' t' ] | None = None
220217 msa_mask : Bool [' s' ] | None = None
221218 distance_labels : Int ['n n' ] | None = None
@@ -355,9 +352,13 @@ def molecule_to_atom_input(
355352 # and not the first biomolecule in the chain, add a single covalent bond between first atom of incoming biomolecule and the last atom of the last biomolecule
356353
357354 if is_chainable_biomolecule and not is_first_mol_in_chain :
355+
356+
358357 atompair_ids [offset , offset - 1 ] = 1
359358 atompair_ids [offset - 1 , offset ] = 1
360359
360+ last_mol = mol
361+
361362 # atom_inputs
362363
363364 atom_inputs : List [Float ['m dai' ]] = []
@@ -444,7 +445,6 @@ class Alphafold3Input:
444445 resolved_labels : Int [' n' ] | None = None
445446 add_atom_ids : bool = False
446447 add_atompair_ids : bool = False
447- add_output_atompos_indices : bool = True
448448 directed_bonds : bool = False
449449 extract_atom_feats_fn : Callable [[Atom ], Float ['m dai' ]] = default_extract_atom_feats_fn
450450 extract_atompair_feats_fn : Callable [[Mol ], Float ['m m dapi' ]] = default_extract_atompair_feats_fn
@@ -814,58 +814,9 @@ def get_num_atoms_per_chain(chains: List[List[Mol]]) -> List[int]:
814814 molecule_atom_indices = tensor (molecule_atom_indices )
815815 molecule_atom_indices = pad_to_len (molecule_atom_indices , num_tokens , value = - 1 )
816816
817- # handle atom positions
817+ # todo - handle atom positions for variable lengthed atoms (eventual missing atoms from mmCIF)
818818
819819 atom_pos = i .atom_pos
820- output_atompos_indices = None
821-
822- if exists (atom_pos ):
823- if isinstance (atom_pos , list ):
824- atom_pos = torch .cat (atom_pos , dim = - 2 )
825-
826- assert atom_pos .shape [- 2 ] == total_atoms
827-
828- # to automatically reorder the atom positions back to canonical
829-
830- if i .add_output_atompos_indices :
831- offset = 0
832-
833- reorder_atompos_indices = []
834- output_atompos_indices = []
835-
836- for chain in chainable_biomol_entries :
837- for idx , entry in enumerate (chain ):
838- is_last = idx == (len (chain ) - 1 )
839-
840- mol = entry ['rdchem_mol' ]
841- num_atoms = mol .GetNumAtoms ()
842- atom_reorder_indices = entry ['atom_reorder_indices' ]
843-
844- if not is_last :
845- num_atoms -= 1
846- atom_reorder_indices = atom_reorder_indices [:- 1 ]
847-
848- reorder_back_indices = atom_reorder_indices .argsort ()
849-
850- output_atompos_indices .append (reorder_back_indices + offset )
851- reorder_atompos_indices .append (atom_reorder_indices + offset )
852-
853- offset += num_atoms
854-
855- output_atompos_indices = torch .cat (output_atompos_indices , dim = - 1 )
856- output_atompos_indices = pad_to_length (output_atompos_indices , total_atoms , value = - 1 )
857-
858- # if atom positions are passed in, need to be reordered for the bonds between residues / nucleotides to be contiguous
859- # todo - fix to have no reordering needed (bonds are built not contiguous, just hydroxyl removed)
860-
861- if i .reorder_atom_pos :
862- orig_order = torch .arange (total_atoms )
863- reorder_atompos_indices = torch .cat (reorder_atompos_indices , dim = - 1 )
864- reorder_atompos_indices = pad_to_length (reorder_atompos_indices , total_atoms , value = - 1 )
865-
866- reorder_indices = torch .where (reorder_atompos_indices != - 1 , reorder_atompos_indices , orig_order )
867-
868- atom_pos = atom_pos [reorder_indices ]
869820
870821 # create molecule input
871822
@@ -880,7 +831,6 @@ def get_num_atoms_per_chain(chains: List[List[Mol]]) -> List[int]:
880831 additional_token_feats = default (i .additional_token_feats , torch .zeros (num_tokens , 2 )),
881832 is_molecule_types = is_molecule_types ,
882833 atom_pos = atom_pos ,
883- output_atompos_indices = output_atompos_indices ,
884834 templates = i .templates ,
885835 msa = i .msa ,
886836 template_mask = i .template_mask ,
0 commit comments