1414from prism_pruner .graph_manipulations import graphize
1515from prism_pruner .periodic_table import MASSES_TABLE
1616from prism_pruner .rmsd import rmsd_and_max
17+ from prism_pruner .timeout_context import Timeout
1718from prism_pruner .torsion_module import (
1819 _get_rotation_mask ,
1920 get_angles ,
@@ -44,6 +45,8 @@ class PrunerConfig:
4445 energies : Array1D_float = field (default_factory = lambda : np .array ([]))
4546 max_dE : float = field (default = 0.0 )
4647 debugfunction : Callable [[str ], None ] | None = field (default = None )
48+ timeout_s : int = field (default = 60 )
49+ logfunction : Callable [[str ], None ] | None = print
4750
4851 # Computed fields
4952 eval_calls : int = field (default = 0 , init = False )
@@ -143,7 +146,7 @@ def __post_init__(self) -> None:
143146 assert type (self .atoms ) is np .ndarray
144147
145148 # pre-compute heavy atom mask once
146- if self .heavy_atoms_only and np . count_nonzero ( self . atoms != "H" ) > 0 :
149+ if self .heavy_atoms_only :
147150 self ._heavy_mask : Array1D_bool = self .atoms != "H"
148151 else :
149152 self ._heavy_mask = np .ones (self .atoms .shape [0 ], dtype = np .bool_ )
@@ -187,15 +190,6 @@ def __post_init__(self) -> None:
187190 for c , coord in enumerate (self .structures )
188191 }
189192
190- # check the first structure to assess if any
191- # moment is zero and we should therefore not
192- # bother using it to compare structures
193- self .check_I_on_axis : tuple [bool , bool , bool ] = (
194- self .moi_vecs [0 ][0 ] > 1e-6 ,
195- self .moi_vecs [0 ][1 ] > 1e-6 ,
196- self .moi_vecs [0 ][2 ] > 1e-6 ,
197- )
198-
199193 def evaluate_sim (self , i1 : int , i2 : int ) -> bool :
200194 """Return whether the structures are similar."""
201195 im_1 = self .moi_vecs [i1 ]
@@ -204,8 +198,8 @@ def evaluate_sim(self, i1: int, i2: int) -> bool:
204198 # compare the three MOIs via a Python loop:
205199 # apparently much faster than numpy array operations
206200 # for such a small array!
207- for j , check_axis in enumerate ( self . check_I_on_axis ):
208- if check_axis and np .abs (im_1 [j ] - im_2 [j ]) / im_1 [j ] >= self .max_dev :
201+ for j in range ( 3 ):
202+ if np .abs (im_1 [j ] - im_2 [j ]) / im_1 [j ] >= self .max_dev :
209203 return False
210204 return True
211205
@@ -352,6 +346,7 @@ def _run(prunerconfig: PrunerConfig) -> tuple[Array2D_float, Array1D_bool]:
352346 Sets the self.structures and the corresponding self.mask attributes.
353347 """
354348 start_t = perf_counter ()
349+ timed_out = False
355350
356351 # initialize the output mask
357352 out_mask = np .ones (shape = prunerconfig .structures .shape [0 ], dtype = np .bool_ )
@@ -388,18 +383,27 @@ def _run(prunerconfig: PrunerConfig) -> tuple[Array2D_float, Array1D_bool]:
388383 ):
389384 # choose only k values such that every subgroup
390385 # has on average at least twenty active structures in it
391- if k == 1 or 20 * k < np .count_nonzero (out_mask ):
386+ if ( k == 1 or ( 20 * k < np .count_nonzero (out_mask ))) and not timed_out :
392387 before = np .count_nonzero (out_mask )
393388
394389 start_t_k = perf_counter ()
395390
396- # compute similarities and get back the out_mask
397- # and the pairings to be added to cache
398- out_mask = _main_compute_group (
399- prunerconfig ,
400- out_mask ,
401- k = k ,
402- )
391+ try :
392+ with Timeout (seconds = int (prunerconfig .timeout_s - (perf_counter () - start_t ))):
393+ # compute similarities and get back the out_mask
394+ # and the pairings to be added to cache
395+ out_mask = _main_compute_group (
396+ prunerconfig ,
397+ out_mask ,
398+ k = k ,
399+ )
400+ except TimeoutError :
401+ timed_out = True
402+ if prunerconfig .debugfunction is not None :
403+ prunerconfig .debugfunction (
404+ f"TIMEOUT: { prunerconfig .__class__ .__name__ } timed out on k={ k } "
405+ f"({ prunerconfig .timeout_s } s)."
406+ )
403407
404408 after = np .count_nonzero (out_mask )
405409 newly_discarded = before - after
@@ -444,6 +448,7 @@ def prune_by_rmsd(
444448 max_dev : float | None = None ,
445449 energies : Array1D_float | None = None ,
446450 max_dE : float = 0.0 ,
451+ timeout_s : int = 60 ,
447452 heavy_atoms_only : bool = True ,
448453 debugfunction : Callable [[str ], None ] | None = None ,
449454) -> tuple [Array3D_float , Array1D_bool ]:
@@ -484,6 +489,7 @@ def prune_by_rmsd(
484489 max_dev = max_dev ,
485490 energies = energies ,
486491 max_dE = max_dE ,
492+ timeout_s = timeout_s ,
487493 debugfunction = debugfunction ,
488494 heavy_atoms_only = heavy_atoms_only ,
489495 )
@@ -512,12 +518,8 @@ def _batch_rmsd_prune(
512518 start_t = perf_counter ()
513519
514520 N = len (structures )
515-
516- # check how many heavy atoms: if none, use all
517- M = int (np .count_nonzero (atoms != "H" ))
518-
519521 heavy_mask : Array1D_bool = (
520- atoms != "H" if ( heavy_atoms_only and M > 0 ) else np .ones (structures .shape [1 ], dtype = bool )
522+ atoms != "H" if heavy_atoms_only else np .ones (structures .shape [1 ], dtype = bool )
521523 )
522524 M = int (heavy_mask .sum ())
523525
@@ -688,6 +690,7 @@ def prune_by_rmsd_rot_corr(
688690 max_dev : float | None = None ,
689691 energies : Array1D_float | None = None ,
690692 max_dE : float = 0.0 ,
693+ timeout_s : int = 60 ,
691694 heavy_atoms_only : bool = True ,
692695 logfunction : Callable [[str ], None ] | None = None ,
693696 debugfunction : Callable [[str ], None ] | None = None ,
@@ -861,6 +864,7 @@ def prune_by_rmsd_rot_corr(
861864 max_dev = max_dev ,
862865 single_atom_masks = single_atom_masks ,
863866 rotation_masks = rotation_masks ,
867+ timeout_s = timeout_s ,
864868 )
865869 _ , mask = _run (prunerconfig )
866870
@@ -879,6 +883,7 @@ def prune_by_moment_of_inertia(
879883 max_deviation : float = 1e-2 ,
880884 energies : Array1D_float | None = None ,
881885 max_dE : float = 0.0 ,
886+ timeout_s : int = 60 ,
882887 debugfunction : Callable [[str ], None ] | None = None ,
883888) -> tuple [Array3D_float , Array1D_bool ]:
884889 """Remove duplicate structures using a moments of inertia-based metric.
@@ -896,6 +901,7 @@ def prune_by_moment_of_inertia(
896901 structures = structures ,
897902 energies = energies ,
898903 max_dE = max_dE ,
904+ timeout_s = timeout_s ,
899905 debugfunction = debugfunction ,
900906 max_dev = max_deviation ,
901907 masses = np .array ([MASSES_TABLE [a ] for a in atoms ]),
@@ -914,6 +920,7 @@ def prune(
914920 max_dE : float = 0.0 ,
915921 debugfunction : Callable [[str ], None ] | None = None ,
916922 logfunction : Callable [[str ], None ] | None = None ,
923+ timeout_s : int = 60 ,
917924) -> tuple [Array3D_float , Array1D_bool ]:
918925 """Remove duplicate structures.
919926
@@ -923,6 +930,8 @@ def prune(
923930 Will only compare structures less than max_dE apart
924931 in energy, if energies and max_dE are provided.
925932
933+ Each pruning step will be timed out after timeout_s seconds.
934+
926935 Note: will use automatic pruning thresholds.
927936 """
928937 if energies is None :
@@ -939,8 +948,10 @@ def prune(
939948 max_deviation = 0.01 ,
940949 energies = energies ,
941950 max_dE = max_dE ,
951+ timeout_s = timeout_s ,
942952 debugfunction = debugfunction ,
943953 )
954+
944955 energies = energies [mask ]
945956 active_indices = active_indices [mask ]
946957
@@ -956,6 +967,7 @@ def prune(
956967 max_dev = 0.5 ,
957968 energies = energies ,
958969 max_dE = max_dE ,
970+ timeout_s = timeout_s ,
959971 debugfunction = debugfunction ,
960972 )
961973 energies = energies [mask ]
@@ -976,6 +988,7 @@ def prune(
976988 max_dev = 0.5 ,
977989 energies = energies ,
978990 max_dE = max_dE ,
991+ timeout_s = timeout_s ,
979992 debugfunction = debugfunction ,
980993 logfunction = logfunction ,
981994 )
0 commit comments