Skip to content

Commit 762e28c

Browse files
authored
Merge pull request #11 from ntampellini/performance/rmsd
[Performance] Pruning energy-sorted ensembles
2 parents d0e55e9 + 7ba0407 commit 762e28c

File tree

13 files changed

+164625
-868
lines changed

13 files changed

+164625
-868
lines changed

examples/example_notebook.ipynb

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@
5252
},
5353
{
5454
"cell_type": "code",
55-
"execution_count": 3,
55+
"execution_count": 4,
5656
"id": "31198a3f",
5757
"metadata": {},
5858
"outputs": [
@@ -61,12 +61,12 @@
6161
"output_type": "stream",
6262
"text": [
6363
"DEBUG: MOIPrunerConfig - k=50, rejected 449 (keeping 637/1086), in 0.1 s\n",
64-
"DEBUG: MOIPrunerConfig - k=20, rejected 109 (keeping 528/1086), in 0.1 s\n",
65-
"DEBUG: MOIPrunerConfig - k=10, rejected 27 (keeping 501/1086), in 0.1 s\n",
66-
"DEBUG: MOIPrunerConfig - k=5, rejected 28 (keeping 473/1086), in 0.4 s\n",
67-
"DEBUG: MOIPrunerConfig - k=2, rejected 38 (keeping 435/1086), in 0.5 s\n",
68-
"DEBUG: MOIPrunerConfig - k=1, rejected 10 (keeping 425/1086), in 0.6 s\n",
69-
"DEBUG: MOIPrunerConfig - keeping 425/1086 (1.9 s)\n",
64+
"DEBUG: MOIPrunerConfig - k=20, rejected 109 (keeping 528/1086), in 0.0 s\n",
65+
"DEBUG: MOIPrunerConfig - k=10, rejected 27 (keeping 501/1086), in 0.0 s\n",
66+
"DEBUG: MOIPrunerConfig - k=5, rejected 28 (keeping 473/1086), in 0.1 s\n",
67+
"DEBUG: MOIPrunerConfig - k=2, rejected 38 (keeping 435/1086), in 0.2 s\n",
68+
"DEBUG: MOIPrunerConfig - k=1, rejected 10 (keeping 425/1086), in 0.3 s\n",
69+
"DEBUG: MOIPrunerConfig - keeping 425/1086 (0.8 s)\n",
7070
"DEBUG: MOIPrunerConfig - Used cached data 105595/211707 times, 49.88% of total calls\n"
7171
]
7272
},
@@ -76,7 +76,7 @@
7676
"(425, 136, 3)"
7777
]
7878
},
79-
"execution_count": 3,
79+
"execution_count": 4,
8080
"metadata": {},
8181
"output_type": "execute_result"
8282
}

pixi.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

prism_pruner/algebra.py

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -118,27 +118,23 @@ def quaternion_to_rotation_matrix(quat: Array1D_float | Sequence[float]) -> Arra
118118

119119

120120
def get_inertia_moments(coords: Array3D_float, masses: Array1D_float) -> Array1D_float:
121-
"""
122-
Find the moments of inertia of the three principal axes.
121+
"""Compute the principal moments of inertia of a molecule.
123122
124-
:return: diagonal of the diagonalized inertia tensor, that is
125-
a shape (3,) array with the moments of inertia along the main axes.
126-
(I_x, I_y and largest I_z last)
123+
Returns a length-3 array [I_x, I_y, I_z], sorted ascending.
127124
"""
128-
# Center coordinates around the center of mass
129-
coords = coords - np.sum(coords * masses[:, np.newaxis], axis=0)
130-
131-
# Compute r^2 for each atom
132-
norms_squared = np.einsum("ni,ni->n", coords, coords)
125+
# Shift to center of mass
126+
com = np.sum(coords * masses[:, np.newaxis], axis=0) / np.sum(masses)
127+
coords = coords - com
133128

134-
# Build inertia tensor using einsum
135-
total = np.sum(masses * norms_squared)
136-
inertia_moment_matrix = total * np.eye(3) - np.einsum("n,ni,nj->ij", masses, coords, coords)
129+
# Compute inertia tensor
130+
norms_sq = np.einsum("ni,ni->n", coords, coords)
131+
total = np.sum(masses * norms_sq)
132+
I_matrix = total * np.eye(3) - np.einsum("n,ni,nj->ij", masses, coords, coords)
137133

138-
# diagonalize the matrix and return the diagonal
139-
inertia_moment_matrix = diagonalize(inertia_moment_matrix)
134+
# Principal moments via symmetric eigendecomposition
135+
moments, _ = np.linalg.eigh(I_matrix)
140136

141-
return np.diag(inertia_moment_matrix)
137+
return np.sort(moments)
142138

143139

144140
def diagonalize(a: Array2D_float) -> Array2D_float:

prism_pruner/conformer_ensemble.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
"""ConformerEnsemble class."""
22

3-
from dataclasses import dataclass
3+
import re
4+
from dataclasses import dataclass, field
45
from pathlib import Path
56
from typing import Self
67

78
import numpy as np
89

9-
from prism_pruner.typing import Array1D_str, Array2D_float, Array3D_float
10+
from prism_pruner.typing import Array1D_float, Array1D_str, Array2D_float, Array3D_float
1011

1112

1213
@dataclass
@@ -15,15 +16,22 @@ class ConformerEnsemble:
1516

1617
coords: Array3D_float
1718
atoms: Array1D_str
19+
energies: Array1D_float = field(default_factory=lambda: np.array([]))
1820

1921
@classmethod
20-
def from_xyz(cls, file: Path | str) -> Self:
22+
def from_xyz(cls, file: Path | str, read_energies: bool = False) -> Self:
2123
"""Generate ensemble from a multiple conformer xyz file."""
2224
coords = []
2325
atoms = []
26+
energies = []
2427
with Path(file).open() as f:
2528
for num in f:
26-
_comment = next(f)
29+
if read_energies:
30+
energy = next(re.finditer(r"-*\d+\.\d+", next(f))).group()
31+
energies.append(float(energy))
32+
else:
33+
_comment = next(f)
34+
2735
conf_atoms = []
2836
conf_coords = []
2937
for _ in range(int(num)):
@@ -34,7 +42,7 @@ def from_xyz(cls, file: Path | str) -> Self:
3442
atoms.append(conf_atoms)
3543
coords.append(conf_coords)
3644

37-
return cls(coords=np.array(coords), atoms=np.array(atoms[0]))
45+
return cls(coords=np.array(coords), atoms=np.array(atoms[0]), energies=np.array(energies))
3846

3947
def to_xyz(self, file: Path | str) -> None:
4048
"""Write ensemble to an xyz file."""

prism_pruner/pruner.py

Lines changed: 58 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,11 @@ class PrunerConfig:
3838

3939
# Optional parameters that get initialized
4040
energies: Array1D_float = field(default_factory=lambda: np.array([]))
41-
ewin: float = field(default=0.0)
41+
max_dE: float = field(default=0.0)
4242
debugfunction: Callable[[str], None] | None = field(default=None)
4343

4444
# Computed fields
45-
calls: int = field(default=0, init=False)
45+
eval_calls: int = field(default=0, init=False)
4646
cache_calls: int = field(default=0, init=False)
4747
cache: set[tuple[int, int]] = field(default_factory=lambda: set(), init=False)
4848

@@ -51,16 +51,21 @@ def __post_init__(self) -> None:
5151
self.mask = np.ones(shape=(self.structures.shape[0],), dtype=np.bool_)
5252

5353
if len(self.energies) != 0:
54-
assert self.ewin > 0.0, (
55-
"If you provide energies, please also provide an appropriate energy window ewin."
54+
assert self.max_dE > 0.0, (
55+
"If you provide energies, please also provide an appropriate energy window max_dE."
5656
)
5757

5858
# Set defaults for optional parameters
5959
if len(self.energies) == 0:
60-
self.energies = np.zeros(self.structures.shape[0])
60+
self.energies = np.zeros(self.structures.shape[0], dtype=float)
6161

62-
if self.ewin == 0.0:
63-
self.ewin = 1.0
62+
assert len(self.energies) == len(self.structures), (
63+
"Please make sure that the energies "
64+
+ "provided have the same len as the input structures."
65+
)
66+
67+
if self.max_dE == 0.0:
68+
self.max_dE = 1.0
6469

6570
def evaluate_sim(self, *args: Any, **kwargs: Any) -> bool:
6671
"""Stub method - override in subclasses as needed."""
@@ -176,7 +181,7 @@ def _main_compute_subrow(
176181
structure in structures, returning at the first instance of a match.
177182
Ignores structures that are False (0) in in_mask and does not perform
178183
the comparison if the energy difference between the structures is less
179-
than self.ewin. Saves dissimilar structural pairs (i.e. that evaluate to
184+
than self.max_dE. Saves dissimilar structural pairs (i.e. that evaluate to
180185
False (0)) by adding them to self.cache, avoiding redundant calcaulations.
181186
"""
182187
i1 = first_abs_index
@@ -191,16 +196,18 @@ def _main_compute_subrow(
191196
i2 = first_abs_index + 1 + i
192197
hash_value = (i1, i2)
193198

194-
prunerconfig.calls += 1
195199
if hash_value in prunerconfig.cache:
196200
prunerconfig.cache_calls += 1
197201
continue
198202

199203
# if we have not computed the value before, check if the two
200204
# structures have close enough energy before running the comparison
201-
elif np.abs(prunerconfig.energies[i1] - prunerconfig.energies[i2]) < prunerconfig.ewin:
205+
elif (
206+
np.abs(prunerconfig.energies[i1] - prunerconfig.energies[i2]) < prunerconfig.max_dE
207+
):
202208
# function will return True whether the structures are similar,
203209
# and will stop iterating on this row, returning
210+
prunerconfig.eval_calls += 1
204211
if prunerconfig.evaluate_sim(i1, i2):
205212
return True
206213

@@ -309,6 +316,14 @@ def prune(prunerconfig: PrunerConfig) -> tuple[Array2D_float, Array1D_bool]:
309316
out_mask = np.ones(shape=prunerconfig.structures.shape[0], dtype=np.bool_)
310317
prunerconfig.cache = set()
311318

319+
# sort structures by ascending energy: this will have the effect of
320+
# having energetically similar structures end up in the same chunk
321+
# and therefore being pruned early
322+
if np.abs(prunerconfig.energies[-1]) > 0:
323+
sorting_indices = np.argsort(prunerconfig.energies)
324+
prunerconfig.structures = prunerconfig.structures[sorting_indices]
325+
prunerconfig.energies = prunerconfig.energies[sorting_indices]
326+
312327
# split the structure array in subgroups and prune them internally
313328
for k in (
314329
500_000,
@@ -365,11 +380,17 @@ def prune(prunerconfig: PrunerConfig) -> tuple[Array2D_float, Array1D_bool]:
365380
+ f"({time_to_string(elapsed)})"
366381
)
367382

368-
fraction = 0 if prunerconfig.calls == 0 else prunerconfig.cache_calls / prunerconfig.calls
383+
if prunerconfig.eval_calls == 0:
384+
fraction = 0.0
385+
else:
386+
fraction = prunerconfig.cache_calls / (
387+
prunerconfig.eval_calls + prunerconfig.cache_calls
388+
)
389+
369390
prunerconfig.debugfunction(
370391
f"DEBUG: {prunerconfig.__class__.__name__} - Used cached data "
371-
+ f"{prunerconfig.cache_calls}/{prunerconfig.calls} times, "
372-
+ f"{100 * fraction:.2f}% of total calls"
392+
+ f"{prunerconfig.cache_calls}/{prunerconfig.eval_calls + prunerconfig.cache_calls}"
393+
+ f" times, {100 * fraction:.2f}% of total calls"
373394
)
374395

375396
return prunerconfig.structures[out_mask], out_mask
@@ -380,6 +401,8 @@ def prune_by_rmsd(
380401
atoms: Array1D_str,
381402
max_rmsd: float = 0.25,
382403
max_dev: float | None = None,
404+
energies: Array1D_float | None = None,
405+
max_dE: float = 0.0,
383406
debugfunction: Callable[[str], None] | None = None,
384407
) -> tuple[Array3D_float, Array1D_bool]:
385408
"""Remove duplicate structures using a heavy-atom RMSD metric.
@@ -391,6 +414,9 @@ def prune_by_rmsd(
391414
Similarity occurs for structures with both RMSD < max_rmsd and
392415
maximum deviation < max_dev. max_dev by default is 2 * max_rmsd.
393416
"""
417+
if energies is None:
418+
energies = np.array([])
419+
394420
# set default max_dev if not provided
395421
max_dev = max_dev or 2 * max_rmsd
396422

@@ -400,6 +426,8 @@ def prune_by_rmsd(
400426
atoms=atoms,
401427
max_rmsd=max_rmsd,
402428
max_dev=max_dev,
429+
energies=energies,
430+
max_dE=max_dE,
403431
debugfunction=debugfunction,
404432
)
405433

@@ -413,6 +441,8 @@ def prune_by_rmsd_rot_corr(
413441
graph: Graph,
414442
max_rmsd: float = 0.25,
415443
max_dev: float | None = None,
444+
energies: Array1D_float | None = None,
445+
max_dE: float = 0.0,
416446
logfunction: Callable[[str], None] | None = None,
417447
debugfunction: Callable[[str], None] | None = None,
418448
) -> tuple[Array3D_float, Array1D_bool]:
@@ -535,10 +565,15 @@ def prune_by_rmsd_rot_corr(
535565
)
536566
logfunction("\n")
537567

568+
if energies is None:
569+
energies = np.array([])
570+
538571
# Initialize PrunerConfig
539572
prunerconfig = RMSDRotCorrPrunerConfig(
540573
structures=structures,
541574
atoms=atoms,
575+
energies=energies,
576+
max_dE=max_dE,
542577
graph=graph,
543578
torsions=torsions_ids,
544579
debugfunction=debugfunction,
@@ -561,19 +596,25 @@ def prune_by_moment_of_inertia(
561596
structures: Array3D_float,
562597
atoms: Array1D_str,
563598
max_deviation: float = 1e-2,
599+
energies: Array1D_float | None = None,
600+
max_dE: float = 0.0,
564601
debugfunction: Callable[[str], None] | None = None,
565602
) -> tuple[Array3D_float, Array1D_bool]:
566603
"""Remove duplicate structures using a moments of inertia-based metric.
567604
568605
Remove duplicate structures (enantiomeric or rotameric) based on the
569-
moments of inertia on the principal axes. If all three MOI
570-
deviate less than max_deviation percent from another structure,
571-
they are classified as rotamers or enantiomers and therefore only one
572-
of them is kept (i.e. max_deviation = 0.1 is 10% relative deviation).
606+
moment of inertia on the principal axes. If all three deviate less than
607+
max_deviation percent from another one, the structure is removed from
608+
the ensemble (i.e. max_deviation = 0.1 is 10% relative deviation).
573609
"""
610+
if energies is None:
611+
energies = np.array([])
612+
574613
# set up PrunerConfig dataclass
575614
prunerconfig = MOIPrunerConfig(
576615
structures=structures,
616+
energies=energies,
617+
max_dE=max_dE,
577618
debugfunction=debugfunction,
578619
max_dev=max_deviation,
579620
masses=np.array([elements.symbol(a).mass for a in atoms]),

prism_pruner/rmsd.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
def rmsd_and_max(
1010
p: Array2D_float,
1111
q: Array2D_float,
12-
center: bool = False,
12+
center: bool = True,
1313
) -> tuple[float, float]:
1414
"""Return RMSD and max deviation.
1515

prism_pruner/torsion_module.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -418,6 +418,10 @@ def rotationally_corrected_rmsd_and_max(
418418

419419
torsion_corrections = [0 for _ in torsions]
420420

421+
mask = (
422+
np.array([a != "H" for a in atoms]) if heavy_atoms_only else np.ones(len(atoms), dtype=bool)
423+
)
424+
421425
# Now rotate every dummy torsion by the appropriate increment until we minimize local RMSD
422426
for i, torsion in enumerate(torsions):
423427
best_rmsd = 1e10
@@ -432,7 +436,7 @@ def rotationally_corrected_rmsd_and_max(
432436
best_rmsd = locally_corrected_rmsd
433437
torsion_corrections[i] = angle
434438

435-
# it is faster to undo the rotation rather than working with a copy of coords
439+
# it is faster to undo the rotation rather than working with a copy of coordss
436440
coord = rotate_dihedral(coord, torsion, -angle, indices_to_be_moved=[torsion[3]])
437441

438442
# now rotate that angle to the desired orientation before going to the next angle
@@ -442,18 +446,14 @@ def rotationally_corrected_rmsd_and_max(
442446
)
443447

444448
if debugfunction is not None:
445-
heavy_mask = np.array([a != "H" for a in atoms])
446-
global_rmsd = rmsd_and_max(ref[heavy_mask], coord[heavy_mask])[0]
449+
global_rmsd = rmsd_and_max(ref[mask], coord[mask])[0]
447450
debugfunction(
448451
f" Torsion {i + 1} - {torsion}: best θ = {torsion_corrections[i]}°, "
449452
+ f"4-atom RMSD: {best_rmsd:.3f} Å, global RMSD: {global_rmsd:.3f} Å"
450453
)
451454

452455
# we should have the optimal orientation on all torsions now:
453456
# calculate the RMSD
454-
mask = (
455-
np.array([a != "H" for a in atoms]) if heavy_atoms_only else np.ones(len(atoms), dtype=bool)
456-
)
457457
rmsd, maxdev = rmsd_and_max(ref[mask], coord[mask])
458458

459459
# since we could have segmented graphs, and therefore potentially only rotate

0 commit comments

Comments
 (0)