diff --git a/imdclient/tests/test_imdclient.py b/imdclient/tests/test_imdclient.py index 0ad86f7..656d08c 100644 --- a/imdclient/tests/test_imdclient.py +++ b/imdclient/tests/test_imdclient.py @@ -79,7 +79,22 @@ def server_client(self, universe, imdsinfo, request): client = IMDClient( f"localhost", server.port, - universe.trajectory.n_atoms, + universe.atoms.n_atoms, + ) + server.join_accept_thread() + yield server, client + client.stop() + server.cleanup() + + @pytest.fixture + def server_client_incorrect_atoms(self, universe, imdsinfo): + server = InThreadIMDServer(universe.trajectory) + server.set_imdsessioninfo(imdsinfo) + server.handshake_sequence("localhost", first_frame=False) + client = IMDClient( + f"localhost", + server.port, + universe.atoms.n_atoms + 1, ) server.join_accept_thread() yield server, client @@ -163,6 +178,19 @@ def test_continue_after_disconnect(self, universe, imdsinfo, cont): IMDHeaderType.IMD_WAIT, expected_length=(int)(not cont) ) + def test_incorrect_atom_count(self, server_client_incorrect_atoms, universe): + server, client = server_client_incorrect_atoms + + server.send_frame(0) + + with pytest.raises(EOFError) as exc_info: + client.get_imdframe() + + error_msg = str(exc_info.value) + assert f"Expected n_atoms value {universe.atoms.n_atoms + 1}" in error_msg + assert f"got {universe.atoms.n_atoms}" in error_msg + assert "Ensure you are using the correct topology file" in error_msg + class TestIMDClientV3ContextManager: @pytest.fixture