@@ -544,8 +544,10 @@ def map_int_or_string_indices_to_mol(
544544def maybe_string_to_int (
545545 entries : dict ,
546546 indices : Int [' _' ] | List [str ] | str ,
547- other_index : int = 0
548547) -> Int [' _' ]:
548+
549+ unknown_index = len (entries ) - 1
550+
549551 if isinstance (indices , str ):
550552 indices = list (indices )
551553
@@ -554,7 +556,7 @@ def maybe_string_to_int(
554556
555557 index = {symbol : i for i , symbol in enumerate (entries .keys ())}
556558
557- return tensor ([index [ c ] for c in indices ]).long ()
559+ return tensor ([index . get ( c , unknown_index ) for c in indices ]).long ()
558560
559561@typecheck
560562def alphafold3_input_to_molecule_input (
@@ -582,12 +584,14 @@ def alphafold3_input_to_molecule_input(
582584 ss_dnas .extend ([seq , rc_seq ])
583585
584586 # keep track of molecule_ids - for now it is
585- # other(1 ) | proteins (20 ) | rna (4 ) | dna (4)
587+ # proteins (21 ) | rna (5 ) | dna (5 ) | gap? (1) - unknown for each biomolecule is the very last, ligand is 20
586588
587- protein_offset = 1
588- rna_offset = len (HUMAN_AMINO_ACIDS ) + protein_offset
589+ rna_offset = len (HUMAN_AMINO_ACIDS )
589590 dna_offset = len (RNA_NUCLEOTIDES ) + rna_offset
590591
592+ ligand_id = len (HUMAN_AMINO_ACIDS ) - 1
593+ gap_id = len (DNA_NUCLEOTIDES ) + dna_offset
594+
591595 molecule_ids = []
592596
593597 # convert all proteins to a List[Mol] of each peptide
@@ -609,7 +613,7 @@ def alphafold3_input_to_molecule_input(
609613
610614 src_tgt_atom_indices .extend ([[entry ['first_atom_idx' ], entry ['last_atom_idx' ]] for entry in protein_entries ])
611615
612- protein_ids = maybe_string_to_int (HUMAN_AMINO_ACIDS , protein ) + protein_offset
616+ protein_ids = maybe_string_to_int (HUMAN_AMINO_ACIDS , protein )
613617 molecule_ids .append (protein_ids )
614618
615619 chainable_biomol_entries .append (protein_entries )
@@ -652,11 +656,15 @@ def alphafold3_input_to_molecule_input(
652656 metal_ions = alphafold3_input .metal_ions
653657 mol_metal_ions = map_int_or_string_indices_to_mol (METALS , metal_ions )
654658
659+ molecule_ids .append (tensor ([gap_id ] * len (mol_metal_ions )))
660+
655661 # convert ligands to rdchem.Mol
656662
657663 ligands = list (alphafold3_input .ligands )
658664 mol_ligands = [(mol_from_smile (ligand ) if isinstance (ligand , str ) else ligand ) for ligand in ligands ]
659665
666+ molecule_ids .append (tensor ([ligand_id ] * len (mol_ligands )))
667+
660668 # create the molecule input
661669
662670 all_protein_mols = flatten (mol_proteins )
@@ -766,7 +774,7 @@ def alphafold3_input_to_molecule_input(
766774
767775 # handle molecule ids
768776
769- molecule_ids = torch .cat (molecule_ids )
777+ molecule_ids = torch .cat (molecule_ids ). long ()
770778 molecule_ids = pad_to_len (molecule_ids , num_tokens )
771779
772780 # handle atom_parent_ids
0 commit comments