Skip to content

Commit f585a6a

Browse files
committed
cache the calculation of has_bond in molecule_lengthed_molecule_input_to_atom_input
1 parent fa1a065 commit f585a6a

File tree

2 files changed

+62
-34
lines changed

2 files changed

+62
-34
lines changed

alphafold3_pytorch/inputs.py

Lines changed: 61 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -213,17 +213,18 @@
213213
# simple caching
214214

215215
ATOMPAIR_IDS_CACHE = dict()
216+
HAS_BOND_CACHE = dict()
216217

217218
@typecheck
218219
def maybe_cache(
219220
fn,
220221
*,
221222
cache: dict,
222-
key: str,
223+
key: str | None,
223224
should_cache: bool = True
224225
) -> Callable:
225226

226-
if not should_cache:
227+
if not should_cache or not exists(key):
227228
return fn
228229

229230
@wraps(fn)
@@ -246,7 +247,7 @@ def inner(*args, **kwargs):
246247
def get_atompair_ids(
247248
mol: Mol,
248249
directed_bonds: bool
249-
) -> Tensor | None:
250+
) -> Int['m m'] | None:
250251

251252
coordinates = []
252253
updates = []
@@ -304,6 +305,46 @@ def get_atompair_ids(
304305

305306
return mol_atompair_ids
306307

308+
@typecheck
309+
def get_mol_has_bond(
310+
mol: Mol
311+
) -> Bool['m m'] | None:
312+
313+
coordinates = []
314+
315+
bonds = mol.GetBonds()
316+
num_bonds = len(bonds)
317+
318+
for bond in bonds:
319+
atom_start_index = bond.GetBeginAtomIdx()
320+
atom_end_index = bond.GetEndAtomIdx()
321+
322+
coordinates.extend(
323+
[
324+
[atom_start_index, atom_end_index],
325+
[atom_end_index, atom_start_index],
326+
]
327+
)
328+
329+
if num_bonds == 0:
330+
return None
331+
332+
num_atoms = mol.GetNumAtoms()
333+
has_bond = torch.zeros(num_atoms, num_atoms).bool()
334+
335+
coordinates = tensor(coordinates).long()
336+
337+
# has_bond = einx.set_at("[h w], c [2], c -> [h w]", has_bond, coordinates, updates)
338+
339+
has_bond_stride = tensor(has_bond.stride())
340+
flattened_coordinates = (coordinates * has_bond_stride).sum(dim = -1)
341+
packed_has_bond, unpack_has_bond = pack_one(has_bond, '*')
342+
343+
packed_has_bond[flattened_coordinates] = True
344+
has_bond = unpack_has_bond(packed_has_bond, '*')
345+
346+
return has_bond
347+
307348
# functions
308349

309350

@@ -1182,49 +1223,36 @@ def molecule_lengthed_molecule_input_to_atom_input(
11821223

11831224
for (
11841225
mol,
1226+
mol_id,
11851227
mol_is_chainable_biomolecule,
11861228
mol_is_first_mol_in_chain,
11871229
mol_is_one_token_per_atom,
1188-
) in zip(molecules, is_chainable_biomolecules, is_first_mol_in_chains, one_token_per_atom):
1230+
) in zip(
1231+
molecules,
1232+
molecule_ids,
1233+
is_chainable_biomolecules,
1234+
is_first_mol_in_chains,
1235+
one_token_per_atom
1236+
):
11891237
num_atoms = mol.GetNumAtoms()
11901238

11911239
if mol_is_chainable_biomolecule and not mol_is_first_mol_in_chain:
11921240
token_bonds[offset, offset - 1] = True
11931241
token_bonds[offset - 1, offset] = True
11941242

11951243
if mol_is_one_token_per_atom:
1196-
coordinates = []
1197-
1198-
bonds = mol.GetBonds()
1199-
num_bonds = len(bonds)
1200-
1201-
for bond in bonds:
1202-
atom_start_index = bond.GetBeginAtomIdx()
1203-
atom_end_index = bond.GetEndAtomIdx()
12041244

1205-
coordinates.extend(
1206-
[
1207-
[atom_start_index, atom_end_index],
1208-
[atom_end_index, atom_start_index],
1209-
]
1210-
)
1211-
1212-
if num_bonds > 0:
1213-
has_bond = torch.zeros(num_atoms, num_atoms).bool()
1214-
1215-
coordinates = tensor(coordinates).long()
1216-
1217-
# has_bond = einx.set_at("[h w], c [2], c -> [h w]", has_bond, coordinates, updates)
1218-
1219-
has_bond_stride = tensor(has_bond.stride())
1220-
flattened_coordinates = (coordinates * has_bond_stride).sum(dim = -1)
1221-
packed_has_bond, unpack_has_bond = pack_one(has_bond, '*')
1222-
1223-
packed_has_bond[flattened_coordinates] = True
1224-
has_bond = unpack_has_bond(packed_has_bond, '*')
1245+
maybe_cached_get_mol_has_bond = maybe_cache(
1246+
get_mol_has_bond,
1247+
cache = HAS_BOND_CACHE,
1248+
key = str(mol_id),
1249+
should_cache = mol_is_chainable_biomolecule.item()
1250+
)
12251251

1226-
# / ein.set_at
1252+
has_bond = maybe_cached_get_mol_has_bond(mol)
12271253

1254+
if exists(has_bond) and has_bond.numel() > 0:
1255+
num_atoms = mol.GetNumAtoms()
12281256
row_col_slice = slice(offset, offset + num_atoms)
12291257
token_bonds[row_col_slice, row_col_slice] = has_bond
12301258

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.5.37"
3+
version = "0.5.38"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" },

0 commit comments

Comments
 (0)