22import json
33import logging
44import re
5+ import typing
56from collections import defaultdict
6- from collections .abc import Generator , Iterator
7+ from collections .abc import Generator
78from dataclasses import dataclass , field
89from datetime import UTC , datetime , timedelta
910from enum import Enum , unique
1011from typing import Any , Final
1112
1213import httpx
14+ from playwright ._impl ._sync_base import EventContextManager
1315from playwright .sync_api import FrameLocator , Page , Request
1416from playwright .sync_api import TimeoutError as PlaywrightTimeoutError
1517from playwright .sync_api import WebSocket
1618from pydantic import AnyUrl
17- from pytest_simcore .helpers .logging_tools import log_context
19+
20+ from .logging_tools import log_context
21+
22+ _logger = logging .getLogger (__name__ )
23+
1824
1925SECOND : Final [int ] = 1000
2026MINUTE : Final [int ] = 60 * SECOND
@@ -106,6 +112,94 @@ class SocketIOEvent:
106112SOCKETIO_MESSAGE_PREFIX : Final [str ] = "42"
107113
108114
115+ @dataclass
116+ class RestartableWebSocket :
117+ page : Page
118+ ws : WebSocket
119+ _registered_events : list [tuple [str , typing .Callable | None ]] = field (
120+ default_factory = list
121+ )
122+ _number_of_restarts : int = 0
123+
124+ def __post_init__ (self ):
125+ self ._configure_websocket_events ()
126+
127+ def _configure_websocket_events (self ):
128+ try :
129+ with log_context (
130+ logging .DEBUG ,
131+ msg = "handle websocket message (set to --log-cli-level=DEBUG level if you wanna see all of them)" ,
132+ ) as ctx :
133+
134+ def on_framesent (payload : str | bytes ) -> None :
135+ ctx .logger .debug ("⬇️ Frame sent: %s" , payload )
136+
137+ def on_framereceived (payload : str | bytes ) -> None :
138+ ctx .logger .debug ("⬆️ Frame received: %s" , payload )
139+
140+ def on_close (_ : WebSocket ) -> None :
141+ ctx .logger .warning (
142+ "⚠️ WebSocket closed. Attempting to reconnect..."
143+ )
144+ self ._attempt_reconnect (ctx .logger )
145+
146+ def on_socketerror (error_msg : str ) -> None :
147+ ctx .logger .error ("❌ WebSocket error: %s" , error_msg )
148+
149+ # Attach core event listeners
150+ self .ws .on ("framesent" , on_framesent )
151+ self .ws .on ("framereceived" , on_framereceived )
152+ self .ws .on ("close" , on_close )
153+ self .ws .on ("socketerror" , on_socketerror )
154+
155+ finally :
156+ # Detach core event listeners
157+ self .ws .remove_listener ("framesent" , on_framesent )
158+ self .ws .remove_listener ("framereceived" , on_framereceived )
159+ self .ws .remove_listener ("close" , on_close )
160+ self .ws .remove_listener ("socketerror" , on_socketerror )
161+
162+ def _attempt_reconnect (self , logger : logging .Logger ) -> None :
163+ """
164+ Attempt to reconnect the WebSocket and restore event listeners.
165+ """
166+ try :
167+ with self .page .expect_websocket () as ws_info :
168+ assert not ws_info .value .is_closed ()
169+
170+ self .ws = ws_info .value
171+ self ._number_of_restarts += 1
172+ logger .info (
173+ "🔄 Reconnected to WebSocket successfully. Number of reconnections: %s" ,
174+ self ._number_of_restarts ,
175+ )
176+ self ._configure_websocket_events ()
177+ # Re-register all custom event listeners
178+ for event , predicate in self ._registered_events :
179+ self .ws .expect_event (event , predicate )
180+
181+ except Exception as e : # pylint: disable=broad-except
182+ logger .error ("🚨 Failed to reconnect WebSocket: %s" , e )
183+
184+ def expect_event (
185+ self ,
186+ event : str ,
187+ predicate : typing .Callable | None = None ,
188+ * ,
189+ timeout : float | None = None ,
190+ ) -> EventContextManager :
191+ """
192+ Register an event listener with support for reconnection.
193+ """
194+ output = self .ws .expect_event (event , predicate , timeout = timeout )
195+ self ._registered_events .append ((event , predicate ))
196+ return output
197+
198+ @classmethod
199+ def create (cls , page : Page , ws : WebSocket ):
200+ return cls (page , ws )
201+
202+
109203def decode_socketio_42_message (message : str ) -> SocketIOEvent :
110204 data = json .loads (message .removeprefix (SOCKETIO_MESSAGE_PREFIX ))
111205 return SocketIOEvent (name = data [0 ], obj = data [1 ])
@@ -245,9 +339,14 @@ def __call__(self, message: str) -> bool:
245339 url = f"https://{ self .node_id } .services.{ self .get_partial_product_url ()} "
246340 response = httpx .get (url , timeout = 10 )
247341 self .logger .info (
248- "Querying the service endpoint from the E2E test. Url: %s Response: %s" ,
342+ "Querying the service endpoint from the E2E test. Url: %s Response: %s TIP: %s " ,
249343 url ,
250344 response ,
345+ (
346+ "Response 401 is OK. It means that service is ready."
347+ if response .status_code == 401
348+ else "We are emulating the frontend; a 500 response is acceptable if the service is not yet ready."
349+ ),
251350 )
252351 if response .status_code <= 401 :
253352 # NOTE: If the response status is less than 400, it means that the backend is ready (There are some services that respond with a 3XX)
@@ -278,7 +377,7 @@ def get_partial_product_url(self):
278377def wait_for_pipeline_state (
279378 current_state : RunningState ,
280379 * ,
281- websocket : WebSocket ,
380+ websocket : RestartableWebSocket ,
282381 if_in_states : tuple [RunningState , ...],
283382 expected_states : tuple [RunningState , ...],
284383 timeout_ms : int ,
@@ -301,39 +400,6 @@ def wait_for_pipeline_state(
301400 return current_state
302401
303402
304- @contextlib .contextmanager
305- def web_socket_default_log_handler (web_socket : WebSocket ) -> Iterator [None ]:
306-
307- try :
308- with log_context (
309- logging .DEBUG ,
310- msg = "handle websocket message (set to --log-cli-level=DEBUG level if you wanna see all of them)" ,
311- ) as ctx :
312-
313- def on_framesent (payload : str | bytes ) -> None :
314- ctx .logger .debug ("⬇️ Frame sent: %s" , payload )
315-
316- def on_framereceived (payload : str | bytes ) -> None :
317- ctx .logger .debug ("⬆️ Frame received: %s" , payload )
318-
319- def on_close (payload : WebSocket ) -> None :
320- ctx .logger .warning ("⚠️ Websocket closed: %s" , payload )
321-
322- def on_socketerror (error_msg : str ) -> None :
323- ctx .logger .error ("❌ Websocket error: %s" , error_msg )
324-
325- web_socket .on ("framesent" , on_framesent )
326- web_socket .on ("framereceived" , on_framereceived )
327- web_socket .on ("close" , on_close )
328- web_socket .on ("socketerror" , on_socketerror )
329- yield
330- finally :
331- web_socket .remove_listener ("framesent" , on_framesent )
332- web_socket .remove_listener ("framereceived" , on_framereceived )
333- web_socket .remove_listener ("close" , on_close )
334- web_socket .remove_listener ("socketerror" , on_socketerror )
335-
336-
337403def _node_started_predicate (request : Request ) -> bool :
338404 return bool (
339405 re .search (NODE_START_REQUEST_PATTERN , request .url )
@@ -358,12 +424,14 @@ def expected_service_running(
358424 * ,
359425 page : Page ,
360426 node_id : str ,
361- websocket : WebSocket ,
427+ websocket : RestartableWebSocket ,
362428 timeout : int ,
363429 press_start_button : bool ,
364430 product_url : AnyUrl ,
365431) -> Generator [ServiceRunning , None , None ]:
366- with log_context (logging .INFO , msg = "Waiting for node to run" ) as ctx :
432+ with log_context (
433+ logging .INFO , msg = f"Waiting for node to run. Timeout: { timeout } "
434+ ) as ctx :
367435 waiter = SocketIONodeProgressCompleteWaiter (
368436 node_id = node_id , logger = ctx .logger , product_url = product_url
369437 )
@@ -395,15 +463,17 @@ def wait_for_service_running(
395463 * ,
396464 page : Page ,
397465 node_id : str ,
398- websocket : WebSocket ,
466+ websocket : RestartableWebSocket ,
399467 timeout : int ,
400468 press_start_button : bool ,
401469 product_url : AnyUrl ,
402470) -> FrameLocator :
403471 """NOTE: if the service was already started this will not work as some of the required websocket events will not be emitted again
404472 In which case this will need further adjutment"""
405473
406- with log_context (logging .INFO , msg = "Waiting for node to run" ) as ctx :
474+ with log_context (
475+ logging .INFO , msg = f"Waiting for node to run. Timeout: { timeout } "
476+ ) as ctx :
407477 waiter = SocketIONodeProgressCompleteWaiter (
408478 node_id = node_id , logger = ctx .logger , product_url = product_url
409479 )
0 commit comments