@@ -180,6 +180,7 @@ class MoleculeInput:
180180 resolved_labels : Int [' n' ] | None = None
181181 add_atom_ids : bool = False
182182 add_atompair_ids : bool = False
183+ directed_bonds : bool = False
183184 extract_atom_feats_fn : Callable [[Atom ], Float ['m dai' ]] = default_extract_atom_feats_fn
184185 extract_atompair_feats_fn : Callable [[Mol ], Float ['m m dapi' ]] = default_extract_atompair_feats_fn
185186
@@ -247,6 +248,8 @@ def molecule_to_atom_input(
247248
248249 if i .add_atompair_ids :
249250 atom_bond_index = {symbol : (idx + 1 ) for idx , symbol in enumerate (ATOM_BONDS )}
251+ num_atom_bond_types = len (atom_bond_index )
252+
250253 other_index = len (ATOM_BONDS ) + 1
251254
252255 atompair_ids = torch .zeros (total_atoms , total_atoms ).long ()
@@ -284,7 +287,17 @@ def molecule_to_atom_input(
284287 bond_type = bond .GetBondType ()
285288 bond_id = atom_bond_index .get (bond_type , other_index ) + 1
286289
287- updates .extend ([bond_id , bond_id ])
290+ # default to symmetric bond type (undirected atom bonds)
291+
292+ bond_to = bond_from = bond_id
293+
294+ # if allowing for directed bonds, assume num_atompair_embeds = (2 * num_atom_bond_types) + 1
295+ # offset other edge by num_atom_bond_types
296+
297+ if i .directed_bonds :
298+ bond_from += num_atom_bond_types
299+
300+ updates .extend ([bond_to , bond_from ])
288301
289302 coordinates = tensor (coordinates ).long ()
290303 updates = tensor (updates ).long ()
@@ -386,6 +399,7 @@ class Alphafold3Input:
386399 add_atom_ids : bool = False
387400 add_atompair_ids : bool = False
388401 add_output_atompos_indices : bool = True
402+ directed_bonds : bool = False
389403 extract_atom_feats_fn : Callable [[Atom ], Float ['m dai' ]] = default_extract_atom_feats_fn
390404 extract_atompair_feats_fn : Callable [[Mol ], Float ['m m dapi' ]] = default_extract_atompair_feats_fn
391405
@@ -833,6 +847,7 @@ def get_num_atoms_per_chain(chains: List[List[Mol]]) -> List[int]:
833847 atom_parent_ids = atom_parent_ids ,
834848 add_atom_ids = i .add_atom_ids ,
835849 add_atompair_ids = i .add_atompair_ids ,
850+ directed_bonds = i .directed_bonds ,
836851 extract_atom_feats_fn = i .extract_atom_feats_fn ,
837852 extract_atompair_feats_fn = i .extract_atompair_feats_fn
838853 )
0 commit comments