@@ -154,15 +154,17 @@ def molecule_to_atom_input(
154154 mol_input : MoleculeInput
155155) -> AtomInput :
156156
157- molecules = mol_input .molecules
158- atom_lens = mol_input .molecule_token_pool_lens
157+ i = mol_input
158+
159+ molecules = i .molecules
160+ atom_lens = i .molecule_token_pool_lens
159161
160162 # get total number of atoms
161163
162164 if not exists (atom_lens ):
163165 atom_lens = []
164166
165- for mol , is_ligand in zip (molecules , mol_input .is_molecule_types [:, - 1 ]):
167+ for mol , is_ligand in zip (molecules , i .is_molecule_types [:, - 1 ]):
166168 num_atoms = mol .GetNumAtoms ()
167169
168170 if is_ligand :
@@ -184,7 +186,7 @@ def molecule_to_atom_input(
184186
185187 atom_ids = None
186188
187- if mol_input .add_atom_ids :
189+ if i .add_atom_ids :
188190 atom_index = {symbol : i for i , symbol in enumerate (ATOMS )}
189191
190192 atom_ids = []
@@ -201,14 +203,22 @@ def molecule_to_atom_input(
201203
202204 atompair_ids = None
203205
204- if mol_input .add_atompair_ids :
205- atom_bond_index = {symbol : (i + 1 ) for i , symbol in enumerate (ATOM_BONDS )}
206+ if i .add_atompair_ids :
207+ atom_bond_index = {symbol : (idx + 1 ) for idx , symbol in enumerate (ATOM_BONDS )}
206208 other_index = len (ATOM_BONDS ) + 1
207209
208210 atompair_ids = torch .zeros (total_atoms , total_atoms ).long ()
211+
209212 offset = 0
210213
211- for mol in molecules :
214+ # need the asym_id (each molecule for each chain ascending) as well as `is_protein | is_dna | is_rna` for is_molecule_types (chainable biomolecules)
215+ # will do a single bond from a peptide or nucleotide to the one before, if `asym_id` != 0 (first in the chain)
216+
217+ asym_ids = i .additional_molecule_feats [..., 2 ]
218+ is_chainable_biomolecules = i .is_molecule_types [..., :3 ].any (dim = - 1 )
219+
220+ for idx , (mol , asym_id , is_chainable_biomolecule ) in enumerate (zip (molecules , asym_ids , is_chainable_biomolecules )):
221+
212222 coordinates = []
213223 updates = []
214224
@@ -225,7 +235,7 @@ def molecule_to_atom_input(
225235 ])
226236
227237 bond_type = bond .GetBondType ()
228- bond_id = atom_bond_index .get (bond_type , other_index )
238+ bond_id = atom_bond_index .get (bond_type , other_index ) + 1
229239
230240 updates .extend ([bond_id , bond_id ])
231241
@@ -237,7 +247,14 @@ def molecule_to_atom_input(
237247 row_col_slice = slice (offset , offset + num_atoms )
238248 atompair_ids [row_col_slice , row_col_slice ] = mol_atompair_ids
239249
240- offset += num_atoms
250+ # if is chainable biomolecule
251+ # and not the first biomolecule in the chain, add a single covalent bond between first atom of incoming biomolecule and the last atom of the last biomolecule
252+
253+ if is_chainable_biomolecule and asym_id != 0 :
254+ atompair_ids [offset , offset - 1 ] = 1
255+ atompair_ids [offset - 1 , offset ] = 1
256+
257+ offset += num_atoms
241258
242259 # atom_inputs
243260
@@ -266,8 +283,8 @@ def molecule_to_atom_input(
266283
267284 all_atom_pos = []
268285
269- for i , atom in enumerate (mol .GetAtoms ()):
270- pos = mol .GetConformer ().GetAtomPosition (i )
286+ for idx , atom in enumerate (mol .GetAtoms ()):
287+ pos = mol .GetConformer ().GetAtomPosition (idx )
271288 all_atom_pos .append ([pos .x , pos .y , pos .z ])
272289
273290 all_atom_pos_tensor = tensor (all_atom_pos )
@@ -285,12 +302,12 @@ def molecule_to_atom_input(
285302 atom_inputs = tensor (atom_inputs , dtype = torch .float ),
286303 atompair_inputs = atompair_inputs ,
287304 molecule_atom_lens = tensor (atom_lens , dtype = torch .long ),
288- molecule_ids = mol_input .molecule_ids ,
289- additional_token_feats = mol_input .additional_token_feats ,
290- additional_molecule_feats = mol_input .additional_molecule_feats ,
291- is_molecule_types = mol_input .is_molecule_types ,
292- token_bonds = mol_input .token_bonds ,
293- atom_parent_ids = mol_input .atom_parent_ids ,
305+ molecule_ids = i .molecule_ids ,
306+ additional_token_feats = i .additional_token_feats ,
307+ additional_molecule_feats = i .additional_molecule_feats ,
308+ is_molecule_types = i .is_molecule_types ,
309+ token_bonds = i .token_bonds ,
310+ atom_parent_ids = i .atom_parent_ids ,
294311 atom_ids = atom_ids ,
295312 atompair_ids = atompair_ids
296313 )
0 commit comments