@@ -241,11 +241,10 @@ def inner(*args, **kwargs):
241241
242242# get atompair bonds functions
243243
244- ATOM_BOND_INDEX = {symbol : (idx + 1 ) for idx , symbol in enumerate (ATOM_BONDS )}
245-
246244@typecheck
247245def get_atompair_ids (
248246 mol : Mol ,
247+ atom_bonds : List [str ],
249248 directed_bonds : bool
250249) -> Int ['m m' ] | None :
251250
@@ -258,8 +257,10 @@ def get_atompair_ids(
258257 bonds = mol .GetBonds ()
259258 num_bonds = len (bonds )
260259
261- num_atom_bond_types = len (ATOM_BOND_INDEX )
262- other_index = len (ATOM_BONDS ) + 1
260+ atom_bond_index = {symbol : (idx + 1 ) for idx , symbol in enumerate (atom_bonds )}
261+
262+ num_atom_bond_types = len (atom_bond_index )
263+ other_index = len (atom_bond_index ) + 1
263264
264265 for bond in bonds :
265266 atom_start_index = bond .GetBeginAtomIdx ()
@@ -273,7 +274,7 @@ def get_atompair_ids(
273274 )
274275
275276 bond_type = bond .GetBondType ()
276- bond_id = ATOM_BOND_INDEX .get (bond_type , other_index ) + 1
277+ bond_id = atom_bond_index .get (bond_type , other_index ) + 1
277278
278279 # default to symmetric bond type (undirected atom bonds)
279280
@@ -761,7 +762,8 @@ class MoleculeInput:
761762 directed_bonds : bool = False
762763 extract_atom_feats_fn : Callable [[Atom ], Float ["m dai" ]] = default_extract_atom_feats_fn # type: ignore
763764 extract_atompair_feats_fn : Callable [[Mol ], Float ["m m dapi" ]] = default_extract_atompair_feats_fn # type: ignore
764- custom_atoms : List [str ]| None = None
765+ custom_atoms : List [str ] | None = None
766+ custom_bonds : List [str ] | None = None
765767
766768@typecheck
767769def molecule_to_atom_input (mol_input : MoleculeInput ) -> AtomInput :
@@ -891,6 +893,8 @@ def molecule_to_atom_input(mol_input: MoleculeInput) -> AtomInput:
891893 prev_mol = None
892894 prev_src_tgt_atom_indices = None
893895
896+ atom_bonds = default (i .custom_bonds , ATOM_BONDS )
897+
894898 for (
895899 mol ,
896900 mol_id ,
@@ -914,7 +918,7 @@ def molecule_to_atom_input(mol_input: MoleculeInput) -> AtomInput:
914918 should_cache = is_chainable_biomolecule .item ()
915919 )
916920
917- mol_atompair_ids = maybe_cached_get_atompair_ids (mol , directed_bonds = i .directed_bonds )
921+ mol_atompair_ids = maybe_cached_get_atompair_ids (mol , atom_bonds , directed_bonds = i .directed_bonds )
918922
919923 # /einx.set_at
920924
@@ -1103,7 +1107,8 @@ class MoleculeLengthMoleculeInput:
11031107 directed_bonds : bool = False
11041108 extract_atom_feats_fn : Callable [[Atom ], Float ["m dai" ]] = default_extract_atom_feats_fn # type: ignore
11051109 extract_atompair_feats_fn : Callable [[Mol ], Float ["m m dapi" ]] = default_extract_atompair_feats_fn # type: ignore
1106- custom_atoms : List [str ]| None = None
1110+ custom_atoms : List [str ] | None = None
1111+ custom_bonds : List [str ] | None = None
11071112
11081113
11091114@typecheck
@@ -1354,6 +1359,8 @@ def molecule_lengthed_molecule_input_to_atom_input(
13541359 prev_mol = None
13551360 prev_src_tgt_atom_indices = None
13561361
1362+ atom_bonds = default (i .custom_bonds , ATOM_BONDS )
1363+
13571364 for (
13581365 mol ,
13591366 mol_id ,
@@ -1377,7 +1384,7 @@ def molecule_lengthed_molecule_input_to_atom_input(
13771384 should_cache = is_chainable_biomolecule .item ()
13781385 )
13791386
1380- mol_atompair_ids = maybe_cached_get_atompair_ids (mol , directed_bonds = i .directed_bonds )
1387+ mol_atompair_ids = maybe_cached_get_atompair_ids (mol , atom_bonds , directed_bonds = i .directed_bonds )
13811388
13821389 # mol_atompair_ids = einx.set_at("[h w], c [2], c -> [h w]", mol_atompair_ids, coordinates, updates)
13831390
@@ -1553,6 +1560,7 @@ class Alphafold3Input:
15531560 extract_atom_feats_fn : Callable [[Atom ], Float ["m dai" ]] = default_extract_atom_feats_fn # type: ignore
15541561 extract_atompair_feats_fn : Callable [[Mol ], Float ["m m dapi" ]] = default_extract_atompair_feats_fn # type: ignore
15551562 custom_atoms : List [str ] | None = None
1563+ custom_bonds : List [str ] | None = None
15561564
15571565@typecheck
15581566def map_int_or_string_indices_to_mol (
@@ -1999,7 +2007,8 @@ def get_num_atoms_per_chain(chains: List[List[Mol]]) -> List[int]:
19992007 directed_bonds = i .directed_bonds ,
20002008 extract_atom_feats_fn = i .extract_atom_feats_fn ,
20012009 extract_atompair_feats_fn = i .extract_atompair_feats_fn ,
2002- custom_atoms = i .custom_atoms
2010+ custom_atoms = i .custom_atoms ,
2011+ custom_bonds = i .custom_bonds
20032012 )
20042013
20052014 return molecule_input
@@ -2166,7 +2175,8 @@ class PDBInput:
21662175 add_atom_ids : bool = False
21672176 add_atompair_ids : bool = False
21682177 directed_bonds : bool = False
2169- custom_atoms : List [str ]| None = None
2178+ custom_atoms : List [str ] | None = None
2179+ custom_bonds : List [str ] | None = None
21702180 training : bool = False
21712181 inference : bool = False
21722182 distillation : bool = False
@@ -3982,7 +3992,8 @@ def pdb_input_to_molecule_input(
39823992 directed_bonds = i .directed_bonds ,
39833993 extract_atom_feats_fn = i .extract_atom_feats_fn ,
39843994 extract_atompair_feats_fn = i .extract_atompair_feats_fn ,
3985- custom_atoms = i .custom_atoms
3995+ custom_atoms = i .custom_atoms ,
3996+ custom_bonds = i .custom_bonds
39863997 )
39873998
39883999 return molecule_input
0 commit comments