33from typing import Callable , List
44
55import aiohttp
6+ from aiohttp import WSMessage
67
78from homematicip .connection import ATTR_AUTH_TOKEN , ATTR_CLIENT_AUTH , ATTR_ACCESSPOINT_ID
89from homematicip .connection .connection_context import ConnectionContext
@@ -19,9 +20,8 @@ class WebsocketHandler:
1920 def __init__ (self ):
2021 self .INITIAL_BACKOFF = 8
2122 self .url = None
22- self ._session = None
23- self ._ws : aiohttp .client .ClientSession = None
2423 self ._stop_event = asyncio .Event ()
24+ self ._websocket_connected = asyncio .Event ()
2525 self ._reconnect_task = None
2626 self ._task_lock = asyncio .Lock ()
2727 self ._on_message_handlers : List [Callable ] = []
@@ -59,27 +59,30 @@ async def _call_handlers(self, handlers, *args):
5959
6060 async def _connect (self , context : ConnectionContext ):
6161 backoff = self .INITIAL_BACKOFF
62- max_backoff = 1800
62+ max_backoff = 900
6363 while not self ._stop_event .is_set ():
6464 try :
6565 LOGGER .info (f"Connect to { context .websocket_url } " )
66- self ._session = aiohttp .ClientSession ()
67- self ._ws = await self ._session .ws_connect (
68- context .websocket_url ,
69- headers = {
70- ATTR_AUTH_TOKEN : context .auth_token ,
71- ATTR_CLIENT_AUTH : context .client_auth_token ,
72- ATTR_ACCESSPOINT_ID : context .accesspoint_id
73- },
74- ssl = getattr (context , 'ssl_ctx' , True ),
75- heartbeat = 30 ,
76- timeout = aiohttp .ClientTimeout (total = 30 )
77- )
78- LOGGER .info (f"WebSocket connection established to { context .websocket_url } ." )
79- await self ._call_handlers (self ._on_connected_handler )
80- backoff = self .INITIAL_BACKOFF
81- await self ._listen ()
66+
67+ async with aiohttp .ClientSession () as session :
68+ async with session .ws_connect (
69+ context .websocket_url ,
70+ headers = {
71+ ATTR_AUTH_TOKEN : context .auth_token ,
72+ ATTR_CLIENT_AUTH : context .client_auth_token ,
73+ ATTR_ACCESSPOINT_ID : context .accesspoint_id
74+ },
75+ heartbeat = 30 ,
76+ ssl = getattr (context , 'ssl_ctx' , True ),
77+ ) as ws :
78+ backoff = self .INITIAL_BACKOFF
79+ LOGGER .info (f"WebSocket connection established to { context .websocket_url } ." )
80+ self ._websocket_connected .set ()
81+ await self ._call_handlers (self ._on_connected_handler )
82+ await self ._listen (ws )
83+
8284 except Exception as e :
85+ self ._websocket_connected .clear ()
8386 reason = f"Websocket lost connection: { e } . Retry in { backoff } s."
8487 LOGGER .warning (reason )
8588
@@ -93,24 +96,23 @@ async def _connect(self, context: ConnectionContext):
9396 finally :
9497 await self ._cleanup ()
9598
96- async def _listen (self ):
97- async for msg in self ._ws :
99+
100+ async def _listen (self , ws ):
101+ async for msg in ws :
98102 if msg .type in (aiohttp .WSMsgType .TEXT , aiohttp .WSMsgType .BINARY ):
99- await self ._call_handlers ( self . _on_message_handlers , msg . data )
103+ await self ._handle_ws_message ( msg )
100104 elif msg .type == aiohttp .WSMsgType .ERROR :
101105 LOGGER .error (f"Error in websocket: { msg } " )
102106 break
103107
104- async def _cleanup (self ):
105- if self ._ws :
106- if not self ._ws .closed :
107- await self ._ws .close ()
108- self ._ws = None
109- if self ._session :
110- if not self ._session .closed :
111- await self ._session .close ()
112- self ._session = None
108+ async def _handle_ws_message (self , message : WSMessage ):
109+ try :
110+ await self ._call_handlers (self ._on_message_handlers , message .data )
111+ except Exception as e :
112+ LOGGER .error (f"Error handling message: { e } " , exc_info = True )
113113
114+ async def _cleanup (self ):
115+ self ._websocket_connected .clear ()
114116 await self ._call_handlers (self ._on_disconnected_handler )
115117
116118 async def start (self , context : ConnectionContext ):
@@ -128,14 +130,14 @@ async def stop(self):
128130 LOGGER .info ("Stop websocket client..." )
129131 self ._stop_event .set ()
130132 async with self ._task_lock :
131- try :
132- await self ._ws .close ()
133- except Exception as e :
134- pass
135- finally :
136- if self ._reconnect_task :
133+ if self ._reconnect_task and not self ._reconnect_task .done ():
134+ self ._reconnect_task .cancel ()
135+ try :
137136 await self ._reconnect_task
138- self ._reconnect_task = None
137+ except asyncio .CancelledError :
138+ pass
139+
140+ self ._reconnect_task = None
139141 await self ._cleanup ()
140142 LOGGER .info ("[Stop] WebSocket client stopped." )
141143
@@ -149,4 +151,4 @@ def _handle_task_result(self, task: asyncio.Task):
149151
150152 def is_connected (self ):
151153 """Returns True if the WebSocket connection is active."""
152- return self ._ws is not None and not self . _ws . closed
154+ return self ._websocket_connected . is_set ()
0 commit comments