Skip to content

Commit 101c6eb

Browse files
committed
adds timeout
1 parent e46f843 commit 101c6eb

File tree

4 files changed

+67
-28
lines changed

4 files changed

+67
-28
lines changed

pixi.lock

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

prism_pruner/pruner.py

Lines changed: 38 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from prism_pruner.graph_manipulations import graphize
1515
from prism_pruner.periodic_table import MASSES_TABLE
1616
from prism_pruner.rmsd import rmsd_and_max
17+
from prism_pruner.timeout_context import Timeout
1718
from prism_pruner.torsion_module import (
1819
_get_rotation_mask,
1920
get_angles,
@@ -44,6 +45,8 @@ class PrunerConfig:
4445
energies: Array1D_float = field(default_factory=lambda: np.array([]))
4546
max_dE: float = field(default=0.0)
4647
debugfunction: Callable[[str], None] | None = field(default=None)
48+
timeout_s: int = field(default=60)
49+
logfunction: Callable[[str], None] | None = print
4750

4851
# Computed fields
4952
eval_calls: int = field(default=0, init=False)
@@ -143,7 +146,7 @@ def __post_init__(self) -> None:
143146
assert type(self.atoms) is np.ndarray
144147

145148
# pre-compute heavy atom mask once
146-
if self.heavy_atoms_only and np.count_nonzero(self.atoms != "H") > 0:
149+
if self.heavy_atoms_only:
147150
self._heavy_mask: Array1D_bool = self.atoms != "H"
148151
else:
149152
self._heavy_mask = np.ones(self.atoms.shape[0], dtype=np.bool_)
@@ -187,15 +190,6 @@ def __post_init__(self) -> None:
187190
for c, coord in enumerate(self.structures)
188191
}
189192

190-
# check the first structure to assess if any
191-
# moment is zero and we should therefore not
192-
# bother using it to compare structures
193-
self.check_I_on_axis: tuple[bool, bool, bool] = (
194-
self.moi_vecs[0][0] > 1e-6,
195-
self.moi_vecs[0][1] > 1e-6,
196-
self.moi_vecs[0][2] > 1e-6,
197-
)
198-
199193
def evaluate_sim(self, i1: int, i2: int) -> bool:
200194
"""Return whether the structures are similar."""
201195
im_1 = self.moi_vecs[i1]
@@ -204,8 +198,8 @@ def evaluate_sim(self, i1: int, i2: int) -> bool:
204198
# compare the three MOIs via a Python loop:
205199
# apparently much faster than numpy array operations
206200
# for such a small array!
207-
for j, check_axis in enumerate(self.check_I_on_axis):
208-
if check_axis and np.abs(im_1[j] - im_2[j]) / im_1[j] >= self.max_dev:
201+
for j in range(3):
202+
if np.abs(im_1[j] - im_2[j]) / im_1[j] >= self.max_dev:
209203
return False
210204
return True
211205

@@ -352,6 +346,7 @@ def _run(prunerconfig: PrunerConfig) -> tuple[Array2D_float, Array1D_bool]:
352346
Sets the self.structures and the corresponding self.mask attributes.
353347
"""
354348
start_t = perf_counter()
349+
timed_out = False
355350

356351
# initialize the output mask
357352
out_mask = np.ones(shape=prunerconfig.structures.shape[0], dtype=np.bool_)
@@ -388,18 +383,27 @@ def _run(prunerconfig: PrunerConfig) -> tuple[Array2D_float, Array1D_bool]:
388383
):
389384
# choose only k values such that every subgroup
390385
# has on average at least twenty active structures in it
391-
if k == 1 or 20 * k < np.count_nonzero(out_mask):
386+
if (k == 1 or (20 * k < np.count_nonzero(out_mask))) and not timed_out:
392387
before = np.count_nonzero(out_mask)
393388

394389
start_t_k = perf_counter()
395390

396-
# compute similarities and get back the out_mask
397-
# and the pairings to be added to cache
398-
out_mask = _main_compute_group(
399-
prunerconfig,
400-
out_mask,
401-
k=k,
402-
)
391+
try:
392+
with Timeout(seconds=int(prunerconfig.timeout_s - (perf_counter() - start_t))):
393+
# compute similarities and get back the out_mask
394+
# and the pairings to be added to cache
395+
out_mask = _main_compute_group(
396+
prunerconfig,
397+
out_mask,
398+
k=k,
399+
)
400+
except TimeoutError:
401+
timed_out = True
402+
if prunerconfig.debugfunction is not None:
403+
prunerconfig.debugfunction(
404+
f"TIMEOUT: {prunerconfig.__class__.__name__} timed out on k={k} "
405+
f"({prunerconfig.timeout_s} s)."
406+
)
403407

404408
after = np.count_nonzero(out_mask)
405409
newly_discarded = before - after
@@ -444,6 +448,7 @@ def prune_by_rmsd(
444448
max_dev: float | None = None,
445449
energies: Array1D_float | None = None,
446450
max_dE: float = 0.0,
451+
timeout_s: int = 60,
447452
heavy_atoms_only: bool = True,
448453
debugfunction: Callable[[str], None] | None = None,
449454
) -> tuple[Array3D_float, Array1D_bool]:
@@ -484,6 +489,7 @@ def prune_by_rmsd(
484489
max_dev=max_dev,
485490
energies=energies,
486491
max_dE=max_dE,
492+
timeout_s=timeout_s,
487493
debugfunction=debugfunction,
488494
heavy_atoms_only=heavy_atoms_only,
489495
)
@@ -512,12 +518,8 @@ def _batch_rmsd_prune(
512518
start_t = perf_counter()
513519

514520
N = len(structures)
515-
516-
# check how many heavy atoms: if none, use all
517-
M = int(np.count_nonzero(atoms != "H"))
518-
519521
heavy_mask: Array1D_bool = (
520-
atoms != "H" if (heavy_atoms_only and M > 0) else np.ones(structures.shape[1], dtype=bool)
522+
atoms != "H" if heavy_atoms_only else np.ones(structures.shape[1], dtype=bool)
521523
)
522524
M = int(heavy_mask.sum())
523525

@@ -688,6 +690,7 @@ def prune_by_rmsd_rot_corr(
688690
max_dev: float | None = None,
689691
energies: Array1D_float | None = None,
690692
max_dE: float = 0.0,
693+
timeout_s: int = 60,
691694
heavy_atoms_only: bool = True,
692695
logfunction: Callable[[str], None] | None = None,
693696
debugfunction: Callable[[str], None] | None = None,
@@ -861,6 +864,7 @@ def prune_by_rmsd_rot_corr(
861864
max_dev=max_dev,
862865
single_atom_masks=single_atom_masks,
863866
rotation_masks=rotation_masks,
867+
timeout_s=timeout_s,
864868
)
865869
_, mask = _run(prunerconfig)
866870

@@ -879,6 +883,7 @@ def prune_by_moment_of_inertia(
879883
max_deviation: float = 1e-2,
880884
energies: Array1D_float | None = None,
881885
max_dE: float = 0.0,
886+
timeout_s: int = 60,
882887
debugfunction: Callable[[str], None] | None = None,
883888
) -> tuple[Array3D_float, Array1D_bool]:
884889
"""Remove duplicate structures using a moments of inertia-based metric.
@@ -896,6 +901,7 @@ def prune_by_moment_of_inertia(
896901
structures=structures,
897902
energies=energies,
898903
max_dE=max_dE,
904+
timeout_s=timeout_s,
899905
debugfunction=debugfunction,
900906
max_dev=max_deviation,
901907
masses=np.array([MASSES_TABLE[a] for a in atoms]),
@@ -914,6 +920,7 @@ def prune(
914920
max_dE: float = 0.0,
915921
debugfunction: Callable[[str], None] | None = None,
916922
logfunction: Callable[[str], None] | None = None,
923+
timeout_s: int = 60,
917924
) -> tuple[Array3D_float, Array1D_bool]:
918925
"""Remove duplicate structures.
919926
@@ -923,6 +930,8 @@ def prune(
923930
Will only compare structures less than max_dE apart
924931
in energy, if energies and max_dE are provided.
925932
933+
Each pruning step will be timed out after timeout_s seconds.
934+
926935
Note: will use automatic pruning thresholds.
927936
"""
928937
if energies is None:
@@ -939,8 +948,10 @@ def prune(
939948
max_deviation=0.01,
940949
energies=energies,
941950
max_dE=max_dE,
951+
timeout_s=timeout_s,
942952
debugfunction=debugfunction,
943953
)
954+
944955
energies = energies[mask]
945956
active_indices = active_indices[mask]
946957

@@ -956,6 +967,7 @@ def prune(
956967
max_dev=0.5,
957968
energies=energies,
958969
max_dE=max_dE,
970+
timeout_s=timeout_s,
959971
debugfunction=debugfunction,
960972
)
961973
energies = energies[mask]
@@ -976,6 +988,7 @@ def prune(
976988
max_dev=0.5,
977989
energies=energies,
978990
max_dE=max_dE,
991+
timeout_s=timeout_s,
979992
debugfunction=debugfunction,
980993
logfunction=logfunction,
981994
)

prism_pruner/timeout_context.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
"""PRISM - Pruning Interface for Similar Molecules."""
2+
3+
import signal
4+
from typing import Any
5+
6+
7+
class Timeout:
8+
"""Timeout context manager."""
9+
10+
def __init__(self, seconds: int = 60, error_message: str = "Timeout") -> None:
11+
"""Define the __init__ method of the context manager."""
12+
self.seconds = seconds
13+
self.error_message = error_message
14+
15+
def handle_timeout(self, signum: Any, frame: Any) -> None:
16+
"""Handle the timeout signal."""
17+
raise TimeoutError(self.error_message)
18+
19+
def __enter__(self) -> None:
20+
"""Define the __enter__ method of the context manager."""
21+
signal.signal(signal.SIGALRM, self.handle_timeout)
22+
signal.alarm(self.seconds)
23+
24+
def __exit__(self, exc_type: object, exc_val: object, exc_tb: object) -> None:
25+
"""Define the __exit__ method of the context manager."""
26+
signal.alarm(0)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
name = "prism_pruner"
33
description = "Prism Pruner"
44
license = "MIT"
5-
version = "0.0.9"
5+
version = "0.1.0"
66
readme = "README.md"
77
keywords = []
88
authors = [{name = "Nicolò Tampellini", email = "nicolo.tampellini@yale.edu"}]

0 commit comments

Comments
 (0)