Skip to content

Commit f34e6ba

Browse files
talagayevmarinegor
andauthored
'MDAnalysis.analysis.diffusionmap' parallelization (#4745)
`MDAnalysis.analysis.diffusionmap` parallelization - fix #4679 - enable parallelization for `diffusionmap` with all backends - update documentation and tests with appropriate fixture --------- Co-authored-by: Egor Marin <[email protected]>
1 parent 5c11b50 commit f34e6ba

File tree

4 files changed

+87
-33
lines changed

4 files changed

+87
-33
lines changed

package/CHANGELOG

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ The rules for this file:
1414

1515

1616
-------------------------------------------------------------------------------
17-
??/??/?? IAlibay, orbeckst, marinegor, tylerjereddy, ljwoods2
17+
??/??/?? IAlibay, orbeckst, marinegor, tylerjereddy, ljwoods2, talagayev
1818

1919
* 2.11.0
2020

@@ -26,6 +26,10 @@ Fixes
2626
* Fixes incorrect assignment of secondary structure to proline residues in
2727
DSSP by porting upstream PyDSSP 0.9.1 fix (Issue #4913)
2828

29+
Enhancements
30+
* Enables parallelization for analysis.diffusionmap.DistanceMatrix
31+
(Issue #4679, PR #4745)
32+
2933
Changes
3034

3135
Deprecations

package/MDAnalysis/analysis/diffusionmap.py

Lines changed: 53 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@
143143
from MDAnalysis.core.universe import Universe
144144
from MDAnalysis.core.groups import AtomGroup, UpdatingAtomGroup
145145
from .rms import rmsd
146-
from .base import AnalysisBase
146+
from .base import AnalysisBase, ResultsGroup
147147

148148
logger = logging.getLogger("MDAnalysis.analysis.diffusionmap")
149149

@@ -234,8 +234,22 @@ class DistanceMatrix(AnalysisBase):
234234
.. versionchanged:: 2.8.0
235235
:class:`DistanceMatrix` is now correctly works with `frames=...`
236236
parameter (#4432) by iterating over `self._sliced_trajectory`
237+
.. versionchanged:: 2.11.0
238+
Enabled **parallel execution** with the ``multiprocessing`` and ``dask``
239+
backends; use the new method :meth:`get_supported_backends` to see all
240+
supported backends.
237241
"""
238242

243+
_analysis_algorithm_is_parallelizable = True
244+
245+
@classmethod
246+
def get_supported_backends(cls):
247+
return (
248+
"serial",
249+
"multiprocessing",
250+
"dask",
251+
)
252+
239253
def __init__(
240254
self,
241255
universe,
@@ -265,27 +279,16 @@ def __init__(
265279
self._calculated = False
266280

267281
def _prepare(self):
268-
self.results.dist_matrix = np.zeros((self.n_frames, self.n_frames))
282+
# Perpare for parallelization workers
283+
n_atoms = self.atoms.n_atoms
284+
n_dim = self.atoms.positions.shape[1]
285+
self.results._positions = np.zeros(
286+
(self.n_frames, n_atoms, n_dim), dtype=np.float64
287+
)
269288

270289
def _single_frame(self):
271-
iframe = self._frame_index
272-
i_ref = self.atoms.positions
273-
# diagonal entries need not be calculated due to metric(x,x) == 0 in
274-
# theory, _ts not updated properly. Possible savings by setting a
275-
# cutoff for significant decimal places to sparsify matrix
276-
for j, ts in enumerate(self._sliced_trajectory[iframe:]):
277-
self._ts = ts
278-
j_ref = self.atoms.positions
279-
dist = self._metric(i_ref, j_ref, weights=self._weights)
280-
self.results.dist_matrix[
281-
self._frame_index, j + self._frame_index
282-
] = (dist if dist > self._cutoff else 0)
283-
self.results.dist_matrix[
284-
j + self._frame_index, self._frame_index
285-
] = self.results.dist_matrix[
286-
self._frame_index, j + self._frame_index
287-
]
288-
self._ts = self._sliced_trajectory[iframe]
290+
# Store current frame positions
291+
self.results._positions[self._frame_index] = self.atoms.positions
289292

290293
@property
291294
def dist_matrix(self):
@@ -298,8 +301,38 @@ def dist_matrix(self):
298301
return self.results.dist_matrix
299302

300303
def _conclude(self):
304+
# Build the full pairwise distance matrix from stored positions
305+
# Calculate and store results
306+
pos = np.asarray(
307+
self.results._positions, dtype=np.float64
308+
) # (n_frames, n_atoms, n_dim)
309+
n = pos.shape[0]
310+
311+
D = np.zeros((n, n), dtype=np.float64)
312+
313+
metric = self._metric
314+
cutoff = self._cutoff
315+
weights = self._weights
316+
317+
for i in range(n):
318+
pi = pos[i]
319+
for j in range(i + 1, n):
320+
pj = pos[j]
321+
d = metric(pi, pj, weights=weights)
322+
if d > cutoff:
323+
D[i, j] = d
324+
D[j, i] = d
325+
326+
self.results.dist_matrix = D
301327
self._calculated = True
302328

329+
def _get_aggregator(self):
330+
return ResultsGroup(
331+
lookup={
332+
"_positions": ResultsGroup.ndarray_vstack, # Get positions
333+
}
334+
)
335+
303336

304337
class DiffusionMap(object):
305338
"""Non-linear dimension reduction method

testsuite/MDAnalysisTests/analysis/conftest.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from MDAnalysis.analysis.dihedrals import Dihedral, Ramachandran, Janin
1111
from MDAnalysis.analysis.bat import BAT
1212
from MDAnalysis.analysis.gnm import GNMAnalysis
13+
from MDAnalysis.analysis.diffusionmap import DistanceMatrix
1314
from MDAnalysis.analysis.dssp.dssp import DSSP
1415
from MDAnalysis.analysis.hydrogenbonds.hbond_analysis import (
1516
HydrogenBondAnalysis,
@@ -208,3 +209,11 @@ def client_InterRDF(request):
208209
@pytest.fixture(scope="module", params=params_for_cls(InterRDF_s))
209210
def client_InterRDF_s(request):
210211
return request.param
212+
213+
214+
# MDAnalysis.analysis.diffusionmap
215+
216+
217+
@pytest.fixture(scope="module", params=params_for_cls(DistanceMatrix))
218+
def client_DistanceMatrix(request):
219+
return request.param

testsuite/MDAnalysisTests/analysis/test_diffusionmap.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -52,13 +52,13 @@ def test_eg(dist, dmap):
5252
# makes no sense to test values here, no physical meaning
5353

5454

55-
def test_dist_weights(u):
55+
def test_dist_weights(u, client_DistanceMatrix):
5656
backbone = u.select_atoms("backbone")
5757
weights_atoms = np.ones(len(backbone.atoms))
5858
dist = diffusionmap.DistanceMatrix(
5959
u, select="backbone", weights=weights_atoms
6060
)
61-
dist.run(step=3)
61+
dist.run(**client_DistanceMatrix, step=3)
6262
dmap = diffusionmap.DiffusionMap(dist)
6363
dmap.run()
6464
assert_array_almost_equal(dmap.eigenvalues, [1, 1, 1, 1], 4)
@@ -76,11 +76,11 @@ def test_dist_weights(u):
7676
)
7777

7878

79-
def test_dist_weights_frames(u):
79+
def test_dist_weights_frames(u, client_DistanceMatrix):
8080
backbone = u.select_atoms("backbone")
8181
weights_atoms = np.ones(len(backbone.atoms))
8282
dist = diffusionmap.DistanceMatrix(
83-
u, select="backbone", weights=weights_atoms
83+
u, **client_DistanceMatrix, select="backbone", weights=weights_atoms
8484
)
8585
frames = np.arange(len(u.trajectory))
8686
dist.run(frames=frames[::3])
@@ -101,19 +101,25 @@ def test_dist_weights_frames(u):
101101
)
102102

103103

104-
def test_distvalues_ag_universe(u):
105-
dist_universe = diffusionmap.DistanceMatrix(u, select="backbone").run()
104+
def test_distvalues_ag_universe(u, client_DistanceMatrix):
105+
dist_universe = diffusionmap.DistanceMatrix(u, select="backbone").run(
106+
**client_DistanceMatrix
107+
)
106108
ag = u.select_atoms("backbone")
107-
dist_ag = diffusionmap.DistanceMatrix(ag).run()
109+
dist_ag = diffusionmap.DistanceMatrix(ag).run(**client_DistanceMatrix)
108110
assert_allclose(
109111
dist_universe.results.dist_matrix, dist_ag.results.dist_matrix
110112
)
111113

112114

113-
def test_distvalues_ag_select(u):
114-
dist_universe = diffusionmap.DistanceMatrix(u, select="backbone").run()
115+
def test_distvalues_ag_select(u, client_DistanceMatrix):
116+
dist_universe = diffusionmap.DistanceMatrix(u, select="backbone").run(
117+
**client_DistanceMatrix
118+
)
115119
ag = u.select_atoms("protein")
116-
dist_ag = diffusionmap.DistanceMatrix(ag, select="backbone").run()
120+
dist_ag = diffusionmap.DistanceMatrix(ag, select="backbone").run(
121+
**client_DistanceMatrix
122+
)
117123
assert_allclose(
118124
dist_universe.results.dist_matrix, dist_ag.results.dist_matrix
119125
)
@@ -156,8 +162,10 @@ def test_not_universe_atomgroup_error(u):
156162
diffusionmap.DiffusionMap(trj_only)
157163

158164

159-
def test_DistanceMatrix_attr_warning(u):
160-
dist = diffusionmap.DistanceMatrix(u, select="backbone").run(step=3)
165+
def test_DistanceMatrix_attr_warning(u, client_DistanceMatrix):
166+
dist = diffusionmap.DistanceMatrix(u, select="backbone").run(
167+
**client_DistanceMatrix, step=3
168+
)
161169
wmsg = f"The `dist_matrix` attribute was deprecated in MDAnalysis 2.0.0"
162170
with pytest.warns(DeprecationWarning, match=wmsg):
163171
assert getattr(dist, "dist_matrix") is dist.results.dist_matrix

0 commit comments

Comments
 (0)