@@ -426,6 +426,8 @@ def alphafold3_input_to_molecule_input(
426426
427427 i = alphafold3_input
428428
429+ chainable_biomol_entries : List [List [dict ]] = [] # for reordering the atom positions at the end
430+
429431 ss_rnas = list (i .ss_rna )
430432 ss_dnas = list (i .ss_dna )
431433
@@ -467,25 +469,31 @@ def alphafold3_input_to_molecule_input(
467469 protein_ids = maybe_string_to_int (HUMAN_AMINO_ACIDS , protein ) + protein_offset
468470 molecule_ids .append (protein_ids )
469471
472+ chainable_biomol_entries .append (protein_entries )
473+
470474 # convert all single stranded nucleic acids to mol
471475
472476 mol_ss_dnas = []
473477 mol_ss_rnas = []
474478
475479 for seq in ss_rnas :
476- mol_seq = map_int_or_string_indices_to_mol (RNA_NUCLEOTIDES , seq , chain = True )
480+ mol_seq , ss_rna_entries = map_int_or_string_indices_to_mol (RNA_NUCLEOTIDES , seq , chain = True , return_entries = True )
477481 mol_ss_rnas .append (mol_seq )
478482
479483 rna_ids = maybe_string_to_int (RNA_NUCLEOTIDES , seq ) + rna_offset
480484 molecule_ids .append (rna_ids )
481485
486+ chainable_biomol_entries .append (ss_rna_entries )
487+
482488 for seq in ss_dnas :
483- mol_seq = map_int_or_string_indices_to_mol (DNA_NUCLEOTIDES , seq , chain = True )
489+ mol_seq , ss_dna_entries = map_int_or_string_indices_to_mol (DNA_NUCLEOTIDES , seq , chain = True , return_entries = True )
484490 mol_ss_dnas .append (mol_seq )
485491
486492 dna_ids = maybe_string_to_int (DNA_NUCLEOTIDES , seq ) + dna_offset
487493 molecule_ids .append (dna_ids )
488494
495+ chainable_biomol_entries .append (ss_dna_entries )
496+
489497 # convert metal ions to rdchem.Mol
490498
491499 metal_ions = alphafold3_input .metal_ions
@@ -558,6 +566,8 @@ def alphafold3_input_to_molecule_input(
558566 * metal_ions_pool_lens
559567 ]
560568
569+ total_atoms = sum (token_pool_lens )
570+
561571 # construct the token bonds
562572
563573 # will be linearly connected for proteins and nucleic acids
@@ -724,11 +734,38 @@ def get_num_atoms_per_chain(chains: List[List[Mol]]) -> List[int]:
724734 # handle atom positions
725735
726736 atom_pos = i .atom_pos
737+ output_atompos_indices = None
727738
728739 if exists (atom_pos ):
729740 if isinstance (atom_pos , list ):
730741 atom_pos = torch .cat (atom_pos , dim = - 2 )
731742
743+ # to automatically reorder the atom positions back to canonical
744+
745+ if i .add_output_atompos_indices :
746+ offset = 0
747+ output_atompos_indices = []
748+
749+ for chain in chainable_biomol_entries :
750+ for idx , entry in enumerate (chain ):
751+ is_last = idx == (len (chain ) - 1 )
752+
753+ mol = entry ['rdchem_mol' ]
754+ num_atoms = mol .GetNumAtoms ()
755+ atom_reorder_indices = entry ['atom_reorder_indices' ]
756+
757+ if not is_last :
758+ num_atoms -= 1
759+ atom_reorder_indices = atom_reorder_indices [:- 1 ]
760+
761+ reorder_back_indices = atom_reorder_indices .argsort ()
762+ output_atompos_indices .append (reorder_back_indices + offset )
763+
764+ offset += num_atoms
765+
766+ output_atompos_indices = torch .cat (output_atompos_indices , dim = - 1 )
767+ output_atompos_indices = F .pad (output_atompos_indices , (0 , total_atoms - output_atompos_indices .shape [- 1 ]), value = - 1 )
768+
732769 # create molecule input
733770
734771 molecule_input = MoleculeInput (
@@ -741,6 +778,7 @@ def get_num_atoms_per_chain(chains: List[List[Mol]]) -> List[int]:
741778 additional_token_feats = default (i .additional_token_feats , torch .zeros (num_tokens , 2 )),
742779 is_molecule_types = is_molecule_types ,
743780 atom_pos = atom_pos ,
781+ output_atompos_indices = output_atompos_indices ,
744782 templates = i .templates ,
745783 msa = i .msa ,
746784 template_mask = i .template_mask ,
0 commit comments