Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 18 additions & 45 deletions imdclient/IMDClient.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,9 +247,7 @@ def _await_IMD_handshake(self) -> IMDSessionInfo:
header = IMDHeader(h_buf)

if header.type != IMDHeaderType.IMD_HANDSHAKE:
raise ValueError(
f"Expected header type `IMD_HANDSHAKE`, got {header.type}"
)
raise ValueError(f"Expected header type `IMD_HANDSHAKE`, got {header.type}")

if header.length not in IMDVERSIONS:
# Try swapping endianness
Expand Down Expand Up @@ -289,9 +287,7 @@ def _await_IMD_handshake(self) -> IMDSessionInfo:
f"Expected header type `IMD_SESSIONINFO`, got {header.type}"
)
if header.length != 7:
raise ValueError(
f"Expected header length 7, got {header.length}"
)
raise ValueError(f"Expected header length 7, got {header.length}")
data = bytearray(7)
read_into_buf(self._conn, data)
sinfo = parse_imdv3_session_info(data, end)
Expand All @@ -310,9 +306,7 @@ def _go(self):

if self._continue_after_disconnect is not None:
wait_behavior = (int)(not self._continue_after_disconnect)
wait_packet = create_header_bytes(
IMDHeaderType.IMD_WAIT, wait_behavior
)
wait_packet = create_header_bytes(IMDHeaderType.IMD_WAIT, wait_behavior)
self._conn.sendall(wait_packet)
logger.debug(
"IMDClient: Attempted to change wait behavior to %s",
Expand Down Expand Up @@ -515,12 +509,14 @@ def _read(self, buf):
"""Wraps `read_into_buf` call to give uniform error handling which indicates end of stream"""
try:
read_into_buf(self._conn, buf)
except (ConnectionError, TimeoutError, BlockingIOError, Exception):
except (ConnectionError, TimeoutError, BlockingIOError, Exception) as e:
# ConnectionError: Server is definitely done sending frames, socket is closed
# TimeoutError: Server is *likely* done sending frames.
# BlockingIOError: Occurs when timeout is 0 in place of a TimeoutError. Server is *likely* done sending frames
# OSError: Occurs when main thread disconnects from the server and closes the socket, but producer thread attempts to read another frame
# Exception: Something unexpected happened
if e.isinstance(BlockingIOError):
logger.debug("IMDProducer: BlockingIOError occurred, Amru called it")
raise EOFError


Expand Down Expand Up @@ -560,9 +556,7 @@ def _parse_imdframe(self):
)
self._prev_energies = self._imdf.energies

self._expect_header(
IMDHeaderType.IMD_FCOORDS, expected_value=self._n_atoms
)
self._expect_header(IMDHeaderType.IMD_FCOORDS, expected_value=self._n_atoms)
self._read(self._positions)
np.copyto(
self._imdf.positions,
Expand All @@ -571,8 +565,7 @@ def _parse_imdframe(self):
).reshape((self._n_atoms, 3)),
)
elif (
header.type == IMDHeaderType.IMD_FCOORDS
and header.length == self._n_atoms
header.type == IMDHeaderType.IMD_FCOORDS and header.length == self._n_atoms
):
# If we received positions but no energies
# use the last energies received
Expand All @@ -592,9 +585,7 @@ def _parse_imdframe(self):

def _pause(self):
self._conn.settimeout(0)
logger.debug(
"IMDProducer: Pausing simulation because buffer is almost full"
)
logger.debug("IMDProducer: Pausing simulation because buffer is almost full")
pause = create_header_bytes(IMDHeaderType.IMD_PAUSE, 0)
try:
self._conn.sendall(pause)
Expand Down Expand Up @@ -659,9 +650,7 @@ def __init__(

def _pause(self):
self._conn.settimeout(0)
logger.debug(
"IMDProducer: Pausing simulation because buffer is almost full"
)
logger.debug("IMDProducer: Pausing simulation because buffer is almost full")
pause = create_header_bytes(IMDHeaderType.IMD_PAUSE, 0)
try:
self._conn.sendall(pause)
Expand Down Expand Up @@ -710,13 +699,9 @@ def _parse_imdframe(self):
if self._imdsinfo.box:
self._expect_header(IMDHeaderType.IMD_BOX, expected_value=1)
self._read(self._box)
self._imdf.box = parse_box_bytes(
self._box, self._imdsinfo.endianness
)
self._imdf.box = parse_box_bytes(self._box, self._imdsinfo.endianness)
if self._imdsinfo.positions:
self._expect_header(
IMDHeaderType.IMD_FCOORDS, expected_value=self._n_atoms
)
self._expect_header(IMDHeaderType.IMD_FCOORDS, expected_value=self._n_atoms)
self._read(self._positions)
np.copyto(
self._imdf.positions,
Expand All @@ -736,9 +721,7 @@ def _parse_imdframe(self):
).reshape((self._n_atoms, 3)),
)
if self._imdsinfo.forces:
self._expect_header(
IMDHeaderType.IMD_FORCES, expected_value=self._n_atoms
)
self._expect_header(IMDHeaderType.IMD_FORCES, expected_value=self._n_atoms)
self._read(self._forces)
np.copyto(
self._imdf.forces,
Expand Down Expand Up @@ -802,13 +785,9 @@ def __init__(
# even if they aren't sent every frame. Can be optimized if needed
imdf_memsize = imdframe_memsize(n_atoms, imdsinfo)
self._total_imdf = buffer_size // imdf_memsize
logger.debug(
f"IMDFrameBuffer: Total IMDFrames allocated: {self._total_imdf}"
)
logger.debug(f"IMDFrameBuffer: Total IMDFrames allocated: {self._total_imdf}")
if self._total_imdf == 0:
raise ValueError(
"Buffer size is too small to hold a single IMDFrame"
)
raise ValueError("Buffer size is too small to hold a single IMDFrame")
for i in range(self._total_imdf):
self._empty_q.put(IMDFrame(n_atoms, imdsinfo))

Expand All @@ -820,10 +799,7 @@ def __init__(

def is_full(self):
logger.debug("IMDFrameBuffer: Checking if full")
if (
self._empty_q.qsize() / self._total_imdf
<= self._pause_empty_proportion
):
if self._empty_q.qsize() / self._total_imdf <= self._pause_empty_proportion:
return True

logger.debug(
Expand All @@ -836,8 +812,7 @@ def wait_for_space(self):

# Before acquiring the lock, check if we can return immediately
if (
self._empty_q.qsize() / self._total_imdf
>= self._unpause_empty_proportion
self._empty_q.qsize() / self._total_imdf >= self._unpause_empty_proportion
) and not self._consumer_finished:
return
try:
Expand All @@ -852,9 +827,7 @@ def wait_for_space(self):
)
self._empty_imdf_avail.wait()

logger.debug(
"IMDProducer: Got space in buffer or consumer finished"
)
logger.debug("IMDProducer: Got space in buffer or consumer finished")

if self._consumer_finished:
logger.debug("IMDProducer: Noticing consumer finished")
Expand Down
27 changes: 14 additions & 13 deletions imdclient/tests/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def set_imdsessioninfo(self, imdsinfo):
@property
def port(self):
"""Get the port the server is bound to.

Returns:
int: The port number, or None if not bound yet.
"""
Expand Down Expand Up @@ -121,15 +121,7 @@ def send_frame(self, i):
endianness = self.imdsinfo.endianness

if self.imdsinfo.time:
time_header = create_header_bytes(IMDHeaderType.IMD_TIME, 1)
time = struct.pack(
f"{endianness}ddQ",
self.traj[i].dt,
self.traj[i].time,
self.traj[i].frame,
)

self.conn.sendall(time_header + time)
self.send_time_packet(i)

if self.imdsinfo.energies:
energy_header = create_header_bytes(IMDHeaderType.IMD_ENERGIES, 1)
Expand Down Expand Up @@ -187,14 +179,23 @@ def send_frame(self, i):

self.conn.sendall(force_header + force)

def send_time_packet(self, i):
time_header = create_header_bytes(IMDHeaderType.IMD_TIME, 1)
time = struct.pack(
f"{self.imdsinfo.endianness}ddQ",
self.traj[i].dt,
self.traj[i].time,
self.traj[i].frame,
)

self.conn.sendall(time_header + time)

def expect_packet(self, packet_type, expected_length=None):
head_buf = bytearray(IMDHEADERSIZE)
read_into_buf(self.conn, head_buf)
header = IMDHeader(head_buf)
if header.type != packet_type:
raise ValueError(
f"Expected {packet_type} packet, got {header.type}"
)
raise ValueError(f"Expected {packet_type} packet, got {header.type}")
if expected_length is not None and header.length != expected_length:
raise ValueError(
f"Expected packet length {expected_length}, got {header.length}"
Expand Down
50 changes: 27 additions & 23 deletions imdclient/tests/test_imdclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,7 @@

logger = logging.getLogger("imdclient.IMDClient")
file_handler = logging.FileHandler("test.log")
formatter = logging.Formatter(
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
logger.setLevel(logging.DEBUG)
Expand Down Expand Up @@ -62,8 +60,7 @@ def server_client_two_frame_buf(self, universe, imdsinfo):
f"localhost",
server.port,
universe.trajectory.n_atoms,
buffer_size=imdframe_memsize(universe.trajectory.n_atoms, imdsinfo)
* 2,
buffer_size=imdframe_memsize(universe.trajectory.n_atoms, imdsinfo) * 2,
)
server.join_accept_thread()
yield server, client
Expand Down Expand Up @@ -112,9 +109,7 @@ def test_traj_unchanged(self, server_client, universe):
assert_allclose(universe.trajectory[i].positions, imdf.positions)
assert_allclose(universe.trajectory[i].velocities, imdf.velocities)
assert_allclose(universe.trajectory[i].forces, imdf.forces)
assert_allclose(
universe.trajectory[i].triclinic_dimensions, imdf.box
)
assert_allclose(universe.trajectory[i].triclinic_dimensions, imdf.box)

def test_pause_resume_continue(self, server_client_two_frame_buf):
server, client = server_client_two_frame_buf
Expand Down Expand Up @@ -174,23 +169,38 @@ def test_continue_after_disconnect(self, universe, imdsinfo, cont):
continue_after_disconnect=cont,
)
server.join_accept_thread()
server.expect_packet(
IMDHeaderType.IMD_WAIT, expected_length=(int)(not cont)
)
server.expect_packet(IMDHeaderType.IMD_WAIT, expected_length=(int)(not cont))

def test_incorrect_atom_count(self, server_client_incorrect_atoms, universe):
server, client = server_client_incorrect_atoms

server.send_frame(0)

with pytest.raises(EOFError) as exc_info:
client.get_imdframe()

error_msg = str(exc_info.value)
assert f"Expected n_atoms value {universe.atoms.n_atoms + 1}" in error_msg
assert f"got {universe.atoms.n_atoms}" in error_msg
assert "Ensure you are using the correct topology file" in error_msg

def test_partial_frame_in_paused_state(self, server_client_two_frame_buf):
import time

server, client = server_client_two_frame_buf
server.send_frames(0, 2)
# Client's buffer is filled. client should send pause
server.expect_packet(IMDHeaderType.IMD_PAUSE)

server.send_time_packet(2)

time.sleep(6)

client.get_imdframe()
client.get_imdframe()
# i expect this will instantly fail out
client.get_imdframe()


class TestIMDClientV3ContextManager:
@pytest.fixture
Expand Down Expand Up @@ -224,16 +234,10 @@ def test_context_manager_traj_unchanged(self, server, universe):
assert_allclose(universe.trajectory[i].time, imdf.time)
assert_allclose(universe.trajectory[i].dt, imdf.dt)
assert_allclose(universe.trajectory[i].data["step"], imdf.step)
assert_allclose(
universe.trajectory[i].positions, imdf.positions
)
assert_allclose(
universe.trajectory[i].velocities, imdf.velocities
)
assert_allclose(universe.trajectory[i].positions, imdf.positions)
assert_allclose(universe.trajectory[i].velocities, imdf.velocities)
assert_allclose(universe.trajectory[i].forces, imdf.forces)
assert_allclose(
universe.trajectory[i].triclinic_dimensions, imdf.box
)
assert_allclose(universe.trajectory[i].triclinic_dimensions, imdf.box)
i += 1
server.expect_packet(IMDHeaderType.IMD_DISCONNECT)
assert i == 5
Loading