Skip to content

Commit 79dead3

Browse files
DrDomenicoMarsonRMeliorbeckst
authored
Added parallelization to polymer.PersistenceLength (#5074)
* Fixes #4671 * Enables parallelization for `analysis.polymer.PersistenceLength` * Access to `results.raw_bond_autocorr` * update CHANGELOG Co-authored-by: Rocco Meli <[email protected]> Co-authored-by: Oliver Beckstein <[email protected]>
1 parent 8020a0b commit 79dead3

File tree

4 files changed

+61
-27
lines changed

4 files changed

+61
-27
lines changed

package/CHANGELOG

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ The rules for this file:
1616
-------------------------------------------------------------------------------
1717
??/??/?? IAlibay, orbeckst, BHM-Bob, TRY-ER, Abdulrahman-PROG, pbuslaev,
1818
yuxuanzhuang, yuyuan871111, tanishy7777, tulga-rdn, Gareth-elliott,
19-
hmacdope, tylerjereddy, cbouy, talagayev
19+
hmacdope, tylerjereddy, cbouy, talagayev, DrDomenicoMarson
2020

2121

2222
* 2.10.0
@@ -68,6 +68,8 @@ Enhancements
6868
(PR #5038)
6969
* Moved distopia checking function to common import location in
7070
MDAnalysisTest.util (PR #5038)
71+
* Enables parallelization for `analysis.polymer.PersistenceLength` (Issue #4671, PR #5074)
72+
7173

7274
Changes
7375
* Refactored the RDKit converter code to move the inferring code in a separate

package/MDAnalysis/analysis/polymer.py

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
from .. import NoDataError
4141
from ..core.groups import requires, AtomGroup
4242
from ..lib.distances import calc_bonds
43-
from .base import AnalysisBase
43+
from .base import AnalysisBase, ResultsGroup
4444

4545
logger = logging.getLogger(__name__)
4646

@@ -236,8 +236,17 @@ class PersistenceLength(AnalysisBase):
236236
Former ``results`` are now stored as ``results.bond_autocorrelation``.
237237
:attr:`lb`, :attr:`lp`, :attr:`fit` are now stored in a
238238
:class:`MDAnalysis.analysis.base.Results` instance.
239+
.. versionchanged:: 2.10.0
240+
introduced :meth:`get_supported_backends` allowing for parallel
241+
execution on ``multiprocessing`` and ``dask`` backends.
239242
"""
240243

244+
_analysis_algorithm_is_parallelizable = True
245+
246+
@classmethod
247+
def get_supported_backends(cls):
248+
return ("serial", "multiprocessing", "dask")
249+
241250
def __init__(self, atomgroups, **kwargs):
242251
super(PersistenceLength, self).__init__(
243252
atomgroups[0].universe.trajectory, **kwargs
@@ -249,15 +258,18 @@ def __init__(self, atomgroups, **kwargs):
249258
chainlength = len(atomgroups[0])
250259
if not all(l == chainlength for l in lens):
251260
raise ValueError("Not all AtomGroups were the same size")
261+
self.chainlength = chainlength
252262

253-
self._results = np.zeros(chainlength - 1, dtype=np.float32)
263+
def _prepare(self):
264+
self.results.raw_bond_autocorr = np.zeros(
265+
self.chainlength - 1, dtype=np.float32
266+
)
254267

255268
def _single_frame(self):
256269
# could optimise this by writing a "self dot array"
257270
# we're only using the upper triangle of np.inner
258271
# function would accept a bunch of coordinates and spit out the
259272
# decorrel for that
260-
n = len(self._atomgroups[0])
261273

262274
for chain in self._atomgroups:
263275
# Vector from each atom to next
@@ -266,8 +278,17 @@ def _single_frame(self):
266278
vecs /= np.sqrt((vecs * vecs).sum(axis=1))[:, None]
267279

268280
inner_pr = np.inner(vecs, vecs)
269-
for i in range(n - 1):
270-
self._results[: (n - 1) - i] += inner_pr[i, i:]
281+
for i in range(self.chainlength - 1):
282+
self.results.raw_bond_autocorr[
283+
: (self.chainlength - 1) - i
284+
] += inner_pr[i, i:]
285+
286+
def _get_aggregator(self):
287+
return ResultsGroup(
288+
lookup={
289+
"raw_bond_autocorr": ResultsGroup.ndarray_sum,
290+
}
291+
)
271292

272293
@property
273294
def lb(self):
@@ -300,14 +321,12 @@ def fit(self):
300321
return self.results.fit
301322

302323
def _conclude(self):
303-
n = len(self._atomgroups[0])
304-
305-
norm = np.linspace(n - 1, 1, n - 1)
306-
norm *= len(self._atomgroups) * self.n_frames
307-
308-
self.results.bond_autocorrelation = self._results / norm
324+
norm = np.linspace(self.chainlength - 1, 1, self.chainlength - 1)
325+
norm *= len(self._atomgroups) * self._trajectory.n_frames
326+
self.results.bond_autocorrelation = (
327+
self.results.raw_bond_autocorr / norm
328+
)
309329
self._calc_bond_length()
310-
311330
self._perform_fit()
312331

313332
def _calc_bond_length(self):
@@ -350,7 +369,7 @@ def plot(self, ax=None):
350369
import matplotlib.pyplot as plt
351370

352371
if ax is None:
353-
fig, ax = plt.subplots()
372+
_, ax = plt.subplots()
354373
ax.plot(
355374
self.results.x,
356375
self.results.bond_autocorrelation,

testsuite/MDAnalysisTests/analysis/conftest.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from MDAnalysis.analysis.contacts import Contacts
1919
from MDAnalysis.analysis.density import DensityAnalysis
2020
from MDAnalysis.analysis.lineardensity import LinearDensity
21+
from MDAnalysis.analysis.polymer import PersistenceLength
2122
from MDAnalysis.lib.util import is_installed
2223

2324

@@ -185,3 +186,11 @@ def client_DensityAnalysis(request):
185186
@pytest.fixture(scope="module", params=params_for_cls(LinearDensity))
186187
def client_LinearDensity(request):
187188
return request.param
189+
190+
191+
# MDAnalysis.analysis.polymer
192+
193+
194+
@pytest.fixture(scope="module", params=params_for_cls(PersistenceLength))
195+
def client_PersistenceLength(request):
196+
return request.param

testsuite/MDAnalysisTests/analysis/test_persistencelength.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -36,24 +36,33 @@
3636
from MDAnalysisTests.datafiles import Plength, TRZ_psf, TRZ
3737

3838

39+
def test_class_is_parallelizable():
40+
assert polymer.PersistenceLength._analysis_algorithm_is_parallelizable
41+
42+
43+
def test_supported_backends():
44+
assert polymer.PersistenceLength.get_supported_backends() == (
45+
"serial",
46+
"multiprocessing",
47+
"dask",
48+
)
49+
50+
3951
class TestPersistenceLength(object):
4052
@staticmethod
4153
@pytest.fixture()
4254
def u():
4355
return mda.Universe(Plength)
4456

45-
@staticmethod
4657
@pytest.fixture()
47-
def p(u):
58+
def p(self, u):
4859
ags = [r.atoms.select_atoms("name C* N*") for r in u.residues]
49-
5060
p = polymer.PersistenceLength(ags)
5161
return p
5262

53-
@staticmethod
5463
@pytest.fixture()
55-
def p_run(p):
56-
return p.run()
64+
def p_run(self, p, client_PersistenceLength):
65+
return p.run(**client_PersistenceLength)
5766

5867
def test_ag_ValueError(self, u):
5968
ags = [u.atoms[:10], u.atoms[10:110]]
@@ -81,15 +90,11 @@ def test_raise_NoDataError(self, p):
8190
def test_plot_ax_return(self, p_run):
8291
"""Ensure that a matplotlib axis object is
8392
returned when plot() is called."""
84-
actual = p_run.plot()
85-
expected = matplotlib.axes.Axes
86-
assert isinstance(actual, expected)
93+
assert isinstance(p_run.plot(), matplotlib.axes.Axes)
8794

8895
def test_plot_with_ax(self, p_run):
8996
fig, ax = plt.subplots()
90-
9197
ax2 = p_run.plot(ax=ax)
92-
9398
assert ax2 is ax
9499

95100
def test_current_axes(self, p_run):
@@ -98,8 +103,7 @@ def test_current_axes(self, p_run):
98103
assert ax2 is not ax
99104

100105
@pytest.mark.parametrize("attr", ("lb", "lp", "fit"))
101-
def test(self, p, attr):
102-
p_run = p.run(step=3)
106+
def test(self, p_run, attr):
103107
wmsg = f"The `{attr}` attribute was deprecated in MDAnalysis 2.0.0"
104108
with pytest.warns(DeprecationWarning, match=wmsg):
105109
getattr(p_run, attr) is p_run.results[attr]

0 commit comments

Comments
 (0)