diff --git a/imdclient/IMDClient.py b/imdclient/IMDClient.py index 234b5de..1bc56ac 100644 --- a/imdclient/IMDClient.py +++ b/imdclient/IMDClient.py @@ -247,9 +247,7 @@ def _await_IMD_handshake(self) -> IMDSessionInfo: header = IMDHeader(h_buf) if header.type != IMDHeaderType.IMD_HANDSHAKE: - raise ValueError( - f"Expected header type `IMD_HANDSHAKE`, got {header.type}" - ) + raise ValueError(f"Expected header type `IMD_HANDSHAKE`, got {header.type}") if header.length not in IMDVERSIONS: # Try swapping endianness @@ -289,9 +287,7 @@ def _await_IMD_handshake(self) -> IMDSessionInfo: f"Expected header type `IMD_SESSIONINFO`, got {header.type}" ) if header.length != 7: - raise ValueError( - f"Expected header length 7, got {header.length}" - ) + raise ValueError(f"Expected header length 7, got {header.length}") data = bytearray(7) read_into_buf(self._conn, data) sinfo = parse_imdv3_session_info(data, end) @@ -310,9 +306,7 @@ def _go(self): if self._continue_after_disconnect is not None: wait_behavior = (int)(not self._continue_after_disconnect) - wait_packet = create_header_bytes( - IMDHeaderType.IMD_WAIT, wait_behavior - ) + wait_packet = create_header_bytes(IMDHeaderType.IMD_WAIT, wait_behavior) self._conn.sendall(wait_packet) logger.debug( "IMDClient: Attempted to change wait behavior to %s", @@ -515,12 +509,14 @@ def _read(self, buf): """Wraps `read_into_buf` call to give uniform error handling which indicates end of stream""" try: read_into_buf(self._conn, buf) - except (ConnectionError, TimeoutError, BlockingIOError, Exception): + except (ConnectionError, TimeoutError, BlockingIOError, Exception) as e: # ConnectionError: Server is definitely done sending frames, socket is closed # TimeoutError: Server is *likely* done sending frames. # BlockingIOError: Occurs when timeout is 0 in place of a TimeoutError. Server is *likely* done sending frames # OSError: Occurs when main thread disconnects from the server and closes the socket, but producer thread attempts to read another frame # Exception: Something unexpected happened + if e.isinstance(BlockingIOError): + logger.debug("IMDProducer: BlockingIOError occurred, Amru called it") raise EOFError @@ -560,9 +556,7 @@ def _parse_imdframe(self): ) self._prev_energies = self._imdf.energies - self._expect_header( - IMDHeaderType.IMD_FCOORDS, expected_value=self._n_atoms - ) + self._expect_header(IMDHeaderType.IMD_FCOORDS, expected_value=self._n_atoms) self._read(self._positions) np.copyto( self._imdf.positions, @@ -571,8 +565,7 @@ def _parse_imdframe(self): ).reshape((self._n_atoms, 3)), ) elif ( - header.type == IMDHeaderType.IMD_FCOORDS - and header.length == self._n_atoms + header.type == IMDHeaderType.IMD_FCOORDS and header.length == self._n_atoms ): # If we received positions but no energies # use the last energies received @@ -592,9 +585,7 @@ def _parse_imdframe(self): def _pause(self): self._conn.settimeout(0) - logger.debug( - "IMDProducer: Pausing simulation because buffer is almost full" - ) + logger.debug("IMDProducer: Pausing simulation because buffer is almost full") pause = create_header_bytes(IMDHeaderType.IMD_PAUSE, 0) try: self._conn.sendall(pause) @@ -659,9 +650,7 @@ def __init__( def _pause(self): self._conn.settimeout(0) - logger.debug( - "IMDProducer: Pausing simulation because buffer is almost full" - ) + logger.debug("IMDProducer: Pausing simulation because buffer is almost full") pause = create_header_bytes(IMDHeaderType.IMD_PAUSE, 0) try: self._conn.sendall(pause) @@ -710,13 +699,9 @@ def _parse_imdframe(self): if self._imdsinfo.box: self._expect_header(IMDHeaderType.IMD_BOX, expected_value=1) self._read(self._box) - self._imdf.box = parse_box_bytes( - self._box, self._imdsinfo.endianness - ) + self._imdf.box = parse_box_bytes(self._box, self._imdsinfo.endianness) if self._imdsinfo.positions: - self._expect_header( - IMDHeaderType.IMD_FCOORDS, expected_value=self._n_atoms - ) + self._expect_header(IMDHeaderType.IMD_FCOORDS, expected_value=self._n_atoms) self._read(self._positions) np.copyto( self._imdf.positions, @@ -736,9 +721,7 @@ def _parse_imdframe(self): ).reshape((self._n_atoms, 3)), ) if self._imdsinfo.forces: - self._expect_header( - IMDHeaderType.IMD_FORCES, expected_value=self._n_atoms - ) + self._expect_header(IMDHeaderType.IMD_FORCES, expected_value=self._n_atoms) self._read(self._forces) np.copyto( self._imdf.forces, @@ -802,13 +785,9 @@ def __init__( # even if they aren't sent every frame. Can be optimized if needed imdf_memsize = imdframe_memsize(n_atoms, imdsinfo) self._total_imdf = buffer_size // imdf_memsize - logger.debug( - f"IMDFrameBuffer: Total IMDFrames allocated: {self._total_imdf}" - ) + logger.debug(f"IMDFrameBuffer: Total IMDFrames allocated: {self._total_imdf}") if self._total_imdf == 0: - raise ValueError( - "Buffer size is too small to hold a single IMDFrame" - ) + raise ValueError("Buffer size is too small to hold a single IMDFrame") for i in range(self._total_imdf): self._empty_q.put(IMDFrame(n_atoms, imdsinfo)) @@ -820,10 +799,7 @@ def __init__( def is_full(self): logger.debug("IMDFrameBuffer: Checking if full") - if ( - self._empty_q.qsize() / self._total_imdf - <= self._pause_empty_proportion - ): + if self._empty_q.qsize() / self._total_imdf <= self._pause_empty_proportion: return True logger.debug( @@ -836,8 +812,7 @@ def wait_for_space(self): # Before acquiring the lock, check if we can return immediately if ( - self._empty_q.qsize() / self._total_imdf - >= self._unpause_empty_proportion + self._empty_q.qsize() / self._total_imdf >= self._unpause_empty_proportion ) and not self._consumer_finished: return try: @@ -852,9 +827,7 @@ def wait_for_space(self): ) self._empty_imdf_avail.wait() - logger.debug( - "IMDProducer: Got space in buffer or consumer finished" - ) + logger.debug("IMDProducer: Got space in buffer or consumer finished") if self._consumer_finished: logger.debug("IMDProducer: Noticing consumer finished") diff --git a/imdclient/tests/server.py b/imdclient/tests/server.py index 6c467c3..8832999 100644 --- a/imdclient/tests/server.py +++ b/imdclient/tests/server.py @@ -37,7 +37,7 @@ def set_imdsessioninfo(self, imdsinfo): @property def port(self): """Get the port the server is bound to. - + Returns: int: The port number, or None if not bound yet. """ @@ -121,15 +121,7 @@ def send_frame(self, i): endianness = self.imdsinfo.endianness if self.imdsinfo.time: - time_header = create_header_bytes(IMDHeaderType.IMD_TIME, 1) - time = struct.pack( - f"{endianness}ddQ", - self.traj[i].dt, - self.traj[i].time, - self.traj[i].frame, - ) - - self.conn.sendall(time_header + time) + self.send_time_packet(i) if self.imdsinfo.energies: energy_header = create_header_bytes(IMDHeaderType.IMD_ENERGIES, 1) @@ -187,14 +179,23 @@ def send_frame(self, i): self.conn.sendall(force_header + force) + def send_time_packet(self, i): + time_header = create_header_bytes(IMDHeaderType.IMD_TIME, 1) + time = struct.pack( + f"{self.imdsinfo.endianness}ddQ", + self.traj[i].dt, + self.traj[i].time, + self.traj[i].frame, + ) + + self.conn.sendall(time_header + time) + def expect_packet(self, packet_type, expected_length=None): head_buf = bytearray(IMDHEADERSIZE) read_into_buf(self.conn, head_buf) header = IMDHeader(head_buf) if header.type != packet_type: - raise ValueError( - f"Expected {packet_type} packet, got {header.type}" - ) + raise ValueError(f"Expected {packet_type} packet, got {header.type}") if expected_length is not None and header.length != expected_length: raise ValueError( f"Expected packet length {expected_length}, got {header.length}" diff --git a/imdclient/tests/test_imdclient.py b/imdclient/tests/test_imdclient.py index 656d08c..1431495 100644 --- a/imdclient/tests/test_imdclient.py +++ b/imdclient/tests/test_imdclient.py @@ -22,9 +22,7 @@ logger = logging.getLogger("imdclient.IMDClient") file_handler = logging.FileHandler("test.log") -formatter = logging.Formatter( - "%(asctime)s - %(name)s - %(levelname)s - %(message)s" -) +formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") file_handler.setFormatter(formatter) logger.addHandler(file_handler) logger.setLevel(logging.DEBUG) @@ -62,8 +60,7 @@ def server_client_two_frame_buf(self, universe, imdsinfo): f"localhost", server.port, universe.trajectory.n_atoms, - buffer_size=imdframe_memsize(universe.trajectory.n_atoms, imdsinfo) - * 2, + buffer_size=imdframe_memsize(universe.trajectory.n_atoms, imdsinfo) * 2, ) server.join_accept_thread() yield server, client @@ -112,9 +109,7 @@ def test_traj_unchanged(self, server_client, universe): assert_allclose(universe.trajectory[i].positions, imdf.positions) assert_allclose(universe.trajectory[i].velocities, imdf.velocities) assert_allclose(universe.trajectory[i].forces, imdf.forces) - assert_allclose( - universe.trajectory[i].triclinic_dimensions, imdf.box - ) + assert_allclose(universe.trajectory[i].triclinic_dimensions, imdf.box) def test_pause_resume_continue(self, server_client_two_frame_buf): server, client = server_client_two_frame_buf @@ -174,23 +169,38 @@ def test_continue_after_disconnect(self, universe, imdsinfo, cont): continue_after_disconnect=cont, ) server.join_accept_thread() - server.expect_packet( - IMDHeaderType.IMD_WAIT, expected_length=(int)(not cont) - ) + server.expect_packet(IMDHeaderType.IMD_WAIT, expected_length=(int)(not cont)) def test_incorrect_atom_count(self, server_client_incorrect_atoms, universe): server, client = server_client_incorrect_atoms - + server.send_frame(0) - + with pytest.raises(EOFError) as exc_info: client.get_imdframe() - + error_msg = str(exc_info.value) assert f"Expected n_atoms value {universe.atoms.n_atoms + 1}" in error_msg assert f"got {universe.atoms.n_atoms}" in error_msg assert "Ensure you are using the correct topology file" in error_msg + def test_partial_frame_in_paused_state(self, server_client_two_frame_buf): + import time + + server, client = server_client_two_frame_buf + server.send_frames(0, 2) + # Client's buffer is filled. client should send pause + server.expect_packet(IMDHeaderType.IMD_PAUSE) + + server.send_time_packet(2) + + time.sleep(6) + + client.get_imdframe() + client.get_imdframe() + # i expect this will instantly fail out + client.get_imdframe() + class TestIMDClientV3ContextManager: @pytest.fixture @@ -224,16 +234,10 @@ def test_context_manager_traj_unchanged(self, server, universe): assert_allclose(universe.trajectory[i].time, imdf.time) assert_allclose(universe.trajectory[i].dt, imdf.dt) assert_allclose(universe.trajectory[i].data["step"], imdf.step) - assert_allclose( - universe.trajectory[i].positions, imdf.positions - ) - assert_allclose( - universe.trajectory[i].velocities, imdf.velocities - ) + assert_allclose(universe.trajectory[i].positions, imdf.positions) + assert_allclose(universe.trajectory[i].velocities, imdf.velocities) assert_allclose(universe.trajectory[i].forces, imdf.forces) - assert_allclose( - universe.trajectory[i].triclinic_dimensions, imdf.box - ) + assert_allclose(universe.trajectory[i].triclinic_dimensions, imdf.box) i += 1 server.expect_packet(IMDHeaderType.IMD_DISCONNECT) assert i == 5