Skip to content

Commit 3b073e7

Browse files
Copilotljwoods2
andauthored
Fix race condition in InThreadIMDServer by auto-binding to free port (#94)
* fix #81 * Implement fix for race condition in InThreadIMDServer --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: ljwoods2 <145226270+ljwoods2@users.noreply.github.com>
1 parent 370b9d3 commit 3b073e7

File tree

2 files changed

+27
-25
lines changed

2 files changed

+27
-25
lines changed

imdclient/tests/server.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,25 @@ def __init__(self, traj):
2929
self.listen_socket = None
3030
self.conn = None
3131
self.accept_thread = None
32+
self._bound_port = None
3233

3334
def set_imdsessioninfo(self, imdsinfo):
3435
self.imdsinfo = imdsinfo
3536

36-
def handshake_sequence(self, host, port, first_frame=True):
37+
@property
38+
def port(self):
39+
"""Get the port the server is bound to.
40+
41+
Returns:
42+
int: The port number, or None if not bound yet.
43+
"""
44+
return self._bound_port
45+
46+
def handshake_sequence(self, host, first_frame=True):
3747
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
38-
s.bind((host, port))
39-
logger.debug(f"InThreadIMDServer: Listening on {host}:{port}")
48+
s.bind((host, 0)) # Bind to port 0 to get a free port
49+
self._bound_port = s.getsockname()[1] # Store the actual bound port
50+
logger.debug(f"InThreadIMDServer: Listening on {host}:{self._bound_port}")
4051
s.listen(60)
4152
self.listen_socket = s
4253

imdclient/tests/test_imdclient.py

Lines changed: 13 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
from imdclient.IMDClient import imdframe_memsize, IMDClient
1616
from imdclient.IMDProtocol import IMDHeaderType
1717
from .utils import (
18-
get_free_port,
1918
create_default_imdsinfo_v3,
2019
)
2120
from .server import InThreadIMDServer
@@ -46,10 +45,6 @@
4645

4746
class TestIMDClientV3:
4847

49-
@pytest.fixture
50-
def port(self):
51-
return get_free_port()
52-
5348
@pytest.fixture
5449
def universe(self):
5550
return mda.Universe(COORDINATES_TOPOLOGY, COORDINATES_H5MD)
@@ -59,13 +54,13 @@ def imdsinfo(self):
5954
return create_default_imdsinfo_v3()
6055

6156
@pytest.fixture
62-
def server_client_two_frame_buf(self, universe, imdsinfo, port):
57+
def server_client_two_frame_buf(self, universe, imdsinfo):
6358
server = InThreadIMDServer(universe.trajectory)
6459
server.set_imdsessioninfo(imdsinfo)
65-
server.handshake_sequence("localhost", port, first_frame=False)
60+
server.handshake_sequence("localhost", first_frame=False)
6661
client = IMDClient(
6762
f"localhost",
68-
port,
63+
server.port,
6964
universe.trajectory.n_atoms,
7065
buffer_size=imdframe_memsize(universe.trajectory.n_atoms, imdsinfo)
7166
* 2,
@@ -76,14 +71,14 @@ def server_client_two_frame_buf(self, universe, imdsinfo, port):
7671
server.cleanup()
7772

7873
@pytest.fixture(params=[">", "<"])
79-
def server_client(self, universe, imdsinfo, port, request):
74+
def server_client(self, universe, imdsinfo, request):
8075
server = InThreadIMDServer(universe.trajectory)
8176
imdsinfo.endianness = request.param
8277
server.set_imdsessioninfo(imdsinfo)
83-
server.handshake_sequence("localhost", port, first_frame=False)
78+
server.handshake_sequence("localhost", first_frame=False)
8479
client = IMDClient(
8580
f"localhost",
86-
port,
81+
server.port,
8782
universe.trajectory.n_atoms,
8883
)
8984
server.join_accept_thread()
@@ -153,13 +148,13 @@ def test_pause_resume_no_disconnect(self, server_client_two_frame_buf):
153148
server.expect_packet(IMDHeaderType.IMD_DISCONNECT)
154149

155150
@pytest.mark.parametrize("cont", [True, False])
156-
def test_continue_after_disconnect(self, universe, imdsinfo, port, cont):
151+
def test_continue_after_disconnect(self, universe, imdsinfo, cont):
157152
server = InThreadIMDServer(universe.trajectory)
158153
server.set_imdsessioninfo(imdsinfo)
159-
server.handshake_sequence("localhost", port, first_frame=False)
154+
server.handshake_sequence("localhost", first_frame=False)
160155
client = IMDClient(
161156
f"localhost",
162-
port,
157+
server.port,
163158
universe.trajectory.n_atoms,
164159
continue_after_disconnect=cont,
165160
)
@@ -170,10 +165,6 @@ def test_continue_after_disconnect(self, universe, imdsinfo, port, cont):
170165

171166

172167
class TestIMDClientV3ContextManager:
173-
@pytest.fixture
174-
def port(self):
175-
return get_free_port()
176-
177168
@pytest.fixture
178169
def universe(self):
179170
return mda.Universe(COORDINATES_TOPOLOGY, COORDINATES_H5MD)
@@ -183,19 +174,19 @@ def imdsinfo(self):
183174
return create_default_imdsinfo_v3()
184175

185176
@pytest.fixture
186-
def server(self, universe, imdsinfo, port):
177+
def server(self, universe, imdsinfo):
187178
server = InThreadIMDServer(universe.trajectory)
188179
server.set_imdsessioninfo(imdsinfo)
189180
yield server
190181
server.cleanup()
191182

192-
def test_context_manager_traj_unchanged(self, server, port, universe):
193-
server.handshake_sequence("localhost", port, first_frame=False)
183+
def test_context_manager_traj_unchanged(self, server, universe):
184+
server.handshake_sequence("localhost", first_frame=False)
194185

195186
i = 0
196187
with IMDClient(
197188
"localhost",
198-
port,
189+
server.port,
199190
universe.trajectory.n_atoms,
200191
) as client:
201192
server.send_frames(0, 5)

0 commit comments

Comments
 (0)