Skip to content

Commit fa1a065

Browse files
authored
first go for a simple cache for atompair ids for residues and nucleot… (#286)
* first go for a simple cache for atompair ids for residues and nucleotides * type * fix function * cache in another place * some more fixes
1 parent a14fc0b commit fa1a065

File tree

1 file changed

+133
-97
lines changed

1 file changed

+133
-97
lines changed

alphafold3_pytorch/inputs.py

Lines changed: 133 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,15 @@
88
import statistics
99
import traceback
1010
from collections import defaultdict
11-
from collections.abc import Iterable
1211
from contextlib import redirect_stderr
1312
from dataclasses import asdict, dataclass, field
1413
from datetime import datetime, timedelta
15-
from functools import partial
14+
from functools import partial, wraps
1615
from io import StringIO
1716
from itertools import groupby
1817
from pathlib import Path
19-
from retrying import retry
18+
19+
from collections.abc import Iterable
2020
from beartype.typing import (
2121
Any,
2222
Callable,
@@ -28,23 +28,29 @@
2828
Type,
2929
)
3030

31-
import einx
3231
import numpy as np
3332
import polars as pl
33+
3434
import timeout_decorator
35+
from retrying import retry
36+
3537
import torch
3638
import torch.nn.functional as F
39+
from torch import repeat_interleave, tensor
40+
from torch.nn.utils.rnn import pad_sequence
41+
from torch.utils.data import Dataset
42+
43+
import einx
3744
from einops import pack, rearrange
45+
3846
from joblib import Parallel, delayed
3947
from loguru import logger
4048
from pdbeccdutils.core import ccd_reader
49+
4150
from rdkit import Chem, RDLogger, rdBase
4251
from rdkit.Chem import AllChem, rdDetermineBonds
4352
from rdkit.Chem.rdchem import Atom, Mol
4453
from rdkit.Geometry import Point3D
45-
from torch import repeat_interleave, tensor
46-
from torch.nn.utils.rnn import pad_sequence
47-
from torch.utils.data import Dataset
4854

4955
from alphafold3_pytorch.common import (
5056
amino_acid_constants,
@@ -204,6 +210,100 @@
204210
}
205211
json.dump(CCD_COMPONENTS_SMILES, f)
206212

213+
# simple caching
214+
215+
ATOMPAIR_IDS_CACHE = dict()
216+
217+
@typecheck
218+
def maybe_cache(
219+
fn,
220+
*,
221+
cache: dict,
222+
key: str,
223+
should_cache: bool = True
224+
) -> Callable:
225+
226+
if not should_cache:
227+
return fn
228+
229+
@wraps(fn)
230+
def inner(*args, **kwargs):
231+
if key in cache:
232+
return cache[key]
233+
234+
out = fn(*args, **kwargs)
235+
236+
cache[key] = out
237+
return out
238+
239+
return inner
240+
241+
# get atompair bonds functions
242+
243+
ATOM_BOND_INDEX = {symbol: (idx + 1) for idx, symbol in enumerate(ATOM_BONDS)}
244+
245+
@typecheck
246+
def get_atompair_ids(
247+
mol: Mol,
248+
directed_bonds: bool
249+
) -> Tensor | None:
250+
251+
coordinates = []
252+
updates = []
253+
254+
num_atoms = mol.GetNumAtoms()
255+
mol_atompair_ids = torch.zeros(num_atoms, num_atoms).long()
256+
257+
bonds = mol.GetBonds()
258+
num_bonds = len(bonds)
259+
260+
num_atom_bond_types = len(ATOM_BOND_INDEX)
261+
other_index = len(ATOM_BONDS) + 1
262+
263+
for bond in bonds:
264+
atom_start_index = bond.GetBeginAtomIdx()
265+
atom_end_index = bond.GetEndAtomIdx()
266+
267+
coordinates.extend(
268+
[
269+
[atom_start_index, atom_end_index],
270+
[atom_end_index, atom_start_index],
271+
]
272+
)
273+
274+
bond_type = bond.GetBondType()
275+
bond_id = ATOM_BOND_INDEX.get(bond_type, other_index) + 1
276+
277+
# default to symmetric bond type (undirected atom bonds)
278+
279+
bond_to = bond_from = bond_id
280+
281+
# if allowing for directed bonds, assume num_atompair_embeds = (2 * num_atom_bond_types) + 1
282+
# offset other edge by num_atom_bond_types
283+
284+
if directed_bonds:
285+
bond_from += num_atom_bond_types
286+
287+
updates.extend([bond_to, bond_from])
288+
289+
if num_bonds == 0:
290+
return None
291+
292+
coordinates = tensor(coordinates).long()
293+
updates = tensor(updates).long()
294+
295+
# mol_atompair_ids = einx.set_at("[h w], c [2], c -> [h w]", mol_atompair_ids, coordinates, updates)
296+
297+
molpair_strides = tensor(mol_atompair_ids.stride())
298+
flattened_coordinates = (coordinates * molpair_strides).sum(dim = -1)
299+
300+
packed_atompair_ids, unpack_one = pack_one(mol_atompair_ids, '*')
301+
packed_atompair_ids[flattened_coordinates] = updates
302+
303+
mol_atompair_ids = unpack_one(packed_atompair_ids)
304+
305+
return mol_atompair_ids
306+
207307
# functions
208308

209309

@@ -730,10 +830,6 @@ def molecule_to_atom_input(mol_input: MoleculeInput) -> AtomInput:
730830
atompair_ids = None
731831

732832
if i.add_atompair_ids:
733-
atom_bond_index = {symbol: (idx + 1) for idx, symbol in enumerate(ATOM_BONDS)}
734-
num_atom_bond_types = len(atom_bond_index)
735-
736-
other_index = len(ATOM_BONDS) + 1
737833

738834
atompair_ids = torch.zeros(total_atoms, total_atoms).long()
739835

@@ -753,70 +849,35 @@ def molecule_to_atom_input(mol_input: MoleculeInput) -> AtomInput:
753849

754850
for (
755851
mol,
852+
mol_id,
756853
is_first_mol_in_chain,
757854
is_chainable_biomolecule,
758855
src_tgt_atom_indices,
759856
offset,
760857
) in zip(
761858
molecules,
859+
molecule_ids,
762860
is_first_mol_in_chains,
763861
is_chainable_biomolecules,
764862
i.src_tgt_atom_indices,
765863
offsets,
766864
):
767-
coordinates = []
768-
updates = []
769-
770-
num_atoms = mol.GetNumAtoms()
771-
mol_atompair_ids = torch.zeros(num_atoms, num_atoms).long()
772-
773-
bonds = mol.GetBonds()
774-
num_bonds = len(bonds)
775865

776-
for bond in bonds:
777-
atom_start_index = bond.GetBeginAtomIdx()
778-
atom_end_index = bond.GetEndAtomIdx()
779-
780-
coordinates.extend(
781-
[
782-
[atom_start_index, atom_end_index],
783-
[atom_end_index, atom_start_index],
784-
]
785-
)
786-
787-
bond_type = bond.GetBondType()
788-
bond_id = atom_bond_index.get(bond_type, other_index) + 1
789-
790-
# default to symmetric bond type (undirected atom bonds)
791-
792-
bond_to = bond_from = bond_id
793-
794-
# if allowing for directed bonds, assume num_atompair_embeds = (2 * num_atom_bond_types) + 1
795-
# offset other edge by num_atom_bond_types
796-
797-
if i.directed_bonds:
798-
bond_from += num_atom_bond_types
799-
800-
updates.extend([bond_to, bond_from])
801-
802-
if num_bonds > 0:
803-
coordinates = tensor(coordinates).long()
804-
updates = tensor(updates).long()
805-
806-
# mol_atompair_ids = einx.set_at("[h w], c [2], c -> [h w]", mol_atompair_ids, coordinates, updates)
807-
808-
molpair_strides = tensor(mol_atompair_ids.stride())
809-
flattened_coordinates = (coordinates * molpair_strides).sum(dim = -1)
810-
811-
packed_atompair_ids, unpack_one = pack_one(mol_atompair_ids, '*')
812-
packed_atompair_ids[flattened_coordinates] = updates
866+
maybe_cached_get_atompair_ids = maybe_cache(
867+
get_atompair_ids,
868+
cache = ATOMPAIR_IDS_CACHE,
869+
key = f'{mol_id}:{i.directed_bonds}',
870+
should_cache = is_chainable_biomolecule.item()
871+
)
813872

814-
mol_atompair_ids = unpack_one(packed_atompair_ids)
873+
mol_atompair_ids = maybe_cached_get_atompair_ids(mol, directed_bonds = i.directed_bonds)
815874

816875
# /einx.set_at
817876

818-
row_col_slice = slice(offset, offset + num_atoms)
819-
atompair_ids[row_col_slice, row_col_slice] = mol_atompair_ids
877+
if exists(mol_atompair_ids) and mol_atompair_ids.numel() > 0:
878+
num_atoms = mol.GetNumAtoms()
879+
row_col_slice = slice(offset, offset + num_atoms)
880+
atompair_ids[row_col_slice, row_col_slice] = mol_atompair_ids
820881

821882
# if is chainable biomolecule
822883
# and not the first biomolecule in the chain, add a single covalent bond between first atom of incoming biomolecule and the last atom of the last biomolecule
@@ -1251,10 +1312,6 @@ def molecule_lengthed_molecule_input_to_atom_input(
12511312
atompair_ids = None
12521313

12531314
if i.add_atompair_ids:
1254-
atom_bond_index = {symbol: (idx + 1) for idx, symbol in enumerate(ATOM_BONDS)}
1255-
num_atom_bond_types = len(atom_bond_index)
1256-
1257-
other_index = len(ATOM_BONDS) + 1
12581315

12591316
atompair_ids = torch.zeros(total_atoms, total_atoms).long()
12601317

@@ -1265,56 +1322,35 @@ def molecule_lengthed_molecule_input_to_atom_input(
12651322

12661323
for (
12671324
mol,
1325+
mol_id,
12681326
is_first_mol_in_chain,
12691327
is_chainable_biomolecule,
12701328
src_tgt_atom_indices,
12711329
offset,
12721330
) in zip(
12731331
molecules,
1332+
molecule_ids,
12741333
is_first_mol_in_chains,
12751334
is_chainable_biomolecules,
12761335
i.src_tgt_atom_indices,
12771336
offsets,
12781337
):
1279-
coordinates = []
1280-
updates = []
1281-
1282-
num_atoms = mol.GetNumAtoms()
1283-
mol_atompair_ids = torch.zeros(num_atoms, num_atoms).long()
1284-
1285-
for bond in mol.GetBonds():
1286-
atom_start_index = bond.GetBeginAtomIdx()
1287-
atom_end_index = bond.GetEndAtomIdx()
1288-
1289-
coordinates.extend(
1290-
[
1291-
[atom_start_index, atom_end_index],
1292-
[atom_end_index, atom_start_index],
1293-
]
1294-
)
1295-
1296-
bond_type = bond.GetBondType()
1297-
bond_id = atom_bond_index.get(bond_type, other_index) + 1
12981338

1299-
# default to symmetric bond type (undirected atom bonds)
1300-
1301-
bond_to = bond_from = bond_id
1302-
1303-
# if allowing for directed bonds, assume num_atompair_embeds = (2 * num_atom_bond_types) + 1
1304-
# offset other edge by num_atom_bond_types
1305-
1306-
if i.directed_bonds:
1307-
bond_from += num_atom_bond_types
1308-
1309-
updates.extend([bond_to, bond_from])
1339+
maybe_cached_get_atompair_ids = maybe_cache(
1340+
get_atompair_ids,
1341+
cache = ATOMPAIR_IDS_CACHE,
1342+
key = f'{mol_id}:{i.directed_bonds}',
1343+
should_cache = is_chainable_biomolecule.item()
1344+
)
13101345

1311-
coordinates = tensor(coordinates).long()
1312-
updates = tensor(updates).long()
1346+
mol_atompair_ids = maybe_cached_get_atompair_ids(mol, directed_bonds = i.directed_bonds)
13131347

13141348
# mol_atompair_ids = einx.set_at("[h w], c [2], c -> [h w]", mol_atompair_ids, coordinates, updates)
13151349

1316-
row_col_slice = slice(offset, offset + num_atoms)
1317-
atompair_ids[row_col_slice, row_col_slice] = mol_atompair_ids
1350+
if exists(mol_atompair_ids) and mol_atompair_ids.numel() > 0:
1351+
num_atoms = mol.GetNumAtoms()
1352+
row_col_slice = slice(offset, offset + num_atoms)
1353+
atompair_ids[row_col_slice, row_col_slice] = mol_atompair_ids
13181354

13191355
# if is chainable biomolecule
13201356
# and not the first biomolecule in the chain, add a single covalent bond between first atom of incoming biomolecule and the last atom of the last biomolecule

0 commit comments

Comments
 (0)