@@ -38,11 +38,11 @@ class PrunerConfig:
3838
3939 # Optional parameters that get initialized
4040 energies : Array1D_float = field (default_factory = lambda : np .array ([]))
41- ewin : float = field (default = 0.0 )
41+ max_dE : float = field (default = 0.0 )
4242 debugfunction : Callable [[str ], None ] | None = field (default = None )
4343
4444 # Computed fields
45- calls : int = field (default = 0 , init = False )
45+ eval_calls : int = field (default = 0 , init = False )
4646 cache_calls : int = field (default = 0 , init = False )
4747 cache : set [tuple [int , int ]] = field (default_factory = lambda : set (), init = False )
4848
@@ -51,16 +51,21 @@ def __post_init__(self) -> None:
5151 self .mask = np .ones (shape = (self .structures .shape [0 ],), dtype = np .bool_ )
5252
5353 if len (self .energies ) != 0 :
54- assert self .ewin > 0.0 , (
55- "If you provide energies, please also provide an appropriate energy window ewin ."
54+ assert self .max_dE > 0.0 , (
55+ "If you provide energies, please also provide an appropriate energy window max_dE ."
5656 )
5757
5858 # Set defaults for optional parameters
5959 if len (self .energies ) == 0 :
60- self .energies = np .zeros (self .structures .shape [0 ])
60+ self .energies = np .zeros (self .structures .shape [0 ], dtype = float )
6161
62- if self .ewin == 0.0 :
63- self .ewin = 1.0
62+ assert len (self .energies ) == len (self .structures ), (
63+ "Please make sure that the energies "
64+ + "provided have the same len as the input structures."
65+ )
66+
67+ if self .max_dE == 0.0 :
68+ self .max_dE = 1.0
6469
6570 def evaluate_sim (self , * args : Any , ** kwargs : Any ) -> bool :
6671 """Stub method - override in subclasses as needed."""
@@ -176,7 +181,7 @@ def _main_compute_subrow(
176181 structure in structures, returning at the first instance of a match.
177182 Ignores structures that are False (0) in in_mask and does not perform
178183 the comparison if the energy difference between the structures is less
179- than self.ewin . Saves dissimilar structural pairs (i.e. that evaluate to
184+ than self.max_dE . Saves dissimilar structural pairs (i.e. that evaluate to
180185 False (0)) by adding them to self.cache, avoiding redundant calcaulations.
181186 """
182187 i1 = first_abs_index
@@ -191,16 +196,18 @@ def _main_compute_subrow(
191196 i2 = first_abs_index + 1 + i
192197 hash_value = (i1 , i2 )
193198
194- prunerconfig .calls += 1
195199 if hash_value in prunerconfig .cache :
196200 prunerconfig .cache_calls += 1
197201 continue
198202
199203 # if we have not computed the value before, check if the two
200204 # structures have close enough energy before running the comparison
201- elif np .abs (prunerconfig .energies [i1 ] - prunerconfig .energies [i2 ]) < prunerconfig .ewin :
205+ elif (
206+ np .abs (prunerconfig .energies [i1 ] - prunerconfig .energies [i2 ]) < prunerconfig .max_dE
207+ ):
202208 # function will return True whether the structures are similar,
203209 # and will stop iterating on this row, returning
210+ prunerconfig .eval_calls += 1
204211 if prunerconfig .evaluate_sim (i1 , i2 ):
205212 return True
206213
@@ -309,6 +316,14 @@ def prune(prunerconfig: PrunerConfig) -> tuple[Array2D_float, Array1D_bool]:
309316 out_mask = np .ones (shape = prunerconfig .structures .shape [0 ], dtype = np .bool_ )
310317 prunerconfig .cache = set ()
311318
319+ # sort structures by ascending energy: this will have the effect of
320+ # having energetically similar structures end up in the same chunk
321+ # and therefore being pruned early
322+ if np .abs (prunerconfig .energies [- 1 ]) > 0 :
323+ sorting_indices = np .argsort (prunerconfig .energies )
324+ prunerconfig .structures = prunerconfig .structures [sorting_indices ]
325+ prunerconfig .energies = prunerconfig .energies [sorting_indices ]
326+
312327 # split the structure array in subgroups and prune them internally
313328 for k in (
314329 500_000 ,
@@ -365,11 +380,17 @@ def prune(prunerconfig: PrunerConfig) -> tuple[Array2D_float, Array1D_bool]:
365380 + f"({ time_to_string (elapsed )} )"
366381 )
367382
368- fraction = 0 if prunerconfig .calls == 0 else prunerconfig .cache_calls / prunerconfig .calls
383+ if prunerconfig .eval_calls == 0 :
384+ fraction = 0.0
385+ else :
386+ fraction = prunerconfig .cache_calls / (
387+ prunerconfig .eval_calls + prunerconfig .cache_calls
388+ )
389+
369390 prunerconfig .debugfunction (
370391 f"DEBUG: { prunerconfig .__class__ .__name__ } - Used cached data "
371- + f"{ prunerconfig .cache_calls } /{ prunerconfig .calls } times, "
372- + f"{ 100 * fraction :.2f} % of total calls"
392+ + f"{ prunerconfig .cache_calls } /{ prunerconfig .eval_calls + prunerconfig . cache_calls } "
393+ + f" times, { 100 * fraction :.2f} % of total calls"
373394 )
374395
375396 return prunerconfig .structures [out_mask ], out_mask
@@ -380,6 +401,8 @@ def prune_by_rmsd(
380401 atoms : Array1D_str ,
381402 max_rmsd : float = 0.25 ,
382403 max_dev : float | None = None ,
404+ energies : Array1D_float | None = None ,
405+ max_dE : float = 0.0 ,
383406 debugfunction : Callable [[str ], None ] | None = None ,
384407) -> tuple [Array3D_float , Array1D_bool ]:
385408 """Remove duplicate structures using a heavy-atom RMSD metric.
@@ -391,6 +414,9 @@ def prune_by_rmsd(
391414 Similarity occurs for structures with both RMSD < max_rmsd and
392415 maximum deviation < max_dev. max_dev by default is 2 * max_rmsd.
393416 """
417+ if energies is None :
418+ energies = np .array ([])
419+
394420 # set default max_dev if not provided
395421 max_dev = max_dev or 2 * max_rmsd
396422
@@ -400,6 +426,8 @@ def prune_by_rmsd(
400426 atoms = atoms ,
401427 max_rmsd = max_rmsd ,
402428 max_dev = max_dev ,
429+ energies = energies ,
430+ max_dE = max_dE ,
403431 debugfunction = debugfunction ,
404432 )
405433
@@ -413,6 +441,8 @@ def prune_by_rmsd_rot_corr(
413441 graph : Graph ,
414442 max_rmsd : float = 0.25 ,
415443 max_dev : float | None = None ,
444+ energies : Array1D_float | None = None ,
445+ max_dE : float = 0.0 ,
416446 logfunction : Callable [[str ], None ] | None = None ,
417447 debugfunction : Callable [[str ], None ] | None = None ,
418448) -> tuple [Array3D_float , Array1D_bool ]:
@@ -535,10 +565,15 @@ def prune_by_rmsd_rot_corr(
535565 )
536566 logfunction ("\n " )
537567
568+ if energies is None :
569+ energies = np .array ([])
570+
538571 # Initialize PrunerConfig
539572 prunerconfig = RMSDRotCorrPrunerConfig (
540573 structures = structures ,
541574 atoms = atoms ,
575+ energies = energies ,
576+ max_dE = max_dE ,
542577 graph = graph ,
543578 torsions = torsions_ids ,
544579 debugfunction = debugfunction ,
@@ -561,19 +596,25 @@ def prune_by_moment_of_inertia(
561596 structures : Array3D_float ,
562597 atoms : Array1D_str ,
563598 max_deviation : float = 1e-2 ,
599+ energies : Array1D_float | None = None ,
600+ max_dE : float = 0.0 ,
564601 debugfunction : Callable [[str ], None ] | None = None ,
565602) -> tuple [Array3D_float , Array1D_bool ]:
566603 """Remove duplicate structures using a moments of inertia-based metric.
567604
568605 Remove duplicate structures (enantiomeric or rotameric) based on the
569- moments of inertia on the principal axes. If all three MOI
570- deviate less than max_deviation percent from another structure,
571- they are classified as rotamers or enantiomers and therefore only one
572- of them is kept (i.e. max_deviation = 0.1 is 10% relative deviation).
606+ moment of inertia on the principal axes. If all three deviate less than
607+ max_deviation percent from another one, the structure is removed from
608+ the ensemble (i.e. max_deviation = 0.1 is 10% relative deviation).
573609 """
610+ if energies is None :
611+ energies = np .array ([])
612+
574613 # set up PrunerConfig dataclass
575614 prunerconfig = MOIPrunerConfig (
576615 structures = structures ,
616+ energies = energies ,
617+ max_dE = max_dE ,
577618 debugfunction = debugfunction ,
578619 max_dev = max_deviation ,
579620 masses = np .array ([elements .symbol (a ).mass for a in atoms ]),
0 commit comments