|
9 | 9 |
|
10 | 10 | import inspect |
11 | 11 | from collections import OrderedDict |
| 12 | +from functools import partial |
12 | 13 | from typing import Any, Dict, Optional, Tuple |
13 | 14 |
|
14 | 15 | import numpy as np |
@@ -40,6 +41,7 @@ class AbstractCollectiveVariable(openmm.Force): |
40 | 41 | """ |
41 | 42 |
|
42 | 43 | _unit: mmunit.Unit = mmunit.dimensionless |
| 44 | + _mass_unit: mmunit.Unit = mmunit.dalton * mmunit.nanometers**2 |
43 | 45 | _args: Dict[str, Any] = {} |
44 | 46 |
|
45 | 47 | def __getstate__(self) -> Dict[str, Any]: |
@@ -67,6 +69,7 @@ def _registerCV(self, unit: mmunit.Unit, *args: Any, **kwargs: Any) -> None: |
67 | 69 | """ |
68 | 70 | self.setName(self.__class__.__name__) |
69 | 71 | self.setUnit(unit) |
| 72 | + self._mass_unit = mmunit.dalton * (mmunit.nanometers / self.getUnit()) ** 2 |
70 | 73 | arguments, _ = self.getArguments() |
71 | 74 | self._args = dict(zip(arguments, args)) |
72 | 75 | self._args.update(kwargs) |
@@ -327,20 +330,25 @@ def getEffectiveMass( |
327 | 330 | 1 |
328 | 331 | >>> model.system.addForce(radius_of_gyration) |
329 | 332 | 6 |
330 | | - >>> platform =openmm.Platform.getPlatformByName('Reference') |
331 | | - >>> context =openmm.Context( |
| 333 | + >>> platform = openmm.Platform.getPlatformByName('Reference') |
| 334 | + >>> context = openmm.Context( |
332 | 335 | ... model.system,openmm.VerletIntegrator(0), platform |
333 | 336 | ... ) |
334 | 337 | >>> context.setPositions(model.positions) |
335 | 338 | >>> print(radius_of_gyration.getEffectiveMass(context, digits=6)) |
336 | 339 | 30.94693 Da |
337 | 340 | """ |
338 | 341 | state = self._getSingleForceState(context, getForces=True) |
339 | | - force_values = value_in_md_units(state.getForces(asNumpy=True)) |
340 | | - mass_values = [ |
341 | | - value_in_md_units(context.getSystem().getParticleMass(i)) |
342 | | - for i in range(context.getSystem().getNumParticles()) |
343 | | - ] |
344 | | - effective_mass = 1.0 / np.sum(np.sum(force_values**2, axis=1) / mass_values) |
345 | | - unit = mmunit.dalton * (mmunit.nanometers / self.getUnit()) ** 2 |
346 | | - return mmunit.Quantity(self._precisionRound(effective_mass, digits), unit) |
| 342 | + # pylint: disable=protected-access |
| 343 | + get_mass = partial(openmm._openmm.System_getParticleMass, context.getSystem()) |
| 344 | + force_vectors = state.getForces(asNumpy=True)._value |
| 345 | + # pylint: enable=protected-access |
| 346 | + squared_forces = np.sum(np.square(force_vectors), axis=1) |
| 347 | + nonzeros = np.nonzero(squared_forces)[0] |
| 348 | + if nonzeros.size == 0: |
| 349 | + return mmunit.Quantity(np.inf, self._mass_unit) |
| 350 | + mass_values = np.fromiter(map(get_mass, nonzeros), dtype=np.float64) |
| 351 | + effective_mass = 1.0 / np.sum(squared_forces[nonzeros] / mass_values) |
| 352 | + return mmunit.Quantity( |
| 353 | + self._precisionRound(effective_mass, digits), self._mass_unit |
| 354 | + ) |
0 commit comments