diff --git a/.github/actions/setup-deps/action.yaml b/.github/actions/setup-deps/action.yaml index d0c6b3dd4f3..7523214b8fa 100644 --- a/.github/actions/setup-deps/action.yaml +++ b/.github/actions/setup-deps/action.yaml @@ -64,6 +64,8 @@ inputs: default: 'h5py>=2.10' hole2: default: 'hole2' + imdclient: + default: 'imdclient>=0.2.2' joblib: default: 'joblib>=0.12' netcdf4: @@ -138,6 +140,7 @@ runs: ${{ inputs.gsd }} ${{ inputs.h5py }} ${{ inputs.hole2 }} + ${{ inputs.imdclient }} ${{ inputs.joblib }} ${{ inputs.netcdf4 }} ${{ inputs.networkx }} diff --git a/azure-pipelines.yml b/azure-pipelines.yml index 68b8f470fa2..e3a65315c2b 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -113,6 +113,8 @@ jobs: pytng>=0.2.3 rdkit>=2024.03.4 tidynamics>=1.0.0 + imdclient>=0.2.2 + # remove from azure to avoid test hanging #4707 # "gsd>3.0.0" displayName: 'Install additional dependencies for 64-bit tests' diff --git a/maintainer/conda/environment.yml b/maintainer/conda/environment.yml index 852044fc79a..7b69149302c 100644 --- a/maintainer/conda/environment.yml +++ b/maintainer/conda/environment.yml @@ -30,6 +30,7 @@ dependencies: - sphinxcontrib-bibtex - mdaencore - waterdynamics + - imdclient>=0.2.2 - pip: - mdahole2 - pathsimanalysis diff --git a/package/AUTHORS b/package/AUTHORS index 61eadbaf75d..7426186154e 100644 --- a/package/AUTHORS +++ b/package/AUTHORS @@ -260,6 +260,7 @@ Chronological list of authors - Gareth Elliott - Marc Schuh - Sirsha Ganguly + - Amruthesh Thirumalaiswamy External code ------------- diff --git a/package/CHANGELOG b/package/CHANGELOG index 75dcd5f534d..ebaa2902bd6 100644 --- a/package/CHANGELOG +++ b/package/CHANGELOG @@ -16,7 +16,7 @@ The rules for this file: ------------------------------------------------------------------------------- ??/??/?? IAlibay, orbeckst, BHM-Bob, TRY-ER, Abdulrahman-PROG, pbuslaev, yuxuanzhuang, yuyuan871111, tanishy7777, tulga-rdn, Gareth-elliott, - hmacdope, tylerjereddy, cbouy, talagayev, DrDomenicoMarson + hmacdope, tylerjereddy, cbouy, talagayev, DrDomenicoMarson, amruthesht * 2.10.0 @@ -42,6 +42,11 @@ Fixes directly passing them. (Issue #3520, PR #5006) Enhancements + * Added support for reading and processing streamed data in `coordinates.base` + with new `StreamFrameIteratorSliced` and `StreamReaderBase` (Issue #4827, PR #4923) + * New coordinate reader: Added `IMDReader` for reading real-time streamed + molecular dynamics simulation data using the IMDv3 protocol - requires + `imdclient` package (Issue #4827, PR #4923) * Added capability to calculate MSD from frames with irregular (non-linear) time spacing in analysis.msd.EinsteinMSD with keyword argument `non_linear=True` (Issue #5028, PR #5066) @@ -70,7 +75,7 @@ Enhancements so that it gets passed through from the calling functions and classes (PR #5038) * Moved distopia checking function to common import location in - MDAnalysisTest.util (PR #5038) + MDAnalysisTest.util (PR #5038) * Enables parallelization for `analysis.polymer.PersistenceLength` (Issue #4671, PR #5074) diff --git a/package/MDAnalysis/coordinates/IMD.py b/package/MDAnalysis/coordinates/IMD.py new file mode 100644 index 00000000000..4945d43adfc --- /dev/null +++ b/package/MDAnalysis/coordinates/IMD.py @@ -0,0 +1,334 @@ +""" +IMDReader --- :mod:`MDAnalysis.coordinates.IMD` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +This module provides support for reading molecular dynamics simulation data via the +`Interactive Molecular Dynamics (IMD) protocol v3 `_. +The IMD protocol allows two-way communicating molecular simulation data through a socket. +Via IMD, a simulation engine sends data to a receiver (in this case, the IMDClient) and the receiver can send forces and specific control +requests (such as pausing, resuming, or terminating the simulation) back to the simulation engine. + +.. note:: + This reader only supports IMDv3, which is implemented in GROMACS, LAMMPS, and NAMD at varying + stages of development. See the `imdclient simulation engine docs`_ for more. + While IMDv2 is widely available in simulation engines, it was designed primarily for visualization + and gaps are allowed in the stream (i.e., an inconsistent number of integrator time steps between transmitted coordinate arrays is allowed) + +The :class:`IMDReader` connects to a simulation via a socket and receives coordinate, +velocity, force, and energy data as the simulation progresses. This allows for real-time +monitoring and analysis of ongoing simulations. It uses the `imdclient package `_ +(dependency) to implement the IMDv3 protocol and manage the socket connection and data parsing. + +.. seealso:: + :class:`IMDReader` + Technical details and parameter options for the reader class + + `imdclient documentation `_ + Complete documentation for the IMDClient package + + `IMDClient GitHub repository `_ + Source code and development resources + +.. _`imdclient simulation engine docs`: https://imdclient.readthedocs.io/en/latest/usage.html + +Usage Example +------------- + +As an example of reading a stream, after configuring GROMACS to run a simulation with IMDv3 enabled +(see the `imdclient simulation engine docs`_ for +up-to-date resources on configuring each simulation engine), use the following commands: + +.. code-block:: bash + + gmx mdrun -v -nt 4 -imdwait -imdport 8889 + +The :class:`~MDAnalysis.coordinates.IMD.IMDReader` can then connect to the running simulation and stream data in real time: + +.. code-block:: python + + import MDAnalysis as mda + u = mda.Universe("topol.tpr", "imd://localhost:8889", buffer_size=10*1024*1024) + + print(" time [ position ] [ velocity ] [ force ] [ box ]") + sel = u.select_atoms("all") # Select all atoms; adjust selection as needed + for ts in u.trajectory: + print(f'{ts.time:8.3f} {sel[0].position} {sel[0].velocity} {sel[0].force} {u.dimensions[0:3]}') + +.. important:: + **Jupyter Notebook Users**: When using IMDReader in Jupyter notebooks, be aware that + **kernel restarts will not gracefully close active IMD connections**. This can leave + socket connections open, potentially preventing new connections to the same stream. + + Always use ``try/except/finally`` blocks to ensure proper cleanup: + + .. code-block:: python + + import MDAnalysis as mda + + try: + u = mda.Universe("topol.tpr", "imd://localhost:8889") + except Exception as e: + print(f"Error during connection: {e}") + else: + try: + # Your analysis code here + for ts in u.trajectory: + # Process each frame + pass + finally: + # Ensure connection is closed + u.trajectory.close() + + Always explicitly call ``u.trajectory.close()`` when finished with analysis to + ensure connection is closed properly. + +Important Limitations +--------------------- + +.. warning:: + The IMDReader has some important limitations that are inherent in streaming data. + +Since IMD streams data in real-time from a running simulation, it has fundamental +constraints that differ from traditional trajectory readers: + +* **No random access**: Cannot jump to arbitrary frame numbers or seek backwards +* **Forward-only access**: You can only move forward through frames as they arrive +* **No trajectory length**: The total number of frames is unknown until the simulation ends +* **Single-use iteration**: Cannot restart or rewind once the stream has been consumed +* **No independent copies**: Cannot create separate reader instances for the same stream +* **No stream restart**: Cannot reconnect or reopen once the connection is closed +* **No bulk operations**: Cannot extract all data at once using timeseries methods +* **Limited multiprocessing**: Cannot split reader across processes for parallel analysis +* **Single client connection**: Only one reader can connect to an IMD stream at a time +* **No trajectory Writing**: Complimentary IMD Writer class is not available for streaming data + +.. seealso:: + See :class:`~MDAnalysis.coordinates.base.StreamReaderBase` for technical details. + +Multiple Client Connections +--------------------------- + +The ability to establish multiple simultaneous connections to the same IMD port is +**MD engine implementation dependent**. Some simulation engines may allow multiple +clients to connect concurrently, while others may reject or fail additional connection +attempts. + +See the `imdclient simulation engine docs`_ for further details. + +.. important:: + Even when multiple connections are supported by the simulation engine, each connection + receives its own independent data stream. These streams may contain different data + depending on the simulation engine's configuration, so multiple connections should + not be assumed to provide identical data streams. + +Classes +------- + +.. autoclass:: IMDReader + :members: + :inherited-members: + +""" + +import numpy as np +import logging +import warnings + +from MDAnalysis.coordinates import core +from MDAnalysis.lib.util import store_init_arguments +from MDAnalysis.coordinates.base import StreamReaderBase + + +from packaging.version import Version + +MIN_IMDCLIENT_VERSION = Version("0.2.2") + +try: + import imdclient + from imdclient.IMDClient import IMDClient + from imdclient.utils import parse_host_port +except ImportError: + HAS_IMDCLIENT = False + imdclient_version = Version("0.0.0") + + # Allow building documentation without imdclient + import types + + class MockIMDClient: + pass + + imdclient = types.ModuleType("imdclient") + imdclient.IMDClient = MockIMDClient + imdclient.__version__ = "0.0.0" + +else: + HAS_IMDCLIENT = True + imdclient_version = Version(imdclient.__version__) + + # Check for compatibility: currently needs to be >=0.2.2 + if imdclient_version < MIN_IMDCLIENT_VERSION: + warnings.warn( + f"imdclient version {imdclient_version} is too old; " + f"need at least {MIN_IMDCLIENT_VERSION}, Your installed version of " + "imdclient will NOT be used.", + category=RuntimeWarning, + ) + HAS_IMDCLIENT = False + +logger = logging.getLogger("MDAnalysis.coordinates.IMDReader") + + +class IMDReader(StreamReaderBase): + """ + Coordinate reader implementing the IMDv3 protocol for streaming simulation data. + + This class handles the technical aspects of connecting to IMD-enabled simulation + engines and processing the incoming data stream. For usage examples and protocol + overview, see the module documentation above. + + The reader manages socket connections, data buffering, and frame parsing according + to the IMDv3 specification. It automatically handles different data packet types + (coordinates, velocities, forces, energies, timing) and populates MDAnalysis + timestep objects accordingly. + + Parameters + ---------- + filename : a string of the form "imd://host:port" where host is the hostname + or IP address of the listening simulation engine's IMD server and port + is the port number. + n_atoms : int (optional) + number of atoms in the system. defaults to number of atoms + in the topology. Don't set this unless you know what you're doing. + buffer_size: int (optional) default=10*(1024**2) + number of bytes of memory to allocate to the :class:`~imdclient.IMDClient.IMDClient`'s + internal buffer. Defaults to 10 megabytes. Larger buffers can improve + performance for analyses with periodic heavy computation. + **kwargs : dict (optional) + keyword arguments passed to the constructed :class:`~imdclient.IMDClient.IMDClient` + + Notes + ----- + The IMDReader provides access to additional simulation data through the timestep's + `data` attribute (`ts.data`). The following keys may be available depending on + what the simulation engine transmits: + + * `dt` : float + Time step size in picoseconds (from the `IMD_TIME`_ packet of the IMDv3 protocol) + * `step` : int + Current simulation step number (from the `IMD_TIME`_ packet of the IMDv3 protocol) + * Energy terms : float + Various energy components (e.g., 'potential', 'kinetic', 'total', etc.) + from the `IMD_ENERGIES`_ packet of the IMDv3 protocol. + + .. _IMD_TIME: https://imdclient.readthedocs.io/en/latest/protocol_v3.html#time + .. _IMD_ENERGIES: https://imdclient.readthedocs.io/en/latest/protocol_v3.html#energies + + .. note:: + For important limitations inherent to streaming data, see the module documentation above + and :class:`~MDAnalysis.coordinates.base.StreamReaderBase` for more technical details. + + .. versionadded:: 2.10.0 + """ + + format = "IMD" + + @store_init_arguments + def __init__( + self, + filename, + n_atoms=None, + buffer_size=10 * (1024**2), + **kwargs, + ): + if not HAS_IMDCLIENT: + raise ImportError( + "IMDReader requires the imdclient package. " + "Please install it with 'pip install imdclient'." + ) + + super(IMDReader, self).__init__(filename, **kwargs) + + self._imdclient = None + logger.debug("IMDReader initializing") + + if n_atoms is None: + raise ValueError("IMDReader: n_atoms must be specified") + self.n_atoms = n_atoms + + try: + host, port = parse_host_port(filename) + except ValueError as e: + raise ValueError(f"IMDReader: Invalid IMD URL '{filename}': {e}") + + # This starts the simulation + self._imdclient = IMDClient( + host, port, n_atoms, buffer_size=buffer_size, **kwargs + ) + + imdsinfo = self._imdclient.get_imdsessioninfo() + if imdsinfo.version != 3: + raise ValueError( + f"IMDReader: Detected IMD version v{imdsinfo.version}, " + + "but IMDReader is only compatible with v3" + ) + + self.ts = self._Timestep( + self.n_atoms, + positions=imdsinfo.positions, + velocities=imdsinfo.velocities, + forces=imdsinfo.forces, + **self._ts_kwargs, + ) + + try: + self._read_next_timestep() + except EOFError as e: + raise RuntimeError(f"IMDReader: Read error: {e}") from e + + def _read_frame(self, frame): + + imdf = self._imdclient.get_imdframe() + + self._frame = frame + self._load_imdframe_into_ts(imdf) + + logger.debug("IMDReader: Loaded frame %d", self._frame) + return self.ts + + def _load_imdframe_into_ts(self, imdf): + self.ts.frame = self._frame + if imdf.time is not None: + self.ts.time = imdf.time + self.ts.data["dt"] = imdf.dt + self.ts.data["step"] = imdf.step + if imdf.energies is not None: + self.ts.data.update( + {k: v for k, v in imdf.energies.items() if k != "step"} + ) + if imdf.box is not None: + self.ts.dimensions = core.triclinic_box(*imdf.box) + if imdf.positions is not None: + # must call copy because reference is expected to reset + # see 'test_frame_collect_all_same' in MDAnalysisTests.coordinates.base + np.copyto(self.ts.positions, imdf.positions) + if imdf.velocities is not None: + np.copyto(self.ts.velocities, imdf.velocities) + if imdf.forces is not None: + np.copyto(self.ts.forces, imdf.forces) + + @staticmethod + def _format_hint(thing): + if not isinstance(thing, str): + return False + # a weaker check for type hint + if thing.startswith("imd://"): + return True + else: + return False + + def close(self): + """Gracefully shut down the reader. Stops the producer thread.""" + logger.debug("IMDReader close() called") + if self._imdclient is not None: + self._imdclient.stop() + logger.debug("IMDReader shut down gracefully.") diff --git a/package/MDAnalysis/coordinates/__init__.py b/package/MDAnalysis/coordinates/__init__.py index a340711a414..f81c4915514 100644 --- a/package/MDAnalysis/coordinates/__init__.py +++ b/package/MDAnalysis/coordinates/__init__.py @@ -53,7 +53,7 @@ class that defines a common :ref:`Trajectory API` and allows other code to :class:`~MDAnalysis.coordinates.base.ProtoReader` object; all Readers are accessible through this entry point in the same manner ("`duck typing`_"). -There are three types of base Reader which act as starting points for each +There are four types of base Reader which act as starting points for each specific format. These are: :class:`~MDAnalysis.coordinates.base.ReaderBase` @@ -66,6 +66,12 @@ class that defines a common :ref:`Trajectory API` and allows other code to frame of information. This is used with formats such as GRO and CRD +:class:`~MDAnalysis.coordinates.base.StreamReaderBase` + A specialized Reader for continuous data streams such as live + simulation feeds. Unlike standard readers, streaming readers cannot + randomly access frames, rewind, or determine total length. This is + used for real-time trajectory data from simulations via IMD connections. + :class:`~MDAnalysis.coordinates.chain.ChainReader` An advanced Reader designed to read a sequence of files, to provide iteration over all the frames in each file seamlessly. @@ -277,6 +283,11 @@ class can choose an appropriate reader automatically. | library | | | file formats`_ and | | | | | :mod:`MDAnalysis.coordinates.chemfiles` | +---------------+-----------+-------+------------------------------------------------------+ + | IMD | imd:// | r | Receive simulation trajectory data using interactive | + | | : | | molecular dynamics version 3 (IMDv3) by configuring | + | | | | a socket address to a NAMD, GROMACS, or LAMMPS | + | | | | simulation. :mod:`MDAnalysis.coordinates.IMD` | + +---------------+-----------+-------+------------------------------------------------------+ .. [#a] This format can also be used to provide basic *topology* information (i.e. the list of atoms); it is possible to create a @@ -773,6 +784,7 @@ class can choose an appropriate reader automatically. from . import DMS from . import GMS from . import GRO +from . import IMD from . import INPCRD from . import LAMMPS from . import MOL2 diff --git a/package/MDAnalysis/coordinates/base.py b/package/MDAnalysis/coordinates/base.py index e44c502ef83..d2bfbe63d2c 100644 --- a/package/MDAnalysis/coordinates/base.py +++ b/package/MDAnalysis/coordinates/base.py @@ -50,6 +50,7 @@ .. autoclass:: FrameIteratorIndices +.. autoclass:: StreamFrameIteratorSliced .. _ReadersBase: @@ -87,8 +88,9 @@ .. autoclass:: ProtoReader :members: - - +.. autoclass:: StreamReaderBase + :members: + .. _WritersBase: Writers @@ -1844,3 +1846,468 @@ def __repr__(self): def convert(self, obj): raise NotImplementedError + +class StreamReaderBase(ReaderBase): + """Base class for readers that read a continuous stream of data. + + This class is designed for readers that process continuous data streams, + such as live feeds from simulations. Unlike traditional trajectory readers + that can randomly access frames, streaming readers have fundamental constraints: + + - **No random access**: Cannot seek to arbitrary frames (no ``traj[5]``) + - **Forward-only**: Can only iterate sequentially through frames + - **No length**: Total number of frames is unknown until stream ends + - **No rewinding**: Cannot restart or rewind the stream + - **No copying**: Cannot create independent copies of the reader + - **No reopening**: Cannot restart iteration once stream is consumed + - **No timeseries**: Cannot use ``timeseries()`` or bulk data extraction + - **No writers**: Cannot create ``Writer()`` or ``OtherWriter()`` instances + - **No pickling**: Cannot serialize reader instances (limits multiprocessing) + - **No StreamWriterBase**: No complementary Writer class available for streaming data + + + The reader raises :exc:`RuntimeError` for operations that require random + access or rewinding, including ``rewind()``, ``copy()``, ``timeseries()``, + ``Writer()``, ``OtherWriter()``, and ``len()``. Only slice notation is supported for iteration. + + Parameters + ---------- + filename : str or file-like + Source of the streaming data + convert_units : bool, optional + Whether to convert units from native to MDAnalysis units (default: True) + **kwargs + Additional keyword arguments passed to the parent ReaderBase + + See Also + -------- + StreamFrameIteratorSliced : Iterator for stepped streaming access + ReaderBase : Base class for standard trajectory readers + + + .. versionadded:: 2.10.0 + """ + + def __init__(self, filename, convert_units=True, **kwargs): + super(StreamReaderBase, self).__init__( + filename, convert_units=convert_units, **kwargs + ) + self._init_scope = True + self._reopen_called = False + self._first_ts = None + self._frame = -1 + + def _read_next_timestep(self): + # No rewinding- to both load the first frame after __init__ + # and access it again during iteration, we need to store first ts in mem + if not self._init_scope and self._frame == -1: + self._frame += 1 + # can't simply return the same ts again- transformations would be applied twice + # instead, return the pre-transformed copy + return self._first_ts + + ts = self._read_frame(self._frame + 1) + + if self._init_scope: + self._first_ts = self.ts.copy() + self._init_scope = False + + return ts + + @property + def n_frames(self): + """Changes as stream is processed unlike other readers""" + raise RuntimeError( + "{}: n_frames is unknown".format(self.__class__.__name__) + ) + + def __len__(self): + raise RuntimeError( + "{} has unknown length".format(self.__class__.__name__) + ) + + def next(self): + """Advance to the next timestep in the streaming trajectory. + + Streaming readers process frames sequentially and cannot rewind + once iteration completes. Use ``for ts in trajectory`` for iteration. + + Returns + ------- + Timestep + The next timestep in the stream + + Raises + ------ + StopIteration + When the stream ends or no more frames are available + """ + try: + ts = self._read_next_timestep() + except (EOFError, IOError): + # Don't rewind here like we normally would + raise StopIteration from None + else: + for auxname, reader in self._auxs.items(): + ts = self._auxs[auxname].update_ts(ts) + + ts = self._apply_transformations(ts) + + return ts + + def rewind(self): + """Rewinding is not supported for streaming trajectories. + + Streaming readers process data continuously from streams + and cannot restart or go backward in the stream once consumed. + + Raises + ------ + RuntimeError + Always raised, as rewinding is not supported for streaming trajectories + """ + raise RuntimeError( + "{}: Stream-based readers can't be rewound".format( + self.__class__.__name__ + ) + ) + + # Incompatible methods + def copy(self): + """Reader copying is not supported for streaming trajectories. + + Streaming readers maintain internal state and connection resources + that cannot be duplicated. Each stream connection is unique and + cannot be copied. + + Raises + ------ + RuntimeError + Always raised, as copying is not supported for streaming trajectories + """ + raise RuntimeError( + "{} does not support copying".format(self.__class__.__name__) + ) + + def _reopen(self): + """Prepare stream for iteration - can only be called once. + + Streaming readers cannot be reopened once iteration begins. + This method is called internally during iteration setup and + will raise an error if called multiple times. + + Raises + ------ + RuntimeError + If the stream has already been opened for iteration + """ + if self._reopen_called: + raise RuntimeError( + "{}: Cannot reopen stream".format(self.__class__.__name__) + ) + self._frame = -1 + self._reopen_called = True + + def timeseries(self, **kwargs): + """Timeseries extraction is not supported for streaming trajectories. + + Streaming readers cannot randomly access frames or store bulk coordinate + data in memory, which ``timeseries()`` requires. Use sequential frame + iteration instead. + + Parameters + ---------- + **kwargs + Any keyword arguments (ignored, as method is not supported) + + Raises + ------ + RuntimeError + Always raised, as timeseries extraction is not supported for + streaming trajectories + """ + raise RuntimeError( + "{}: cannot access timeseries for streamed trajectories".format(self.__class__.__name__) + ) + + def __getitem__(self, frame): + """Return an iterator for slicing a streaming trajectory. + + Parameters + ---------- + frame : slice + Slice object. Only the step parameter is meaningful for streams. + + Returns + ------- + FrameIteratorAll or StreamFrameIteratorSliced + Iterator for the requested slice. + + Raises + ------ + TypeError + If frame is not a slice object. + ValueError + If slice contains start or stop values. + + Examples + -------- + >>> for ts in traj[:]: # All frames sequentially + ... process(ts) + >>> for ts in traj[::5]: # Every 5th frame + ... process(ts) + + See Also + -------- + StreamFrameIteratorSliced + """ + if isinstance(frame, slice): + _, _, step = self.check_slice_indices( + frame.start, frame.stop, frame.step + ) + if step is None: + return FrameIteratorAll(self) + else: + return StreamFrameIteratorSliced(self, step) + else: + raise TypeError( + "Streamed trajectories must be an indexed using a slice" + ) + + def check_slice_indices(self, start, stop, step): + """Check and validate slice indices for streaming trajectories. + + Streaming trajectories have fundamental constraints that differ from + traditional trajectory files: + + * **No start/stop indices**: Since streams process data continuously + without knowing the total length, ``start`` and ``stop`` must be ``None`` + * **Step-only slicing**: Only the ``step`` parameter is meaningful, + controlling how many frames to skip during iteration + * **Forward-only**: ``step`` must be positive (> 0) as streams cannot + be processed backward in time + + Parameters + ---------- + start : int or None + Starting frame index. Must be ``None`` for streaming readers. + stop : int or None + Ending frame index. Must be ``None`` for streaming readers. + step : int or None + Step size for iteration. Must be positive integer or ``None`` + (equivalent to 1). + + Returns + ------- + tuple + (start, stop, step) with validated values + + Raises + ------ + ValueError + If ``start`` or ``stop`` are not ``None``, or if ``step`` is + not a positive integer. + + Examples + -------- + Valid streaming slices:: + + traj[:] # All frames (step=None, equivalent to step=1) + traj[::2] # Every 2nd frame + traj[::10] # Every 10th frame + + Invalid streaming slices:: + + traj[5:] # Cannot specify start index + traj[:100] # Cannot specify stop index + traj[5:100:2] # Cannot specify start or stop indices + traj[::-1] # Cannot go backwards (negative step) + + See Also + -------- + __getitem__ + StreamFrameIteratorSliced + + + .. versionadded:: 2.10.0 + """ + if start is not None: + raise ValueError( + "{}: Cannot expect a start index from a stream, 'start' must be None".format( + self.__class__.__name__ + ) + ) + if stop is not None: + raise ValueError( + "{}: Cannot expect a stop index from a stream, 'stop' must be None".format( + self.__class__.__name__ + ) + ) + if step is not None: + if isinstance(step, numbers.Integral): + if step < 1: + raise ValueError( + "{}: Cannot go backwards in a stream, 'step' must be > 0".format( + self.__class__.__name__ + ) + ) + else: + raise ValueError( + "{}: 'step' must be an integer".format( + self.__class__.__name__ + ) + ) + + return start, stop, step + + def Writer(self, filename, **kwargs): + """Writer creation is not supported for streaming trajectories. + + Writer creation requires trajectory metadata that streaming readers + cannot provide due to their sequential processing nature. + + Parameters + ---------- + filename : str + Output filename (ignored, as method is not supported) + **kwargs + Additional keyword arguments (ignored, as method is not supported) + + Raises + ------ + RuntimeError + Always raised, as writer creation is not supported for streaming trajectories + """ + raise RuntimeError( + "{}: cannot create Writer for streamed trajectories".format( + self.__class__.__name__ + ) + ) + + def OtherWriter(self, filename, **kwargs): + """Writer creation is not supported for streaming trajectories. + + OtherWriter initialization requires frame-based parameters and trajectory + indexing information. Streaming readers process data sequentially + without meaningful frame indexing, making writer setup impossible. + + Parameters + ---------- + filename : str + Output filename (ignored, as method is not supported) + **kwargs + Additional keyword arguments (ignored, as method is not supported) + + Raises + ------ + RuntimeError + Always raised, as writer creation is not supported for streaming trajectories + """ + raise RuntimeError( + "{}: cannot create OtherWriter for streamed trajectories".format( + self.__class__.__name__ + ) + ) + + def __getstate__(self): + raise NotImplementedError( + "{} does not support pickling".format(self.__class__.__name__) + ) + + def __setstate__(self, state: object): + raise NotImplementedError( + "{} does not support pickling".format(self.__class__.__name__) + ) + + def __repr__(self): + return ( + "<{cls} {fname} with continuous stream of {natoms} atoms>" + "".format( + cls=self.__class__.__name__, + fname=self.filename, + natoms=self.n_atoms, + ) + ) + + +class StreamFrameIteratorSliced(FrameIteratorBase): + """Iterator for sliced frames in a streamed trajectory. + + Created when slicing a streaming trajectory with a step parameter + (e.g., ``trajectory[::n]``). Reads every nth frame from the continuous + stream, discarding intermediate frames for performance. + + This differs from iterating over all frames (``trajectory[:]``) which uses + :class:`FrameIteratorAll` and processes every frame sequentially without + skipping. + + Streaming constraints apply to the sliced iterator: + + - Frames cannot be accessed randomly (no indexing support) + - The total number of frames is unknown until streaming ends + - Rewinding or restarting iteration is not possible + - Only forward iteration with a fixed step size is supported + + Parameters + ---------- + trajectory : StreamReaderBase + The streaming trajectory reader to iterate over. Must be a + stream-based reader that supports continuous data reading. + step : int + Step size for iteration. Must be a positive integer. A step + of 1 reads every frame, step of 2 reads every other frame, etc. + + See Also + -------- + StreamReaderBase + FrameIteratorBase + + + .. versionadded:: 2.10.0 + """ + + def __init__(self, trajectory, step): + super().__init__(trajectory) + self._step = step + + def __iter__(self): + # Calling reopen tells reader + # it can't be reopened again + self.trajectory._reopen() + return self + + def __next__(self): + try: + # Burn the timesteps until we reach the desired step + # Don't use next() to avoid unnecessary transformations + while (self.trajectory._frame + 1) % self._step != 0: + self.trajectory._read_next_timestep() + except (EOFError, IOError): + # Don't rewind here like we normally would + raise StopIteration from None + + return self.trajectory.next() + + def __len__(self): + raise RuntimeError( + "{} has unknown length".format(self.__class__.__name__) + ) + + def __getitem__(self, frame): + raise RuntimeError("Sliced iterator does not support indexing") + + @property + def step(self): + """The step size for sliced frame iteration. + + Returns the step interval used when iterating through frames in a + streaming trajectory. For example, a step of 2 means every second + frame is processed, while a step of 1 processes every frame. + + Returns + ------- + int + Step size for iteration. Always a positive integer greater than 0. + + """ + return self._step \ No newline at end of file diff --git a/package/MDAnalysis/coordinates/timestep.pyx b/package/MDAnalysis/coordinates/timestep.pyx index ee12feae375..09e44c0f551 100644 --- a/package/MDAnalysis/coordinates/timestep.pyx +++ b/package/MDAnalysis/coordinates/timestep.pyx @@ -938,8 +938,8 @@ cdef class Timestep: return 1.0 @dt.setter - def dt(self, new): - self.data['dt'] = new + def dt(self, new_dt): + self.data['dt'] = new_dt @dt.deleter def dt(self): @@ -966,8 +966,8 @@ cdef class Timestep: return self.dt * self.frame + offset @time.setter - def time(self, new): - self.data['time'] = new + def time(self, new_time): + self.data['time'] = new_time @time.deleter def time(self): diff --git a/package/doc/sphinx/source/conf.py b/package/doc/sphinx/source/conf.py index 4c63b7bcacd..0aba418eb3b 100644 --- a/package/doc/sphinx/source/conf.py +++ b/package/doc/sphinx/source/conf.py @@ -349,4 +349,5 @@ class KeyStyle(UnsrtStyle): "pathsimanalysis": ("https://www.mdanalysis.org/PathSimAnalysis/", None), "mdahole2": ("https://www.mdanalysis.org/mdahole2/", None), "dask": ("https://docs.dask.org/en/stable/", None), + "imdclient": ("https://imdclient.readthedocs.io/en/stable/", None), } diff --git a/package/doc/sphinx/source/documentation_pages/coordinates/IMD.rst b/package/doc/sphinx/source/documentation_pages/coordinates/IMD.rst new file mode 100644 index 00000000000..d4d8013d61d --- /dev/null +++ b/package/doc/sphinx/source/documentation_pages/coordinates/IMD.rst @@ -0,0 +1 @@ +.. automodule:: MDAnalysis.coordinates.IMD \ No newline at end of file diff --git a/package/doc/sphinx/source/documentation_pages/coordinates_modules.rst b/package/doc/sphinx/source/documentation_pages/coordinates_modules.rst index a58a6b44baf..8a27767f23f 100644 --- a/package/doc/sphinx/source/documentation_pages/coordinates_modules.rst +++ b/package/doc/sphinx/source/documentation_pages/coordinates_modules.rst @@ -27,6 +27,7 @@ provide the format in the keyword argument *format* to coordinates/GSD coordinates/GRO coordinates/H5MD + coordinates/IMD coordinates/INPCRD coordinates/LAMMPS coordinates/MMTF diff --git a/package/doc/sphinx/source/documentation_pages/references.rst b/package/doc/sphinx/source/documentation_pages/references.rst index 7de33322fd3..284d5a63b78 100644 --- a/package/doc/sphinx/source/documentation_pages/references.rst +++ b/package/doc/sphinx/source/documentation_pages/references.rst @@ -229,6 +229,16 @@ If you use H5MD files using pp. 18 – 26, 2021. doi:`10.25080/majora-1b6fd038-005. `_ +.. comment:: + + If you use IMD capability with :mod:`MDAnalysis.coordinates.IMD.py`, please cite [IMDv3paper]_. + + .. [IMDv3paper] Authors (YEAR). + IMDv3 Manuscript Title. + *Journal*, 185. doi:`insert-doi-here `_ + +.. todo:: Fill in the final IMDv3 citation once the paper is published. + See https://github.com/MDAnalysis/mdanalysis/issues/5094 .. _citations-using-duecredit: diff --git a/package/pyproject.toml b/package/pyproject.toml index 008bdf8ecd8..bb9d9d48929 100644 --- a/package/pyproject.toml +++ b/package/pyproject.toml @@ -76,6 +76,7 @@ extra_formats = [ "pytng>=0.2.3", "gsd>3.0.0", "rdkit>=2022.09.1", + "imdclient>=0.2.2", ] analysis = [ "biopython>=1.80", diff --git a/testsuite/MDAnalysisTests/coordinates/test_imd.py b/testsuite/MDAnalysisTests/coordinates/test_imd.py new file mode 100644 index 00000000000..2fed2f761fd --- /dev/null +++ b/testsuite/MDAnalysisTests/coordinates/test_imd.py @@ -0,0 +1,589 @@ +"""Test for MDAnalysis trajectory reader expectations +""" + +import importlib +import pickle +import sys +from types import ModuleType +from weakref import ref + +import pytest +import numpy as np +from numpy.testing import ( + assert_allclose, + assert_almost_equal, + assert_array_almost_equal, + assert_equal, +) + +import MDAnalysis as mda +from MDAnalysis.coordinates.IMD import ( + HAS_IMDCLIENT, + MIN_IMDCLIENT_VERSION, + IMDReader, +) +from MDAnalysis.transformations import translate + +if HAS_IMDCLIENT: + import imdclient + from imdclient.tests.server import InThreadIMDServer + from imdclient.tests.utils import ( + create_default_imdsinfo_v3, + get_free_port, + ) + +from MDAnalysisTests.coordinates.base import ( + assert_timestep_almost_equal, + BaseReference, + MultiframeReaderTest, +) +from MDAnalysisTests.datafiles import ( + COORDINATES_H5MD, + COORDINATES_TOPOLOGY, + COORDINATES_TRR, +) + + +class IMDModuleStateManager: + """Context manager to completely backup and restore imdclient/IMD module state. + + We need a custom manager because IMD changes its own state (HAS_IMDCLIENT) when it is imported + and we are going to manipulate the state of the imdclient module that IMD sees. + """ + + def __init__(self): + self.original_modules = None + self.imd_was_imported = False + + def __enter__(self): + # Backup sys.modules + self.original_modules = sys.modules.copy() + + # Check if IMD module was already imported + self.imd_was_imported = "MDAnalysis.coordinates.IMD" in sys.modules + + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + # Restore sys.modules completely first + sys.modules.clear() + sys.modules.update(self.original_modules) + + # If IMD module was originally imported, force a fresh reload to restore original state + # This ensures that HAS_IMDCLIENT and other globals are recalculated with the real imdclient + if self.imd_was_imported: + # Remove the potentially corrupted IMD module + sys.modules.pop("MDAnalysis.coordinates.IMD", None) + # Fresh import will re-evaluate all globals + import MDAnalysis.coordinates.IMD + + +class TestImport: + """Test imdclient import behavior and HAS_IMDCLIENT flag.""" + + def _setup_mock_imdclient(self, monkeypatch, version): + """Helper method to set up mock imdclient with specified version.""" + # Remove IMD and imdclient modules to force fresh import + monkeypatch.delitem( + sys.modules, "MDAnalysis.coordinates.IMD", raising=False + ) + monkeypatch.delitem(sys.modules, "imdclient", raising=False) + + module_name = "imdclient" + mocked_module = ModuleType(module_name) + IMDClient_module = ModuleType(f"{module_name}.IMDClient") + + class MockIMDClient: + pass + + IMDClient_module.IMDClient = MockIMDClient + mocked_module.IMDClient = IMDClient_module + mocked_module.__version__ = version + + utils_module = ModuleType(f"{module_name}.utils") + utils_module.parse_host_port = lambda x: ("localhost", 12345) + mocked_module.utils = utils_module + + monkeypatch.setitem(sys.modules, module_name, mocked_module) + monkeypatch.setitem( + sys.modules, f"{module_name}.IMDClient", IMDClient_module + ) + monkeypatch.setitem(sys.modules, f"{module_name}.utils", utils_module) + + return mocked_module + + def test_has_minversion(self, monkeypatch): + """Test that HAS_IMDCLIENT is True when imdclient >= MIN_IMDCLIENT_VERSION.""" + with IMDModuleStateManager(): + self._setup_mock_imdclient(monkeypatch, str(MIN_IMDCLIENT_VERSION)) + + # Import and check HAS_IMDCLIENT with compatible version + import MDAnalysis.coordinates.IMD + from MDAnalysis.coordinates.IMD import HAS_IMDCLIENT + + assert ( + HAS_IMDCLIENT + ), f"HAS_IMDCLIENT should be True with version {MIN_IMDCLIENT_VERSION}" + + def test_no_minversion(self, monkeypatch): + """Test that HAS_IMDCLIENT is False when imdclient version is too old.""" + with IMDModuleStateManager(): + self._setup_mock_imdclient(monkeypatch, "0.0.0") + + # Import and check HAS_IMDCLIENT with incompatible version + import MDAnalysis.coordinates.IMD + from MDAnalysis.coordinates.IMD import HAS_IMDCLIENT + + assert ( + not HAS_IMDCLIENT + ), "HAS_IMDCLIENT should be False with version 0.0.0" + + def test_missing_ImportError(self, monkeypatch): + """Test that IMDReader raises ImportError when HAS_IMDCLIENT=False.""" + with IMDModuleStateManager(): + self._setup_mock_imdclient(monkeypatch, "0.0.0") + + # Import with incompatible version (HAS_IMDCLIENT=False) + import MDAnalysis.coordinates.IMD + from MDAnalysis.coordinates.IMD import IMDReader + + # IMDReader should raise ImportError when HAS_IMDCLIENT=False + with pytest.raises( + ImportError, match="IMDReader requires the imdclient" + ): + IMDReader("imd://localhost:12345", n_atoms=5) + + +@pytest.mark.skipif(not HAS_IMDCLIENT, reason="IMDClient not installed") +class IMDReference(BaseReference): + def __init__(self): + super(IMDReference, self).__init__() + # Serve TRR traj data via the server + traj = mda.coordinates.TRR.TRRReader(COORDINATES_TRR) + self.server = InThreadIMDServer(traj) + self.server.set_imdsessioninfo(create_default_imdsinfo_v3()) + + self.n_atoms = traj.n_atoms + self.prec = 3 + + self.trajectory = "imd://localhost" + self.topology = COORDINATES_TOPOLOGY + self.changing_dimensions = True + self.reader = IMDReader + + self.first_frame.velocities = self.first_frame.positions / 10 + self.first_frame.forces = self.first_frame.positions / 100 + + self.second_frame.velocities = self.second_frame.positions / 10 + self.second_frame.forces = self.second_frame.positions / 100 + + self.last_frame.velocities = self.last_frame.positions / 10 + self.last_frame.forces = self.last_frame.positions / 100 + + self.jump_to_frame.velocities = self.jump_to_frame.positions / 10 + self.jump_to_frame.forces = self.jump_to_frame.positions / 100 + + def iter_ts(self, i): + ts = self.first_frame.copy() + ts.positions = 2**i * self.first_frame.positions + ts.velocities = ts.positions / 10 + ts.forces = ts.positions / 100 + ts.time = i + ts.frame = i + return ts + + +@pytest.mark.skipif(not HAS_IMDCLIENT, reason="IMDClient not installed") +class TestIMDReaderBaseAPI(MultiframeReaderTest): + + @pytest.fixture(scope="function") + def ref(self): + """Not a static method like in base class- need new server for each test""" + reference = IMDReference() + yield reference + reference.server.cleanup() + + @staticmethod + @pytest.fixture() + def reader(ref): + # This will start the test IMD Server, waiting for a connection + # to then send handshake & first frame + ref.server.handshake_sequence("localhost") + # This will connect to the test IMD Server and read the first frame + reader = ref.reader( + f"{ref.trajectory}:{ref.server.port}", + n_atoms=ref.n_atoms, + buffer_size=1 * 1024 * 1024, + ) + # Send the rest of the frames- small enough to all fit in socket itself + ref.server.send_frames(1, 5) + + reader.add_auxiliary( + "lowf", + ref.aux_lowf, + dt=ref.aux_lowf_dt, + initial_time=0, + time_selector=None, + ) + reader.add_auxiliary( + "highf", + ref.aux_highf, + dt=ref.aux_highf_dt, + initial_time=0, + time_selector=None, + ) + yield reader + reader.close() + + @staticmethod + @pytest.fixture() + def transformed(ref): + # This will start the test IMD Server, waiting for a connection + # to then send handshake & first frame + ref.server.handshake_sequence("localhost") + # This will connect to the test IMD Server and read the first frame + transformed = ref.reader( + f"{ref.trajectory}:{ref.server.port}", + n_atoms=ref.n_atoms, + buffer_size=1 * 1024 * 1024, + ) + # Send the rest of the frames- small enough to all fit in socket itself + ref.server.send_frames(1, 5) + transformed.add_transformations( + translate([1, 1, 1]), translate([0, 0, 0.33]) + ) + return transformed + + def test_n_frames(self, ref, reader): + with pytest.raises(RuntimeError, match="n_frames is unknown"): + reader.n_frames + + def test_first_frame(self, ref, reader): + # don't rewind here as in inherited base test + assert_timestep_almost_equal( + reader.ts, ref.first_frame, decimal=ref.prec + ) + + def test_get_writer_1(self, ref, reader, tmpdir): + with pytest.raises( + RuntimeError, + match="cannot create Writer for streamed trajectories", + ): + reader.Writer(str(tmpdir.join("output"))) + + def test_get_writer_2(self, ref, reader, tmpdir): + with pytest.raises( + RuntimeError, + match="cannot create Writer for streamed trajectories", + ): + reader.Writer(str(tmpdir.join("output")), n_atoms=100) + + def test_OtherWriter_RuntimeError(self, reader, tmpdir): + with pytest.raises( + RuntimeError, + match="cannot create OtherWriter for streamed trajectories", + ): + reader.OtherWriter(tmpdir.join("output")) + + def test_total_time(self, ref, reader): + pytest.skip("`total_time` is unknown for IMDReader") + + def test_changing_dimensions(self, ref, reader): + with pytest.raises( + RuntimeError, match="Stream-based readers can't be rewound" + ): + reader.rewind() + + def test_iter(self, ref, reader): + for i, ts in enumerate(reader): + assert_timestep_almost_equal(ts, ref.iter_ts(i), decimal=ref.prec) + + def test_first_dimensions(self, ref, reader): + # don't rewind here as in inherited base test + if ref.dimensions is None: + assert reader.ts.dimensions is None + else: + assert_allclose( + reader.ts.dimensions, + ref.dimensions, + rtol=0, + atol=1.5 * 10 ** (-ref.prec), + ) + + def test_transformed(self, ref, transformed): + # see transformed fixture + ref_trans = ref.first_frame.positions + 1 + ref_trans[:, 2] += 0.33 + assert_allclose(transformed.ts.positions, ref_trans) + + def test_volume(self, ref, reader): + # don't rewind here as in inherited base test + vol = reader.ts.volume + # Here we can only be sure about the numbers upto the decimal point due + # to limited floating point precision. + assert_allclose(vol, ref.volume, rtol=0, atol=1.5e0) + + def test_reload_auxiliaries_from_description(self, ref, reader): + pytest.skip("Cannot create two IMDReaders on the same stream") + + def test_stop_iter(self, reader): + with pytest.raises( + RuntimeError, match="Stream-based readers can't be rewound" + ): + reader.rewind() + + def test_iter_rewinds(self, reader): + pytest.skip("IMDReader cannot be rewound") + + def test_timeseries_shape(self, reader): + pytest.skip("IMDReader does not support timeseries") + + def test_timeseries_asel_shape(self, reader): + pytest.skip("IMDReader does not support timeseries") + + def test_timeseries_values(self, reader): + pytest.skip("IMDReader does not support timeseries") + + def test_transformations_2iter(self, ref, transformed): + pytest.skip("IMDReader cannot be reopened") + + def test_transformations_slice(self, ref, transformed): + pytest.skip("IMDReader cannot be reopened") + + def test_transformations_switch_frame(self, ref, transformed): + pytest.skip("IMDReader cannot be reopened") + + def test_transformation_rewind(self, ref, transformed): + pytest.skip("IMDReader cannot be reopened") + + def test_pickle_reader(self, reader): + with pytest.raises( + NotImplementedError, match="does not support pickling" + ): + pickle.dumps(reader) + + def test_pickle_next_ts_reader(self, reader): + pytest.skip("IMDReader cannot be pickled") + + def test_pickle_last_ts_reader(self, reader): + pytest.skip("IMDReader cannot be pickled") + + def test_transformations_copy(self, ref, transformed): + with pytest.raises(RuntimeError, match="does not support copying"): + transformed.copy() + + def test_timeseries_empty_asel(self, reader): + pytest.skip("IMDReader does not support timeseries") + + def test_timeseries_empty_atomgroup(self, reader): + pytest.skip("IMDReader does not support timeseries") + + def test_timeseries_asel_warns_deprecation(self, reader): + pytest.skip("IMDReader does not support timeseries") + + def test_timeseries_atomgroup(self, reader): + pytest.skip("IMDReader does not support timeseries") + + def test_timeseries_atomgroup_asel_mutex(self, reader): + pytest.skip("IMDReader does not support timeseries") + + def test_last_frame(self, ref, reader): + pytest.skip("IMDReader cannot be rewound") + + def test_go_over_last_frame(self, ref, reader): + pytest.skip("IMDReader must be an indexed using a slice") + + def test_frame_jump(self, ref, reader): + pytest.skip("IMDReader must be an indexed using a slice") + + def test_frame_jump_issue1942(self, ref, reader): + pytest.skip("IMDReader must be an indexed using a slice") + + def test_next_gives_second_frame(self, ref, reader): + # don't recreate reader here as in inherited base test + ts = reader.next() + assert_timestep_almost_equal(ts, ref.second_frame, decimal=ref.prec) + + def test_frame_collect_all_same(self, reader): + pytest.skip("IMDReader has independent coordinates") + + +@pytest.fixture +def universe(): + return mda.Universe(COORDINATES_TOPOLOGY, COORDINATES_H5MD) + + +@pytest.fixture +def port(): + return get_free_port() + + +@pytest.fixture +def imdsinfo(): + return create_default_imdsinfo_v3() + + +@pytest.mark.skipif(not HAS_IMDCLIENT, reason="IMDClient not installed") +class TestStreamIteration: + @pytest.fixture + def reader(self, universe, imdsinfo): + server = InThreadIMDServer(universe.trajectory) + server.set_imdsessioninfo(imdsinfo) + server.handshake_sequence("localhost", first_frame=True) + reader = IMDReader( + f"imd://localhost:{server.port}", + n_atoms=universe.trajectory.n_atoms, + buffer_size=1 * 1024 * 1024, + ) + server.send_frames(1, 5) + + yield reader + server.cleanup() + reader.close() + + def test_iterate_step(self, reader, universe): + i = 0 + for ts in reader[::2]: + assert ts.frame == i + i += 2 + + def test_iterate_twice_sliced_raises_error(self, reader): + for ts in reader[::2]: + pass + with pytest.raises(RuntimeError, match="Cannot reopen stream"): + for ts in reader[::2]: + pass + + def test_iterate_twice_all_raises_error(self, reader): + for ts in reader: + pass + with pytest.raises(RuntimeError, match="Cannot reopen stream"): + for ts in reader: + pass + + def test_iterate_twice_fi_all_raises_error(self, reader): + for ts in reader[:]: + pass + with pytest.raises(RuntimeError, match="Cannot reopen stream"): + for ts in reader[:]: + pass + + def test_index_stream_raises_error(self, reader): + with pytest.raises(TypeError, match="Streamed trajectories must be"): + reader[0] + + def test_iterate_backwards_raises_error(self, reader): + with pytest.raises(ValueError, match="Cannot go backwards"): + for ts in reader[::-1]: + pass + + def test_iterate_start_stop_raises_error(self, reader): + with pytest.raises(ValueError, match="Cannot expect a start index"): + for ts in reader[1:3]: + pass + + def test_subslice_fi_all_after_iteration_raises_error(self, reader): + sliced_reader = reader[:] + for ts in sliced_reader: + pass + sub_sliced_reader = sliced_reader[::1] + with pytest.raises(RuntimeError): + for ts in sub_sliced_reader: + pass + + def test_timeseries_raises(self, reader): + with pytest.raises( + RuntimeError, + match="cannot access timeseries for streamed trajectories", + ): + reader.timeseries() + + def test_step_property(self, reader): + """Test that the step property returns the correct step size.""" + # Test step property for different slice steps + sliced_reader = reader[::1] + assert sliced_reader.step == 1 + + sliced_reader_step5 = reader[::5] + assert sliced_reader_step5.step == 5 + + +@pytest.mark.skipif(not HAS_IMDCLIENT, reason="IMDClient not installed") +def test_n_atoms_not_specified(universe, imdsinfo): + server = InThreadIMDServer(universe.trajectory) + server.set_imdsessioninfo(imdsinfo) + server.handshake_sequence("localhost", first_frame=True) + with pytest.raises( + ValueError, + match="IMDReader: n_atoms must be specified", + ): + IMDReader( + f"imd://localhost:{server.port}", + ) + server.cleanup() + + +@pytest.mark.skipif(not HAS_IMDCLIENT, reason="IMDClient not installed") +def test_imd_stream_empty(universe, imdsinfo): + server = InThreadIMDServer(universe.trajectory) + server.set_imdsessioninfo(imdsinfo) + server.handshake_sequence("localhost", first_frame=False) + with pytest.raises( + RuntimeError, + match="IMDReader: Read error", + ): + IMDReader( + f"imd://localhost:{server.port}", + n_atoms=universe.trajectory.n_atoms, + ) + server.cleanup() + + +@pytest.mark.skipif(not HAS_IMDCLIENT, reason="IMDClient not installed") +def test_create_imd_universe(universe, imdsinfo): + server = InThreadIMDServer(universe.trajectory) + server.set_imdsessioninfo(imdsinfo) + server.handshake_sequence("localhost", first_frame=True) + u_imd = mda.Universe( + COORDINATES_TOPOLOGY, + f"imd://localhost:{server.port}", + n_atoms=universe.trajectory.n_atoms, + ) + assert type(u_imd.trajectory).__name__ == "IMDReader" + with pytest.raises(ValueError, match="IMDReader: Invalid IMD URL"): + u_imd = mda.Universe( + COORDINATES_TOPOLOGY, + f"imd://localhost:{port}/invalid", + n_atoms=universe.trajectory.n_atoms, + ) + server.cleanup() + + +def test_imd_format_hint(): + assert IMDReader._format_hint("imd://localhost:12345") + assert IMDReader._format_hint("imd://localhost:12345/invalid") + assert not IMDReader._format_hint("not_a_valid_imd_url") + assert not IMDReader._format_hint(12345) + assert not IMDReader._format_hint(None) + + +@pytest.mark.skipif(not HAS_IMDCLIENT, reason="IMDClient not installed") +def test_wrong_imd_protocol_version(universe, imdsinfo): + """Test that IMDReader raises ValueError for non-v3 protocol versions.""" + # Modify the fixture to have wrong version + imdsinfo.version = 2 # Wrong version, should be 3 + + server = InThreadIMDServer(universe.trajectory) + server.set_imdsessioninfo(imdsinfo) + server.handshake_sequence("localhost", first_frame=True) + + with pytest.raises( + ValueError, + match=rf"IMDReader: Detected IMD version v{imdsinfo.version}, " + rf"but IMDReader is only compatible with v3", + ): + IMDReader( + f"imd://localhost:{server.port}", + n_atoms=universe.trajectory.n_atoms, + ) + server.cleanup() diff --git a/testsuite/MDAnalysisTests/coordinates/test_reader_api.py b/testsuite/MDAnalysisTests/coordinates/test_reader_api.py index 4ae5c0f5c6c..2e6e29c852a 100644 --- a/testsuite/MDAnalysisTests/coordinates/test_reader_api.py +++ b/testsuite/MDAnalysisTests/coordinates/test_reader_api.py @@ -27,6 +27,7 @@ from MDAnalysis.coordinates.base import ( ReaderBase, SingleFrameReaderBase, + StreamReaderBase, Timestep, ) from numpy.testing import assert_allclose, assert_equal @@ -81,6 +82,24 @@ def _read_first_frame(self): self.ts.frame = 0 +class AmazingStreamReader(StreamReaderBase): + format = "AmazingStream" + + def __init__(self, filename, n_atoms): + self.n_atoms = n_atoms + self._mocked_frames = [Timestep(n_atoms) for _ in range(3)] + super().__init__(filename) + + def _read_frame(self, frame): + self._frame = frame + if self._frame >= len(self._mocked_frames): + raise EOFError("End of stream") + ts = self._mocked_frames[self._frame] + ts.frame = self._frame + self.ts = ts + return ts + + class _TestReader(object): __test__ = False """Basic API readers""" @@ -445,3 +464,80 @@ def test_iter_rewind(self, reader): assert_allclose(ts.positions, np.zeros((10, 3))) assert_allclose(reader.ts.positions, np.zeros((10, 3))) + + +class _Stream: + n_atoms = 3 + readerclass = AmazingStreamReader + + +class TestStreamReader(_Stream): + @pytest.fixture + def reader(self): + return self.readerclass("dummy", n_atoms=self.n_atoms) + + def test_repr(self, reader): + rep = repr(reader) + assert "AmazingStreamReader" in rep + assert "continuous stream" in rep + assert "3 atoms" in rep + + def test_read_and_exhaust_stream(self, reader): + ts0 = reader.next() + ts1 = reader.next() + ts2 = reader.next() + assert ts0.frame == 0 + assert ts1.frame == 1 + assert ts2.frame == 2 + + with pytest.raises(StopIteration): + reader.next() + + def test_len_and_n_frames_raise(self, reader): + with pytest.raises(RuntimeError): + _ = len(reader) + with pytest.raises(RuntimeError): + _ = reader.n_frames + + def test_rewind_raises(self, reader): + with pytest.raises(RuntimeError, match="can't be rewound"): + reader.rewind() + + def test_copy_raises(self, reader): + with pytest.raises(RuntimeError, match="does not support copying"): + reader.copy() + + def test_timeseries_raises(self, reader): + with pytest.raises(RuntimeError, match="cannot access timeseries"): + reader.timeseries() + + def test_reopen_only_once(self, reader): + reader._reopen() + with pytest.raises(RuntimeError, match="Cannot reopen stream"): + reader._reopen() + + def test_slice_reader(self, reader): + sliced = reader[slice(None, None, 2)] + with pytest.raises(RuntimeError, match="has unknown length"): + len(sliced) + with pytest.raises(RuntimeError, match="does not support indexing"): + sliced[0] + + for i, ts in enumerate(sliced): + assert ts.frame == i * 2 + + def test_check_slice_index_errors(self, reader): + with pytest.raises(ValueError, match="start.*must be None"): + reader.check_slice_indices(0, None, 1) + with pytest.raises(ValueError, match="stop.*must be None"): + reader.check_slice_indices(None, 1, 1) + with pytest.raises(ValueError, match="must be > 0"): + reader.check_slice_indices(None, None, 0) + with pytest.raises(ValueError, match="must be an integer"): + reader.check_slice_indices(None, None, 1.5) + + def test_pickle_methods(self, reader): + with pytest.raises(NotImplementedError): + reader.__getstate__() + with pytest.raises(NotImplementedError): + reader.__setstate__({})