Skip to content

Commit 6e91c40

Browse files
committed
Improve websocket handling
1 parent ceeb5df commit 6e91c40

File tree

2 files changed

+50
-55
lines changed

2 files changed

+50
-55
lines changed

src/homematicip/connection/websocket_handler.py

Lines changed: 41 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from typing import Callable, List
44

55
import aiohttp
6+
from aiohttp import WSMessage
67

78
from homematicip.connection import ATTR_AUTH_TOKEN, ATTR_CLIENT_AUTH, ATTR_ACCESSPOINT_ID
89
from homematicip.connection.connection_context import ConnectionContext
@@ -19,9 +20,8 @@ class WebsocketHandler:
1920
def __init__(self):
2021
self.INITIAL_BACKOFF = 8
2122
self.url = None
22-
self._session = None
23-
self._ws: aiohttp.client.ClientSession = None
2423
self._stop_event = asyncio.Event()
24+
self._websocket_connected = asyncio.Event()
2525
self._reconnect_task = None
2626
self._task_lock = asyncio.Lock()
2727
self._on_message_handlers: List[Callable] = []
@@ -59,27 +59,30 @@ async def _call_handlers(self, handlers, *args):
5959

6060
async def _connect(self, context: ConnectionContext):
6161
backoff = self.INITIAL_BACKOFF
62-
max_backoff = 1800
62+
max_backoff = 900
6363
while not self._stop_event.is_set():
6464
try:
6565
LOGGER.info(f"Connect to {context.websocket_url}")
66-
self._session = aiohttp.ClientSession()
67-
self._ws = await self._session.ws_connect(
68-
context.websocket_url,
69-
headers={
70-
ATTR_AUTH_TOKEN: context.auth_token,
71-
ATTR_CLIENT_AUTH: context.client_auth_token,
72-
ATTR_ACCESSPOINT_ID: context.accesspoint_id
73-
},
74-
ssl=getattr(context, 'ssl_ctx', True),
75-
heartbeat=30,
76-
timeout=aiohttp.ClientTimeout(total=30)
77-
)
78-
LOGGER.info(f"WebSocket connection established to {context.websocket_url}.")
79-
await self._call_handlers(self._on_connected_handler)
80-
backoff = self.INITIAL_BACKOFF
81-
await self._listen()
66+
67+
async with aiohttp.ClientSession() as session:
68+
async with session.ws_connect(
69+
context.websocket_url,
70+
headers={
71+
ATTR_AUTH_TOKEN: context.auth_token,
72+
ATTR_CLIENT_AUTH: context.client_auth_token,
73+
ATTR_ACCESSPOINT_ID: context.accesspoint_id
74+
},
75+
heartbeat=30,
76+
ssl=getattr(context, 'ssl_ctx', True),
77+
) as ws:
78+
backoff = self.INITIAL_BACKOFF
79+
LOGGER.info(f"WebSocket connection established to {context.websocket_url}.")
80+
self._websocket_connected.set()
81+
await self._call_handlers(self._on_connected_handler)
82+
await self._listen(ws)
83+
8284
except Exception as e:
85+
self._websocket_connected.clear()
8386
reason = f"Websocket lost connection: {e}. Retry in {backoff}s."
8487
LOGGER.warning(reason)
8588

@@ -93,24 +96,23 @@ async def _connect(self, context: ConnectionContext):
9396
finally:
9497
await self._cleanup()
9598

96-
async def _listen(self):
97-
async for msg in self._ws:
99+
100+
async def _listen(self, ws):
101+
async for msg in ws:
98102
if msg.type in (aiohttp.WSMsgType.TEXT, aiohttp.WSMsgType.BINARY):
99-
await self._call_handlers(self._on_message_handlers, msg.data)
103+
await self._handle_ws_message(msg)
100104
elif msg.type == aiohttp.WSMsgType.ERROR:
101105
LOGGER.error(f"Error in websocket: {msg}")
102106
break
103107

104-
async def _cleanup(self):
105-
if self._ws:
106-
if not self._ws.closed:
107-
await self._ws.close()
108-
self._ws = None
109-
if self._session:
110-
if not self._session.closed:
111-
await self._session.close()
112-
self._session = None
108+
async def _handle_ws_message(self, message: WSMessage):
109+
try:
110+
await self._call_handlers(self._on_message_handlers, message.data)
111+
except Exception as e:
112+
LOGGER.error(f"Error handling message: {e}", exc_info=True)
113113

114+
async def _cleanup(self):
115+
self._websocket_connected.clear()
114116
await self._call_handlers(self._on_disconnected_handler)
115117

116118
async def start(self, context: ConnectionContext):
@@ -128,14 +130,14 @@ async def stop(self):
128130
LOGGER.info("Stop websocket client...")
129131
self._stop_event.set()
130132
async with self._task_lock:
131-
try:
132-
await self._ws.close()
133-
except Exception as e:
134-
pass
135-
finally:
136-
if self._reconnect_task:
133+
if self._reconnect_task and not self._reconnect_task.done():
134+
self._reconnect_task.cancel()
135+
try:
137136
await self._reconnect_task
138-
self._reconnect_task = None
137+
except asyncio.CancelledError:
138+
pass
139+
140+
self._reconnect_task = None
139141
await self._cleanup()
140142
LOGGER.info("[Stop] WebSocket client stopped.")
141143

@@ -149,4 +151,4 @@ def _handle_task_result(self, task: asyncio.Task):
149151

150152
def is_connected(self):
151153
"""Returns True if the WebSocket connection is active."""
152-
return self._ws is not None and not self._ws.closed
154+
return self._websocket_connected.is_set()

tests/test_websocket.py

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,9 @@ async def test_is_connected_false_initial():
3030
@pytest.mark.asyncio
3131
async def test_is_connected_true(monkeypatch):
3232
client = WebsocketHandler()
33-
ws_mock = MagicMock()
34-
ws_mock.closed = False
35-
client._ws = ws_mock
33+
client._websocket_connected.set()
3634
assert client.is_connected()
3735

38-
3936
@pytest.mark.asyncio
4037
async def test_handle_task_result_logs_cancelled(caplog):
4138
client = WebsocketHandler()
@@ -58,18 +55,15 @@ async def test_handle_task_result_logs_exception(caplog):
5855

5956
@pytest.mark.asyncio
6057
async def test_cleanup_closes_ws_and_session(monkeypatch):
58+
callback_mock = AsyncMock()
6159
client = WebsocketHandler()
62-
ws_mock = AsyncMock()
63-
session_mock = AsyncMock()
64-
ws_mock.closed = False
65-
session_mock.closed = False
66-
client._ws = ws_mock
67-
client._session = session_mock
60+
client._websocket_connected.set()
61+
client.add_on_disconnected_handler(callback_mock)
62+
6863
await client._cleanup()
69-
ws_mock.close.assert_awaited()
70-
session_mock.close.assert_awaited()
71-
assert client._ws is None
72-
assert client._session is None
64+
65+
assert not client._websocket_connected.is_set()
66+
callback_mock.assert_awaited_once()
7367

7468

7569
@pytest.mark.asyncio
@@ -94,8 +88,7 @@ async def test_listen_calls_handlers(monkeypatch):
9488
DummyMsg('test2', type_=aiohttp.WSMsgType.BINARY),
9589
DummyMsg('err', type_=aiohttp.WSMsgType.ERROR)
9690
]
97-
client._ws = ws_mock
9891
with patch('logging.Logger.debug'), patch('logging.Logger.error'):
99-
await client._listen()
92+
await client._listen(ws_mock)
10093
handler.assert_any_await('test')
10194
handler.assert_any_await('test2')

0 commit comments

Comments
 (0)