88import statistics
99import traceback
1010from collections import defaultdict
11- from collections .abc import Iterable
1211from contextlib import redirect_stderr
1312from dataclasses import asdict , dataclass , field
1413from datetime import datetime , timedelta
15- from functools import partial
14+ from functools import partial , wraps
1615from io import StringIO
1716from itertools import groupby
1817from pathlib import Path
19- from retrying import retry
18+
19+ from collections .abc import Iterable
2020from beartype .typing import (
2121 Any ,
2222 Callable ,
2828 Type ,
2929)
3030
31- import einx
3231import numpy as np
3332import polars as pl
33+
3434import timeout_decorator
35+ from retrying import retry
36+
3537import torch
3638import torch .nn .functional as F
39+ from torch import repeat_interleave , tensor
40+ from torch .nn .utils .rnn import pad_sequence
41+ from torch .utils .data import Dataset
42+
43+ import einx
3744from einops import pack , rearrange
45+
3846from joblib import Parallel , delayed
3947from loguru import logger
4048from pdbeccdutils .core import ccd_reader
49+
4150from rdkit import Chem , RDLogger , rdBase
4251from rdkit .Chem import AllChem , rdDetermineBonds
4352from rdkit .Chem .rdchem import Atom , Mol
4453from rdkit .Geometry import Point3D
45- from torch import repeat_interleave , tensor
46- from torch .nn .utils .rnn import pad_sequence
47- from torch .utils .data import Dataset
4854
4955from alphafold3_pytorch .common import (
5056 amino_acid_constants ,
204210 }
205211 json .dump (CCD_COMPONENTS_SMILES , f )
206212
213+ # simple caching
214+
215+ ATOMPAIR_IDS_CACHE = dict ()
216+
217+ @typecheck
218+ def maybe_cache (
219+ fn ,
220+ * ,
221+ cache : dict ,
222+ key : str ,
223+ should_cache : bool = True
224+ ) -> Callable :
225+
226+ if not should_cache :
227+ return fn
228+
229+ @wraps (fn )
230+ def inner (* args , ** kwargs ):
231+ if key in cache :
232+ return cache [key ]
233+
234+ out = fn (* args , ** kwargs )
235+
236+ cache [key ] = out
237+ return out
238+
239+ return inner
240+
241+ # get atompair bonds functions
242+
243+ ATOM_BOND_INDEX = {symbol : (idx + 1 ) for idx , symbol in enumerate (ATOM_BONDS )}
244+
245+ @typecheck
246+ def get_atompair_ids (
247+ mol : Mol ,
248+ directed_bonds : bool
249+ ) -> Tensor | None :
250+
251+ coordinates = []
252+ updates = []
253+
254+ num_atoms = mol .GetNumAtoms ()
255+ mol_atompair_ids = torch .zeros (num_atoms , num_atoms ).long ()
256+
257+ bonds = mol .GetBonds ()
258+ num_bonds = len (bonds )
259+
260+ num_atom_bond_types = len (ATOM_BOND_INDEX )
261+ other_index = len (ATOM_BONDS ) + 1
262+
263+ for bond in bonds :
264+ atom_start_index = bond .GetBeginAtomIdx ()
265+ atom_end_index = bond .GetEndAtomIdx ()
266+
267+ coordinates .extend (
268+ [
269+ [atom_start_index , atom_end_index ],
270+ [atom_end_index , atom_start_index ],
271+ ]
272+ )
273+
274+ bond_type = bond .GetBondType ()
275+ bond_id = ATOM_BOND_INDEX .get (bond_type , other_index ) + 1
276+
277+ # default to symmetric bond type (undirected atom bonds)
278+
279+ bond_to = bond_from = bond_id
280+
281+ # if allowing for directed bonds, assume num_atompair_embeds = (2 * num_atom_bond_types) + 1
282+ # offset other edge by num_atom_bond_types
283+
284+ if directed_bonds :
285+ bond_from += num_atom_bond_types
286+
287+ updates .extend ([bond_to , bond_from ])
288+
289+ if num_bonds == 0 :
290+ return None
291+
292+ coordinates = tensor (coordinates ).long ()
293+ updates = tensor (updates ).long ()
294+
295+ # mol_atompair_ids = einx.set_at("[h w], c [2], c -> [h w]", mol_atompair_ids, coordinates, updates)
296+
297+ molpair_strides = tensor (mol_atompair_ids .stride ())
298+ flattened_coordinates = (coordinates * molpair_strides ).sum (dim = - 1 )
299+
300+ packed_atompair_ids , unpack_one = pack_one (mol_atompair_ids , '*' )
301+ packed_atompair_ids [flattened_coordinates ] = updates
302+
303+ mol_atompair_ids = unpack_one (packed_atompair_ids )
304+
305+ return mol_atompair_ids
306+
207307# functions
208308
209309
@@ -730,10 +830,6 @@ def molecule_to_atom_input(mol_input: MoleculeInput) -> AtomInput:
730830 atompair_ids = None
731831
732832 if i .add_atompair_ids :
733- atom_bond_index = {symbol : (idx + 1 ) for idx , symbol in enumerate (ATOM_BONDS )}
734- num_atom_bond_types = len (atom_bond_index )
735-
736- other_index = len (ATOM_BONDS ) + 1
737833
738834 atompair_ids = torch .zeros (total_atoms , total_atoms ).long ()
739835
@@ -753,70 +849,35 @@ def molecule_to_atom_input(mol_input: MoleculeInput) -> AtomInput:
753849
754850 for (
755851 mol ,
852+ mol_id ,
756853 is_first_mol_in_chain ,
757854 is_chainable_biomolecule ,
758855 src_tgt_atom_indices ,
759856 offset ,
760857 ) in zip (
761858 molecules ,
859+ molecule_ids ,
762860 is_first_mol_in_chains ,
763861 is_chainable_biomolecules ,
764862 i .src_tgt_atom_indices ,
765863 offsets ,
766864 ):
767- coordinates = []
768- updates = []
769-
770- num_atoms = mol .GetNumAtoms ()
771- mol_atompair_ids = torch .zeros (num_atoms , num_atoms ).long ()
772-
773- bonds = mol .GetBonds ()
774- num_bonds = len (bonds )
775865
776- for bond in bonds :
777- atom_start_index = bond .GetBeginAtomIdx ()
778- atom_end_index = bond .GetEndAtomIdx ()
779-
780- coordinates .extend (
781- [
782- [atom_start_index , atom_end_index ],
783- [atom_end_index , atom_start_index ],
784- ]
785- )
786-
787- bond_type = bond .GetBondType ()
788- bond_id = atom_bond_index .get (bond_type , other_index ) + 1
789-
790- # default to symmetric bond type (undirected atom bonds)
791-
792- bond_to = bond_from = bond_id
793-
794- # if allowing for directed bonds, assume num_atompair_embeds = (2 * num_atom_bond_types) + 1
795- # offset other edge by num_atom_bond_types
796-
797- if i .directed_bonds :
798- bond_from += num_atom_bond_types
799-
800- updates .extend ([bond_to , bond_from ])
801-
802- if num_bonds > 0 :
803- coordinates = tensor (coordinates ).long ()
804- updates = tensor (updates ).long ()
805-
806- # mol_atompair_ids = einx.set_at("[h w], c [2], c -> [h w]", mol_atompair_ids, coordinates, updates)
807-
808- molpair_strides = tensor (mol_atompair_ids .stride ())
809- flattened_coordinates = (coordinates * molpair_strides ).sum (dim = - 1 )
810-
811- packed_atompair_ids , unpack_one = pack_one (mol_atompair_ids , '*' )
812- packed_atompair_ids [flattened_coordinates ] = updates
866+ maybe_cached_get_atompair_ids = maybe_cache (
867+ get_atompair_ids ,
868+ cache = ATOMPAIR_IDS_CACHE ,
869+ key = f'{ mol_id } :{ i .directed_bonds } ' ,
870+ should_cache = is_chainable_biomolecule .item ()
871+ )
813872
814- mol_atompair_ids = unpack_one ( packed_atompair_ids )
873+ mol_atompair_ids = maybe_cached_get_atompair_ids ( mol , directed_bonds = i . directed_bonds )
815874
816875 # /einx.set_at
817876
818- row_col_slice = slice (offset , offset + num_atoms )
819- atompair_ids [row_col_slice , row_col_slice ] = mol_atompair_ids
877+ if exists (mol_atompair_ids ) and mol_atompair_ids .numel () > 0 :
878+ num_atoms = mol .GetNumAtoms ()
879+ row_col_slice = slice (offset , offset + num_atoms )
880+ atompair_ids [row_col_slice , row_col_slice ] = mol_atompair_ids
820881
821882 # if is chainable biomolecule
822883 # and not the first biomolecule in the chain, add a single covalent bond between first atom of incoming biomolecule and the last atom of the last biomolecule
@@ -1251,10 +1312,6 @@ def molecule_lengthed_molecule_input_to_atom_input(
12511312 atompair_ids = None
12521313
12531314 if i .add_atompair_ids :
1254- atom_bond_index = {symbol : (idx + 1 ) for idx , symbol in enumerate (ATOM_BONDS )}
1255- num_atom_bond_types = len (atom_bond_index )
1256-
1257- other_index = len (ATOM_BONDS ) + 1
12581315
12591316 atompair_ids = torch .zeros (total_atoms , total_atoms ).long ()
12601317
@@ -1265,56 +1322,35 @@ def molecule_lengthed_molecule_input_to_atom_input(
12651322
12661323 for (
12671324 mol ,
1325+ mol_id ,
12681326 is_first_mol_in_chain ,
12691327 is_chainable_biomolecule ,
12701328 src_tgt_atom_indices ,
12711329 offset ,
12721330 ) in zip (
12731331 molecules ,
1332+ molecule_ids ,
12741333 is_first_mol_in_chains ,
12751334 is_chainable_biomolecules ,
12761335 i .src_tgt_atom_indices ,
12771336 offsets ,
12781337 ):
1279- coordinates = []
1280- updates = []
1281-
1282- num_atoms = mol .GetNumAtoms ()
1283- mol_atompair_ids = torch .zeros (num_atoms , num_atoms ).long ()
1284-
1285- for bond in mol .GetBonds ():
1286- atom_start_index = bond .GetBeginAtomIdx ()
1287- atom_end_index = bond .GetEndAtomIdx ()
1288-
1289- coordinates .extend (
1290- [
1291- [atom_start_index , atom_end_index ],
1292- [atom_end_index , atom_start_index ],
1293- ]
1294- )
1295-
1296- bond_type = bond .GetBondType ()
1297- bond_id = atom_bond_index .get (bond_type , other_index ) + 1
12981338
1299- # default to symmetric bond type (undirected atom bonds)
1300-
1301- bond_to = bond_from = bond_id
1302-
1303- # if allowing for directed bonds, assume num_atompair_embeds = (2 * num_atom_bond_types) + 1
1304- # offset other edge by num_atom_bond_types
1305-
1306- if i .directed_bonds :
1307- bond_from += num_atom_bond_types
1308-
1309- updates .extend ([bond_to , bond_from ])
1339+ maybe_cached_get_atompair_ids = maybe_cache (
1340+ get_atompair_ids ,
1341+ cache = ATOMPAIR_IDS_CACHE ,
1342+ key = f'{ mol_id } :{ i .directed_bonds } ' ,
1343+ should_cache = is_chainable_biomolecule .item ()
1344+ )
13101345
1311- coordinates = tensor (coordinates ).long ()
1312- updates = tensor (updates ).long ()
1346+ mol_atompair_ids = maybe_cached_get_atompair_ids (mol , directed_bonds = i .directed_bonds )
13131347
13141348 # mol_atompair_ids = einx.set_at("[h w], c [2], c -> [h w]", mol_atompair_ids, coordinates, updates)
13151349
1316- row_col_slice = slice (offset , offset + num_atoms )
1317- atompair_ids [row_col_slice , row_col_slice ] = mol_atompair_ids
1350+ if exists (mol_atompair_ids ) and mol_atompair_ids .numel () > 0 :
1351+ num_atoms = mol .GetNumAtoms ()
1352+ row_col_slice = slice (offset , offset + num_atoms )
1353+ atompair_ids [row_col_slice , row_col_slice ] = mol_atompair_ids
13181354
13191355 # if is chainable biomolecule
13201356 # and not the first biomolecule in the chain, add a single covalent bond between first atom of incoming biomolecule and the last atom of the last biomolecule
0 commit comments