22import json
33import logging
44import re
5+ import typing
56from collections import defaultdict
6- from collections .abc import Generator , Iterator
7+ from collections .abc import Generator
78from dataclasses import dataclass , field
89from datetime import UTC , datetime , timedelta
910from enum import Enum , unique
1011from typing import Any , Final
1112
1213import httpx
14+ from playwright ._impl ._sync_base import EventContextManager
1315from playwright .sync_api import FrameLocator , Page , Request
1416from playwright .sync_api import TimeoutError as PlaywrightTimeoutError
1517from playwright .sync_api import WebSocket
1618from pydantic import AnyUrl
17- from pytest_simcore .helpers .logging_tools import log_context
19+
20+ from .logging_tools import log_context
21+
22+ _logger = logging .getLogger (__name__ )
23+
1824
1925SECOND : Final [int ] = 1000
2026MINUTE : Final [int ] = 60 * SECOND
@@ -106,6 +112,94 @@ class SocketIOEvent:
106112SOCKETIO_MESSAGE_PREFIX : Final [str ] = "42"
107113
108114
115+ @dataclass
116+ class RestartableWebSocket :
117+ page : Page
118+ ws : WebSocket
119+ _registered_events : list [tuple [str , typing .Callable | None ]] = field (
120+ default_factory = list
121+ )
122+ _number_of_restarts : int = 0
123+
124+ def __post_init__ (self ):
125+ self ._configure_websocket_events ()
126+
127+ def _configure_websocket_events (self ):
128+ try :
129+ with log_context (
130+ logging .DEBUG ,
131+ msg = "handle websocket message (set to --log-cli-level=DEBUG level if you wanna see all of them)" ,
132+ ) as ctx :
133+
134+ def on_framesent (payload : str | bytes ) -> None :
135+ ctx .logger .debug ("⬇️ Frame sent: %s" , payload )
136+
137+ def on_framereceived (payload : str | bytes ) -> None :
138+ ctx .logger .debug ("⬆️ Frame received: %s" , payload )
139+
140+ def on_close (_ : WebSocket ) -> None :
141+ ctx .logger .warning (
142+ "⚠️ WebSocket closed. Attempting to reconnect..."
143+ )
144+ self ._attempt_reconnect (ctx .logger )
145+
146+ def on_socketerror (error_msg : str ) -> None :
147+ ctx .logger .error ("❌ WebSocket error: %s" , error_msg )
148+
149+ # Attach core event listeners
150+ self .ws .on ("framesent" , on_framesent )
151+ self .ws .on ("framereceived" , on_framereceived )
152+ self .ws .on ("close" , on_close )
153+ self .ws .on ("socketerror" , on_socketerror )
154+
155+ finally :
156+ # Detach core event listeners
157+ self .ws .remove_listener ("framesent" , on_framesent )
158+ self .ws .remove_listener ("framereceived" , on_framereceived )
159+ self .ws .remove_listener ("close" , on_close )
160+ self .ws .remove_listener ("socketerror" , on_socketerror )
161+
162+ def _attempt_reconnect (self , logger : logging .Logger ) -> None :
163+ """
164+ Attempt to reconnect the WebSocket and restore event listeners.
165+ """
166+ try :
167+ with self .page .expect_websocket () as ws_info :
168+ assert not ws_info .value .is_closed ()
169+
170+ self .ws = ws_info .value
171+ self ._number_of_restarts += 1
172+ logger .info (
173+ "🔄 Reconnected to WebSocket successfully. Number of reconnections: %s" ,
174+ self ._number_of_restarts ,
175+ )
176+ self ._configure_websocket_events ()
177+ # Re-register all custom event listeners
178+ for event , predicate in self ._registered_events :
179+ self .ws .expect_event (event , predicate )
180+
181+ except Exception as e : # pylint: disable=broad-except
182+ logger .error ("🚨 Failed to reconnect WebSocket: %s" , e )
183+
184+ def expect_event (
185+ self ,
186+ event : str ,
187+ predicate : typing .Callable | None = None ,
188+ * ,
189+ timeout : float | None = None ,
190+ ) -> EventContextManager :
191+ """
192+ Register an event listener with support for reconnection.
193+ """
194+ output = self .ws .expect_event (event , predicate , timeout = timeout )
195+ self ._registered_events .append ((event , predicate ))
196+ return output
197+
198+ @classmethod
199+ def create (cls , page : Page , ws : WebSocket ):
200+ return cls (page , ws )
201+
202+
109203def decode_socketio_42_message (message : str ) -> SocketIOEvent :
110204 data = json .loads (message .removeprefix (SOCKETIO_MESSAGE_PREFIX ))
111205 return SocketIOEvent (name = data [0 ], obj = data [1 ])
@@ -278,7 +372,7 @@ def get_partial_product_url(self):
278372def wait_for_pipeline_state (
279373 current_state : RunningState ,
280374 * ,
281- websocket : WebSocket ,
375+ websocket : RestartableWebSocket ,
282376 if_in_states : tuple [RunningState , ...],
283377 expected_states : tuple [RunningState , ...],
284378 timeout_ms : int ,
@@ -301,39 +395,6 @@ def wait_for_pipeline_state(
301395 return current_state
302396
303397
304- @contextlib .contextmanager
305- def web_socket_default_log_handler (web_socket : WebSocket ) -> Iterator [None ]:
306-
307- try :
308- with log_context (
309- logging .DEBUG ,
310- msg = "handle websocket message (set to --log-cli-level=DEBUG level if you wanna see all of them)" ,
311- ) as ctx :
312-
313- def on_framesent (payload : str | bytes ) -> None :
314- ctx .logger .debug ("⬇️ Frame sent: %s" , payload )
315-
316- def on_framereceived (payload : str | bytes ) -> None :
317- ctx .logger .debug ("⬆️ Frame received: %s" , payload )
318-
319- def on_close (payload : WebSocket ) -> None :
320- ctx .logger .warning ("⚠️ Websocket closed: %s" , payload )
321-
322- def on_socketerror (error_msg : str ) -> None :
323- ctx .logger .error ("❌ Websocket error: %s" , error_msg )
324-
325- web_socket .on ("framesent" , on_framesent )
326- web_socket .on ("framereceived" , on_framereceived )
327- web_socket .on ("close" , on_close )
328- web_socket .on ("socketerror" , on_socketerror )
329- yield
330- finally :
331- web_socket .remove_listener ("framesent" , on_framesent )
332- web_socket .remove_listener ("framereceived" , on_framereceived )
333- web_socket .remove_listener ("close" , on_close )
334- web_socket .remove_listener ("socketerror" , on_socketerror )
335-
336-
337398def _node_started_predicate (request : Request ) -> bool :
338399 return bool (
339400 re .search (NODE_START_REQUEST_PATTERN , request .url )
@@ -358,12 +419,14 @@ def expected_service_running(
358419 * ,
359420 page : Page ,
360421 node_id : str ,
361- websocket : WebSocket ,
422+ websocket : RestartableWebSocket ,
362423 timeout : int ,
363424 press_start_button : bool ,
364425 product_url : AnyUrl ,
365426) -> Generator [ServiceRunning , None , None ]:
366- with log_context (logging .INFO , msg = "Waiting for node to run" ) as ctx :
427+ with log_context (
428+ logging .INFO , msg = f"Waiting for node to run. Timeout: { timeout } "
429+ ) as ctx :
367430 waiter = SocketIONodeProgressCompleteWaiter (
368431 node_id = node_id , logger = ctx .logger , product_url = product_url
369432 )
@@ -395,15 +458,17 @@ def wait_for_service_running(
395458 * ,
396459 page : Page ,
397460 node_id : str ,
398- websocket : WebSocket ,
461+ websocket : RestartableWebSocket ,
399462 timeout : int ,
400463 press_start_button : bool ,
401464 product_url : AnyUrl ,
402465) -> FrameLocator :
403466 """NOTE: if the service was already started this will not work as some of the required websocket events will not be emitted again
404467 In which case this will need further adjutment"""
405468
406- with log_context (logging .INFO , msg = "Waiting for node to run" ) as ctx :
469+ with log_context (
470+ logging .INFO , msg = f"Waiting for node to run. Timeout: { timeout } "
471+ ) as ctx :
407472 waiter = SocketIONodeProgressCompleteWaiter (
408473 node_id = node_id , logger = ctx .logger , product_url = product_url
409474 )
0 commit comments