Skip to content

Commit 8cb2f08

Browse files
author
Charlles Abreu
authored
Enhances effective mass calculation by ignoring null forces (#52)
* Remove null forces from effective mass calculation * Avoid division by zero
1 parent 8f19d3e commit 8cb2f08

File tree

1 file changed

+18
-10
lines changed

1 file changed

+18
-10
lines changed

cvpack/cvpack.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import inspect
1111
from collections import OrderedDict
12+
from functools import partial
1213
from typing import Any, Dict, Optional, Tuple
1314

1415
import numpy as np
@@ -40,6 +41,7 @@ class AbstractCollectiveVariable(openmm.Force):
4041
"""
4142

4243
_unit: mmunit.Unit = mmunit.dimensionless
44+
_mass_unit: mmunit.Unit = mmunit.dalton * mmunit.nanometers**2
4345
_args: Dict[str, Any] = {}
4446

4547
def __getstate__(self) -> Dict[str, Any]:
@@ -67,6 +69,7 @@ def _registerCV(self, unit: mmunit.Unit, *args: Any, **kwargs: Any) -> None:
6769
"""
6870
self.setName(self.__class__.__name__)
6971
self.setUnit(unit)
72+
self._mass_unit = mmunit.dalton * (mmunit.nanometers / self.getUnit()) ** 2
7073
arguments, _ = self.getArguments()
7174
self._args = dict(zip(arguments, args))
7275
self._args.update(kwargs)
@@ -327,20 +330,25 @@ def getEffectiveMass(
327330
1
328331
>>> model.system.addForce(radius_of_gyration)
329332
6
330-
>>> platform =openmm.Platform.getPlatformByName('Reference')
331-
>>> context =openmm.Context(
333+
>>> platform = openmm.Platform.getPlatformByName('Reference')
334+
>>> context = openmm.Context(
332335
... model.system,openmm.VerletIntegrator(0), platform
333336
... )
334337
>>> context.setPositions(model.positions)
335338
>>> print(radius_of_gyration.getEffectiveMass(context, digits=6))
336339
30.94693 Da
337340
"""
338341
state = self._getSingleForceState(context, getForces=True)
339-
force_values = value_in_md_units(state.getForces(asNumpy=True))
340-
mass_values = [
341-
value_in_md_units(context.getSystem().getParticleMass(i))
342-
for i in range(context.getSystem().getNumParticles())
343-
]
344-
effective_mass = 1.0 / np.sum(np.sum(force_values**2, axis=1) / mass_values)
345-
unit = mmunit.dalton * (mmunit.nanometers / self.getUnit()) ** 2
346-
return mmunit.Quantity(self._precisionRound(effective_mass, digits), unit)
342+
# pylint: disable=protected-access
343+
get_mass = partial(openmm._openmm.System_getParticleMass, context.getSystem())
344+
force_vectors = state.getForces(asNumpy=True)._value
345+
# pylint: enable=protected-access
346+
squared_forces = np.sum(np.square(force_vectors), axis=1)
347+
nonzeros = np.nonzero(squared_forces)[0]
348+
if nonzeros.size == 0:
349+
return mmunit.Quantity(np.inf, self._mass_unit)
350+
mass_values = np.fromiter(map(get_mass, nonzeros), dtype=np.float64)
351+
effective_mass = 1.0 / np.sum(squared_forces[nonzeros] / mass_values)
352+
return mmunit.Quantity(
353+
self._precisionRound(effective_mass, digits), self._mass_unit
354+
)

0 commit comments

Comments
 (0)