Skip to content

Commit 7ba0407

Browse files
committed
polished get_inertia_moments function
1 parent 94904b2 commit 7ba0407

File tree

3 files changed

+87
-31
lines changed

3 files changed

+87
-31
lines changed

examples/example_notebook.ipynb

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@
5252
},
5353
{
5454
"cell_type": "code",
55-
"execution_count": 3,
55+
"execution_count": 4,
5656
"id": "31198a3f",
5757
"metadata": {},
5858
"outputs": [
@@ -61,12 +61,12 @@
6161
"output_type": "stream",
6262
"text": [
6363
"DEBUG: MOIPrunerConfig - k=50, rejected 449 (keeping 637/1086), in 0.1 s\n",
64-
"DEBUG: MOIPrunerConfig - k=20, rejected 109 (keeping 528/1086), in 0.1 s\n",
65-
"DEBUG: MOIPrunerConfig - k=10, rejected 27 (keeping 501/1086), in 0.1 s\n",
66-
"DEBUG: MOIPrunerConfig - k=5, rejected 28 (keeping 473/1086), in 0.4 s\n",
67-
"DEBUG: MOIPrunerConfig - k=2, rejected 38 (keeping 435/1086), in 0.5 s\n",
68-
"DEBUG: MOIPrunerConfig - k=1, rejected 10 (keeping 425/1086), in 0.6 s\n",
69-
"DEBUG: MOIPrunerConfig - keeping 425/1086 (1.9 s)\n",
64+
"DEBUG: MOIPrunerConfig - k=20, rejected 109 (keeping 528/1086), in 0.0 s\n",
65+
"DEBUG: MOIPrunerConfig - k=10, rejected 27 (keeping 501/1086), in 0.0 s\n",
66+
"DEBUG: MOIPrunerConfig - k=5, rejected 28 (keeping 473/1086), in 0.1 s\n",
67+
"DEBUG: MOIPrunerConfig - k=2, rejected 38 (keeping 435/1086), in 0.2 s\n",
68+
"DEBUG: MOIPrunerConfig - k=1, rejected 10 (keeping 425/1086), in 0.3 s\n",
69+
"DEBUG: MOIPrunerConfig - keeping 425/1086 (0.8 s)\n",
7070
"DEBUG: MOIPrunerConfig - Used cached data 105595/211707 times, 49.88% of total calls\n"
7171
]
7272
},
@@ -76,7 +76,7 @@
7676
"(425, 136, 3)"
7777
]
7878
},
79-
"execution_count": 3,
79+
"execution_count": 4,
8080
"metadata": {},
8181
"output_type": "execute_result"
8282
}

prism_pruner/algebra.py

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -118,27 +118,23 @@ def quaternion_to_rotation_matrix(quat: Array1D_float | Sequence[float]) -> Arra
118118

119119

120120
def get_inertia_moments(coords: Array3D_float, masses: Array1D_float) -> Array1D_float:
121-
"""
122-
Find the moments of inertia of the three principal axes.
121+
"""Compute the principal moments of inertia of a molecule.
123122
124-
:return: diagonal of the diagonalized inertia tensor, that is
125-
a shape (3,) array with the moments of inertia along the main axes.
126-
(I_x, I_y and largest I_z last)
123+
Returns a length-3 array [I_x, I_y, I_z], sorted ascending.
127124
"""
128-
# Center coordinates around the center of mass
129-
coords = coords - np.sum(coords * masses[:, np.newaxis], axis=0)
130-
131-
# Compute r^2 for each atom
132-
norms_squared = np.einsum("ni,ni->n", coords, coords)
125+
# Shift to center of mass
126+
com = np.sum(coords * masses[:, np.newaxis], axis=0) / np.sum(masses)
127+
coords = coords - com
133128

134-
# Build inertia tensor using einsum
135-
total = np.sum(masses * norms_squared)
136-
inertia_moment_matrix = total * np.eye(3) - np.einsum("n,ni,nj->ij", masses, coords, coords)
129+
# Compute inertia tensor
130+
norms_sq = np.einsum("ni,ni->n", coords, coords)
131+
total = np.sum(masses * norms_sq)
132+
I_matrix = total * np.eye(3) - np.einsum("n,ni,nj->ij", masses, coords, coords)
137133

138-
# diagonalize the matrix and return the diagonal
139-
inertia_moment_matrix = diagonalize(inertia_moment_matrix)
134+
# Principal moments via symmetric eigendecomposition
135+
moments, _ = np.linalg.eigh(I_matrix)
140136

141-
return np.diag(inertia_moment_matrix)
137+
return np.sort(moments)
142138

143139

144140
def diagonalize(a: Array2D_float) -> Array2D_float:

tests/cregen_comparison_notebook.ipynb

Lines changed: 67 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
},
3030
{
3131
"cell_type": "code",
32-
"execution_count": null,
32+
"execution_count": 13,
3333
"id": "a52bffc3",
3434
"metadata": {},
3535
"outputs": [
@@ -50,8 +50,8 @@
5050
"DEBUG: RMSDPrunerConfig - keeping 61/64 (0.1 s)\n",
5151
"DEBUG: RMSDPrunerConfig - Used cached data 900/1732 times, 51.96% of total calls\n",
5252
"\n",
53-
"CPU times: user 344 ms, sys: 3.65 ms, total: 347 ms\n",
54-
"Wall time: 344 ms\n"
53+
"CPU times: user 303 ms, sys: 12.5 ms, total: 315 ms\n",
54+
"Wall time: 309 ms\n"
5555
]
5656
}
5757
],
@@ -80,7 +80,7 @@
8080
},
8181
{
8282
"cell_type": "code",
83-
"execution_count": null,
83+
"execution_count": 14,
8484
"id": "c2d35681",
8585
"metadata": {},
8686
"outputs": [
@@ -106,9 +106,9 @@
106106
" 13 - [189 193 197 218] : CCNO : 2-fold\n",
107107
"\n",
108108
"\n",
109-
"DEBUG: RMSDRotCorrPrunerConfig - k=2, rejected 14 (keeping 47/61), in 9.8 s\n",
110-
"DEBUG: RMSDRotCorrPrunerConfig - k=1, rejected 2 (keeping 45/61), in 1.2 s\n",
111-
"DEBUG: RMSDRotCorrPrunerConfig - keeping 45/61 (11.0 s)\n",
109+
"DEBUG: RMSDRotCorrPrunerConfig - k=2, rejected 14 (keeping 47/61), in 9.6 s\n",
110+
"DEBUG: RMSDRotCorrPrunerConfig - k=1, rejected 2 (keeping 45/61), in 1.1 s\n",
111+
"DEBUG: RMSDRotCorrPrunerConfig - keeping 45/61 (10.7 s)\n",
112112
"DEBUG: RMSDRotCorrPrunerConfig - Used cached data 531/1175 times, 45.19% of total calls\n"
113113
]
114114
}
@@ -127,6 +127,66 @@
127127
" logfunction=print,\n",
128128
")"
129129
]
130+
},
131+
{
132+
"cell_type": "code",
133+
"execution_count": 8,
134+
"id": "958ec59a",
135+
"metadata": {},
136+
"outputs": [
137+
{
138+
"data": {
139+
"text/plain": [
140+
"(10125, 220, 3)"
141+
]
142+
},
143+
"execution_count": 8,
144+
"metadata": {},
145+
"output_type": "execute_result"
146+
}
147+
],
148+
"source": [
149+
"import numpy as np\n",
150+
"\n",
151+
"morecoords = np.concatenate([ensemble.coords for _ in range(15)])\n",
152+
"morecoords.shape"
153+
]
154+
},
155+
{
156+
"cell_type": "code",
157+
"execution_count": 11,
158+
"id": "e0e31b99",
159+
"metadata": {},
160+
"outputs": [
161+
{
162+
"name": "stdout",
163+
"output_type": "stream",
164+
"text": [
165+
"DEBUG: MOIPrunerConfig - k=500, rejected 5990 (keeping 4135/10125), in 0.4 s\n",
166+
"DEBUG: MOIPrunerConfig - k=200, rejected 1349 (keeping 2786/10125), in 0.2 s\n",
167+
"DEBUG: MOIPrunerConfig - k=100, rejected 747 (keeping 2039/10125), in 0.1 s\n",
168+
"DEBUG: MOIPrunerConfig - k=50, rejected 450 (keeping 1589/10125), in 0.2 s\n",
169+
"DEBUG: MOIPrunerConfig - k=20, rejected 529 (keeping 1060/10125), in 0.3 s\n",
170+
"DEBUG: MOIPrunerConfig - k=10, rejected 404 (keeping 656/10125), in 0.2 s\n",
171+
"DEBUG: MOIPrunerConfig - k=5, rejected 313 (keeping 343/10125), in 0.1 s\n",
172+
"DEBUG: MOIPrunerConfig - k=2, rejected 200 (keeping 143/10125), in 0.1 s\n",
173+
"DEBUG: MOIPrunerConfig - k=1, rejected 71 (keeping 72/10125), in 0.1 s\n",
174+
"DEBUG: MOIPrunerConfig - keeping 72/10125 (1.6 s)\n",
175+
"DEBUG: MOIPrunerConfig - Used cached data 143688/314950 times, 45.62% of total calls\n",
176+
"CPU times: user 2.71 s, sys: 18.1 ms, total: 2.73 s\n",
177+
"Wall time: 2.72 s\n"
178+
]
179+
}
180+
],
181+
"source": [
182+
"%%time\n",
183+
"pruned, mask = prune_by_moment_of_inertia(\n",
184+
" morecoords,\n",
185+
" ensemble.atoms,\n",
186+
" max_deviation=0.01, # 1% difference\n",
187+
" debugfunction=print,\n",
188+
")"
189+
]
130190
}
131191
],
132192
"metadata": {

0 commit comments

Comments
 (0)