Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ ignore_missing_imports = True
exclude = doc/sphinx/source/


[mypy-MDAnalysis.analysis.rms]
ignore_errors = False

[mypy-MDAnalysis.analysis.*]
ignore_errors = True

Expand Down
70 changes: 48 additions & 22 deletions package/MDAnalysis/analysis/rms.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,11 +171,24 @@
from ..exceptions import SelectionError
from ..lib.util import asiterable, iterable, get_weights

from typing import Optional, Union, List, Dict, Any, Sequence, Tuple, TYPE_CHECKING
from numpy.typing import NDArray

if TYPE_CHECKING:
from ..core.groups import AtomGroup
from ..core.universe import Universe


logger = logging.getLogger("MDAnalysis.analysis.rmsd")


def rmsd(a, b, weights=None, center=False, superposition=False):
def rmsd(
a: NDArray,
b: NDArray,
weights: Optional[NDArray] = None,
center: bool = False,
superposition: bool = False,
) -> float:
r"""Returns RMSD between two coordinate sets `a` and `b`.

`a` and `b` are arrays of the coordinates of N atoms of shape
Expand Down Expand Up @@ -282,7 +295,9 @@ def rmsd(a, b, weights=None, center=False, superposition=False):
return np.sqrt(np.sum((a - b) ** 2) / N)


def process_selection(select):
def process_selection(
select: Union[str, Tuple[str, str], Dict[str, str]]
) -> Dict[str, Any]:
"""Return a canonical selection dictionary.

Parameters
Expand Down Expand Up @@ -365,6 +380,14 @@ class RMSD(AnalysisBase):

_analysis_algorithm_is_parallelizable = True

atomgroup: Union["AtomGroup", "Universe"]
reference: Union["AtomGroup", "Universe"]
groupselections: List[Dict[str, Any]]
weights: Optional[Union[str, NDArray, List[Any]]]
weights_groupselections: Union[bool, List[Any]]
tol_mass: float
ref_frame: int

@classmethod
def get_supported_backends(cls):
return (
Expand All @@ -375,15 +398,15 @@ def get_supported_backends(cls):

def __init__(
self,
atomgroup,
reference=None,
select="all",
groupselections=None,
weights=None,
weights_groupselections=False,
tol_mass=0.1,
ref_frame=0,
**kwargs,
atomgroup: Union["AtomGroup", "Universe"],
reference: Optional[Union["AtomGroup", "Universe"]] = None,
select: Union[str, Dict[str, str], Tuple[str, str]] = "all",
groupselections: Optional[Sequence[Union[str, Dict[str, str], Tuple[str, str]]]] = None,
weights: Optional[Union[str, NDArray, List[Any]]] = None,
weights_groupselections: Union[bool, List[Any]] = False,
tol_mass: float = 0.1,
ref_frame: int = 0,
**kwargs: Any,
):
r"""Parameters
----------
Expand Down Expand Up @@ -665,7 +688,7 @@ def __init__(
+ " happens in selection %s" % selection["mobile"]
)

def _prepare(self):
def _prepare(self) -> None:
self._n_atoms = self.mobile_atoms.n_atoms
if not self.weights_groupselections:
if not iterable(
Expand All @@ -679,15 +702,17 @@ def _prepare(self):
self.groupselections
)

weights_gs = self.weights_groupselections
assert isinstance(weights_gs, list)
for igroup, (weights, atoms) in enumerate(
zip(self.weights_groupselections, self._groupselections_atoms)
zip(weights_gs, self._groupselections_atoms)
):
if str(weights) == "mass":
self.weights_groupselections[igroup] = atoms["mobile"].masses
weights_gs[igroup] = atoms["mobile"].masses
if weights is not None:
self.weights_groupselections[igroup] = np.asarray(
self.weights_groupselections[igroup], dtype=np.float64
) / np.mean(self.weights_groupselections[igroup])
weights_gs[igroup] = np.asarray(
weights_gs[igroup], dtype=np.float64
) / np.mean(weights_gs[igroup])
# add the array of weights to weights_select
self.weights_select = get_weights(self.mobile_atoms, self.weights)
self.weights_ref = get_weights(self.ref_atoms, self.weights)
Expand Down Expand Up @@ -746,7 +771,7 @@ def _prepare(self):
def _get_aggregator(self):
return ResultsGroup(lookup={"rmsd": ResultsGroup.ndarray_vstack})

def _single_frame(self):
def _single_frame(self) -> None:
mobile_com = self.mobile_atoms.center(self.weights_select).astype(
np.float64
)
Expand Down Expand Up @@ -787,6 +812,7 @@ def _single_frame(self):

# 2) calculate secondary RMSDs (without any further
# superposition)
assert isinstance(self.weights_groupselections, list)
for igroup, (refpos, atoms) in enumerate(
zip(
self._groupselections_ref_coords64,
Expand Down Expand Up @@ -846,7 +872,7 @@ class RMSF(AnalysisBase):
def get_supported_backends(cls):
return ("serial",)

def __init__(self, atomgroup, **kwargs):
def __init__(self, atomgroup: "AtomGroup", **kwargs: Any):
r"""Parameters
----------
atomgroup : AtomGroup
Expand Down Expand Up @@ -969,18 +995,18 @@ def __init__(self, atomgroup, **kwargs):
super(RMSF, self).__init__(atomgroup.universe.trajectory, **kwargs)
self.atomgroup = atomgroup

def _prepare(self):
def _prepare(self) -> None:
self.sumsquares = np.zeros((self.atomgroup.n_atoms, 3))
self.mean = self.sumsquares.copy()

def _single_frame(self):
def _single_frame(self) -> None:
k = self._frame_index
self.sumsquares += (k / (k + 1.0)) * (
self.atomgroup.positions - self.mean
) ** 2
self.mean = (k * self.mean + self.atomgroup.positions) / (k + 1)

def _conclude(self):
def _conclude(self) -> None:
k = self._frame_index
self.results.rmsf = np.sqrt(self.sumsquares.sum(axis=1) / (k + 1))

Expand Down
Loading