Skip to content

Commit 5496719

Browse files
committed
Make the test actually work
1 parent 75a0f5c commit 5496719

File tree

2 files changed

+35
-24
lines changed

2 files changed

+35
-24
lines changed

requirements-dev.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,5 @@ pydocstyle
1010
pytest>=6.0
1111
sphinx >=3.4
1212
twine
13+
websockets >= 12.0
1314
-e .

tests/test_ws_headers.py

Lines changed: 34 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
from __future__ import print_function
2+
3+
import asyncio
24
import threading
35
import time
46

5-
from autobahn.twisted.websocket import WebSocketServerProtocol, WebSocketServerFactory
6-
from twisted.internet import reactor
7+
import websockets
78

89
from roslibpy import Ros
910

@@ -12,45 +13,54 @@
1213
'authorization': 'Some auth'
1314
}
1415

15-
class TestWebSocketServerProtocol(WebSocketServerProtocol):
16-
def onConnect(self, request):
17-
for key, value in headers.items():
18-
assert request.headers.get(key) == value, f"Header {key} did not match expected value {value}"
19-
self.factory.context['wait'].set()
2016

21-
def onOpen(self):
22-
self.sendClose()
17+
async def websocket_handler(websocket, path):
18+
request_headers = websocket.request_headers
19+
for key, value in headers.items():
20+
assert request_headers.get(key) == value, f"Header {key} did not match expected value {value}"
21+
await websocket.close()
22+
23+
24+
async def start_server(stop_event):
25+
server = await websockets.serve(websocket_handler, '127.0.0.1', 9000)
26+
await stop_event.wait()
27+
server.close()
28+
await server.wait_closed()
29+
2330

24-
def run_server(context):
25-
factory = WebSocketServerFactory()
26-
factory.protocol = TestWebSocketServerProtocol
27-
factory.context = context
31+
def run_server(stop_event):
32+
asyncio.run(start_server(stop_event))
2833

29-
reactor.listenTCP(9000, factory)
30-
reactor.run(installSignalHandlers=False)
3134

3235
def run_client():
3336
client = Ros('127.0.0.1', 9000, headers=headers)
3437
client.run()
3538
client.close()
3639

40+
3741
def test_websocket_headers():
38-
context = dict(wait=threading.Event())
42+
server_stop_event = asyncio.Event()
43+
stop_event = threading.Event()
3944

40-
server_thread = threading.Thread(target=run_server, args=(context,))
45+
server_thread = threading.Thread(target=run_server, args=(server_stop_event,))
4146
server_thread.start()
4247

4348
time.sleep(1) # Give the server time to start
4449

4550
client_thread = threading.Thread(target=run_client)
4651
client_thread.start()
4752

48-
if not context["wait"].wait(10):
49-
raise Exception("Headers were not as expected")
53+
# Wait for the client thread to finish or timeout after 10 seconds
54+
client_thread.join(timeout=10)
55+
56+
if client_thread.is_alive():
57+
raise Exception("Client did not terminate as expected")
58+
59+
# Signal the server to stop
60+
server_stop_event.set()
61+
server_thread.join(timeout=10)
5062

51-
client_thread.join()
52-
reactor.callFromThread(reactor.stop)
53-
server_thread.join()
63+
if server_thread.is_alive():
64+
raise Exception("Server did not stop as expected")
5465

55-
if __name__ == "__main__":
56-
test_websocket_headers()
66+
stop_event.set()

0 commit comments

Comments
 (0)