@@ -372,6 +372,7 @@ class Alphafold3Input:
372372 templates : Float ['t n n dt' ] | None = None
373373 msa : Float ['s n dm' ] | None = None
374374 atom_pos : List [Float ['_ 3' ]] | Float ['m 3' ] | None = None
375+ reorder_atom_pos : bool = True
375376 template_mask : Bool [' t' ] | None = None
376377 msa_mask : Bool [' s' ] | None = None
377378 distance_labels : Int ['n n' ] | None = None
@@ -770,6 +771,8 @@ def get_num_atoms_per_chain(chains: List[List[Mol]]) -> List[int]:
770771
771772 if i .add_output_atompos_indices :
772773 offset = 0
774+
775+ reorder_atompos_indices = []
773776 output_atompos_indices = []
774777
775778 for chain in chainable_biomol_entries :
@@ -784,6 +787,8 @@ def get_num_atoms_per_chain(chains: List[List[Mol]]) -> List[int]:
784787 num_atoms -= 1
785788 atom_reorder_indices = atom_reorder_indices [:- 1 ]
786789
790+ reorder_atompos_indices .append (atom_reorder_indices )
791+
787792 reorder_back_indices = atom_reorder_indices .argsort ()
788793 output_atompos_indices .append (reorder_back_indices + offset )
789794
@@ -792,6 +797,18 @@ def get_num_atoms_per_chain(chains: List[List[Mol]]) -> List[int]:
792797 output_atompos_indices = torch .cat (output_atompos_indices , dim = - 1 )
793798 output_atompos_indices = F .pad (output_atompos_indices , (0 , total_atoms - output_atompos_indices .shape [- 1 ]), value = - 1 )
794799
800+ # if atom positions are passed in, need to be reordered for the bonds between residues / nucleotides to be contiguous
801+ # todo - fix to have no reordering needed (bonds are built not contiguous, just hydroxyl removed)
802+
803+ if i .reorder_atom_pos :
804+ orig_order = torch .arange (total_atoms )
805+ reorder_atompos_indices = torch .cat (reorder_atompos_indices , dim = - 1 )
806+ reorder_atompos_indices = F .pad (reorder_atompos_indices , (0 , total_atoms - reorder_atompos_indices .shape [- 1 ]), value = - 1 )
807+
808+ reorder_indices = torch .where (reorder_atompos_indices != - 1 , reorder_atompos_indices , orig_order )
809+
810+ atom_pos = atom_pos [reorder_indices ]
811+
795812 # create molecule input
796813
797814 molecule_input = MoleculeInput (
0 commit comments