213213# simple caching
214214
215215ATOMPAIR_IDS_CACHE = dict ()
216+ HAS_BOND_CACHE = dict ()
216217
217218@typecheck
218219def maybe_cache (
219220 fn ,
220221 * ,
221222 cache : dict ,
222- key : str ,
223+ key : str | None ,
223224 should_cache : bool = True
224225) -> Callable :
225226
226- if not should_cache :
227+ if not should_cache or not exists ( key ) :
227228 return fn
228229
229230 @wraps (fn )
@@ -246,7 +247,7 @@ def inner(*args, **kwargs):
246247def get_atompair_ids (
247248 mol : Mol ,
248249 directed_bonds : bool
249- ) -> Tensor | None :
250+ ) -> Int [ 'm m' ] | None :
250251
251252 coordinates = []
252253 updates = []
@@ -304,6 +305,46 @@ def get_atompair_ids(
304305
305306 return mol_atompair_ids
306307
308+ @typecheck
309+ def get_mol_has_bond (
310+ mol : Mol
311+ ) -> Bool ['m m' ] | None :
312+
313+ coordinates = []
314+
315+ bonds = mol .GetBonds ()
316+ num_bonds = len (bonds )
317+
318+ for bond in bonds :
319+ atom_start_index = bond .GetBeginAtomIdx ()
320+ atom_end_index = bond .GetEndAtomIdx ()
321+
322+ coordinates .extend (
323+ [
324+ [atom_start_index , atom_end_index ],
325+ [atom_end_index , atom_start_index ],
326+ ]
327+ )
328+
329+ if num_bonds == 0 :
330+ return None
331+
332+ num_atoms = mol .GetNumAtoms ()
333+ has_bond = torch .zeros (num_atoms , num_atoms ).bool ()
334+
335+ coordinates = tensor (coordinates ).long ()
336+
337+ # has_bond = einx.set_at("[h w], c [2], c -> [h w]", has_bond, coordinates, updates)
338+
339+ has_bond_stride = tensor (has_bond .stride ())
340+ flattened_coordinates = (coordinates * has_bond_stride ).sum (dim = - 1 )
341+ packed_has_bond , unpack_has_bond = pack_one (has_bond , '*' )
342+
343+ packed_has_bond [flattened_coordinates ] = True
344+ has_bond = unpack_has_bond (packed_has_bond , '*' )
345+
346+ return has_bond
347+
307348# functions
308349
309350
@@ -1182,49 +1223,36 @@ def molecule_lengthed_molecule_input_to_atom_input(
11821223
11831224 for (
11841225 mol ,
1226+ mol_id ,
11851227 mol_is_chainable_biomolecule ,
11861228 mol_is_first_mol_in_chain ,
11871229 mol_is_one_token_per_atom ,
1188- ) in zip (molecules , is_chainable_biomolecules , is_first_mol_in_chains , one_token_per_atom ):
1230+ ) in zip (
1231+ molecules ,
1232+ molecule_ids ,
1233+ is_chainable_biomolecules ,
1234+ is_first_mol_in_chains ,
1235+ one_token_per_atom
1236+ ):
11891237 num_atoms = mol .GetNumAtoms ()
11901238
11911239 if mol_is_chainable_biomolecule and not mol_is_first_mol_in_chain :
11921240 token_bonds [offset , offset - 1 ] = True
11931241 token_bonds [offset - 1 , offset ] = True
11941242
11951243 if mol_is_one_token_per_atom :
1196- coordinates = []
1197-
1198- bonds = mol .GetBonds ()
1199- num_bonds = len (bonds )
1200-
1201- for bond in bonds :
1202- atom_start_index = bond .GetBeginAtomIdx ()
1203- atom_end_index = bond .GetEndAtomIdx ()
12041244
1205- coordinates .extend (
1206- [
1207- [atom_start_index , atom_end_index ],
1208- [atom_end_index , atom_start_index ],
1209- ]
1210- )
1211-
1212- if num_bonds > 0 :
1213- has_bond = torch .zeros (num_atoms , num_atoms ).bool ()
1214-
1215- coordinates = tensor (coordinates ).long ()
1216-
1217- # has_bond = einx.set_at("[h w], c [2], c -> [h w]", has_bond, coordinates, updates)
1218-
1219- has_bond_stride = tensor (has_bond .stride ())
1220- flattened_coordinates = (coordinates * has_bond_stride ).sum (dim = - 1 )
1221- packed_has_bond , unpack_has_bond = pack_one (has_bond , '*' )
1222-
1223- packed_has_bond [flattened_coordinates ] = True
1224- has_bond = unpack_has_bond (packed_has_bond , '*' )
1245+ maybe_cached_get_mol_has_bond = maybe_cache (
1246+ get_mol_has_bond ,
1247+ cache = HAS_BOND_CACHE ,
1248+ key = str (mol_id ),
1249+ should_cache = mol_is_chainable_biomolecule .item ()
1250+ )
12251251
1226- # / ein.set_at
1252+ has_bond = maybe_cached_get_mol_has_bond ( mol )
12271253
1254+ if exists (has_bond ) and has_bond .numel () > 0 :
1255+ num_atoms = mol .GetNumAtoms ()
12281256 row_col_slice = slice (offset , offset + num_atoms )
12291257 token_bonds [row_col_slice , row_col_slice ] = has_bond
12301258
0 commit comments