Skip to content

Commit d36dd0e

Browse files
authored
Merge pull request #124 from TheSpeedM/main
Add websocket header support to the ROS-client
2 parents 055d507 + e005dfa commit d36dd0e

File tree

4 files changed

+72
-2
lines changed

4 files changed

+72
-2
lines changed

CHANGELOG.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ Unreleased
1212

1313
**Added**
1414

15+
* Added websocket header support to the ROS-client.
16+
1517
**Changed**
1618

1719
**Fixed**

requirements-dev.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,5 @@ pydocstyle
1010
pytest>=6.0
1111
sphinx >=3.4
1212
twine
13+
websockets >= 12.0
1314
-e .

src/roslibpy/ros.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,13 @@ class Ros(object):
3232
host (:obj:`str`): Name or IP address of the ROS bridge host, e.g. ``127.0.0.1``.
3333
port (:obj:`int`): ROS bridge port, e.g. ``9090``.
3434
is_secure (:obj:`bool`): ``True`` to use a secure web sockets connection, otherwise ``False``.
35+
headers (:obj:`dict`): Additional headers to include in the WebSocket connection.
3536
"""
3637

37-
def __init__(self, host, port=None, is_secure=False):
38+
def __init__(self, host, port=None, is_secure=False, headers=None):
3839
self._id_counter = 0
3940
url = RosBridgeClientFactory.create_url(host, port, is_secure)
40-
self.factory = RosBridgeClientFactory(url)
41+
self.factory = RosBridgeClientFactory(url, headers=headers)
4142
self.is_connecting = False
4243
self.connect()
4344

tests/test_ws_headers.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
from __future__ import print_function
2+
3+
import asyncio
4+
import threading
5+
import time
6+
7+
import websockets
8+
9+
from roslibpy import Ros
10+
11+
headers = {
12+
'cookie': 'token=rosbridge',
13+
'authorization': 'Some auth'
14+
}
15+
16+
17+
async def websocket_handler(websocket, path):
18+
request_headers = websocket.request_headers
19+
for key, value in headers.items():
20+
assert request_headers.get(key) == value, f"Header {key} did not match expected value {value}"
21+
await websocket.close()
22+
23+
24+
async def start_server(stop_event):
25+
server = await websockets.serve(websocket_handler, '127.0.0.1', 9000)
26+
await stop_event.wait()
27+
server.close()
28+
await server.wait_closed()
29+
30+
31+
def run_server(stop_event):
32+
asyncio.run(start_server(stop_event))
33+
34+
35+
def run_client():
36+
client = Ros('127.0.0.1', 9000, headers=headers)
37+
client.run()
38+
client.close()
39+
40+
41+
def test_websocket_headers():
42+
server_stop_event = asyncio.Event()
43+
stop_event = threading.Event()
44+
45+
server_thread = threading.Thread(target=run_server, args=(server_stop_event,))
46+
server_thread.start()
47+
48+
time.sleep(1) # Give the server time to start
49+
50+
client_thread = threading.Thread(target=run_client)
51+
client_thread.start()
52+
53+
# Wait for the client thread to finish or timeout after 10 seconds
54+
client_thread.join(timeout=10)
55+
56+
if client_thread.is_alive():
57+
raise Exception("Client did not terminate as expected")
58+
59+
# Signal the server to stop
60+
server_stop_event.set()
61+
server_thread.join(timeout=10)
62+
63+
if server_thread.is_alive():
64+
raise Exception("Server did not stop as expected")
65+
66+
stop_event.set()

0 commit comments

Comments
 (0)