1212)
1313from  pydantic  import  StrictStr 
1414from  websockets .client  import  WebSocketClientProtocol , connect 
15+ from  websockets .exceptions  import  (
16+     ConnectionClosedError ,
17+     WebSocketException ,
18+ )
1519
1620from  api .models .error  import  Error 
1721from  api .models .logs  import  Stdout , Stderr 
2731
2832logger  =  logging .getLogger (__name__ )
2933
34+ MAX_RECONNECT_RETRIES  =  3 
35+ PING_TIMEOUT  =  30 
36+ 
3037
3138class  Execution :
3239    def  __init__ (self , in_background : bool  =  False ):
@@ -61,6 +68,15 @@ def __init__(self, context_id: str, session_id: str, language: str, cwd: str):
6168        self ._executions : Dict [str , Execution ] =  {}
6269        self ._lock  =  asyncio .Lock ()
6370
71+     async  def  reconnect (self ):
72+         if  self ._ws  is  not None :
73+             await  self ._ws .close (reason = "Reconnecting" )
74+ 
75+         if  self ._receive_task  is  not None :
76+             await  self ._receive_task 
77+ 
78+         await  self .connect ()
79+ 
6480    async  def  connect (self ):
6581        logger .debug (f"WebSocket connecting to { self .url }  )
6682
@@ -69,6 +85,7 @@ async def connect(self):
6985
7086        self ._ws  =  await  connect (
7187            self .url ,
88+             ping_timeout = PING_TIMEOUT ,
7289            max_size = None ,
7390            max_queue = None ,
7491            logger = ws_logger ,
@@ -274,9 +291,6 @@ async def execute(
274291        env_vars : Dict [StrictStr , str ],
275292        access_token : str ,
276293    ):
277-         message_id  =  str (uuid .uuid4 ())
278-         self ._executions [message_id ] =  Execution ()
279- 
280294        if  self ._ws  is  None :
281295            raise  Exception ("WebSocket not connected" )
282296
@@ -313,13 +327,40 @@ async def execute(
313327                )
314328                complete_code  =  f"{ indented_env_code } \n { complete_code }  
315329
316-             logger .info (
317-                 f"Sending code for the execution ({ message_id } { complete_code }  
318-             )
319-             request  =  self ._get_execute_request (message_id , complete_code , False )
330+             message_id  =  str (uuid .uuid4 ())
331+             execution  =  Execution ()
332+             self ._executions [message_id ] =  execution 
320333
321334            # Send the code for execution 
322-             await  self ._ws .send (request )
335+             # Initial request and retries 
336+             for  i  in  range (1  +  MAX_RECONNECT_RETRIES ):
337+                 try :
338+                     logger .info (
339+                         f"Sending code for the execution ({ message_id } { complete_code }  
340+                     )
341+                     request  =  self ._get_execute_request (
342+                         message_id , complete_code , False 
343+                     )
344+                     await  self ._ws .send (request )
345+                     break 
346+                 except  (ConnectionClosedError , WebSocketException ) as  e :
347+                     # Keep the last result, even if error 
348+                     if  i  <  MAX_RECONNECT_RETRIES :
349+                         logger .warning (
350+                             f"WebSocket connection lost while sending execution request, { i  +  1 } { str (e )}  
351+                         )
352+                         await  self .reconnect ()
353+             else :
354+                 # The retry didn't help, request wasn't sent successfully 
355+                 logger .error ("Failed to send execution request" )
356+                 await  execution .queue .put (
357+                     Error (
358+                         name = "WebSocketError" ,
359+                         value = "Failed to send execution request" ,
360+                         traceback = "" ,
361+                     )
362+                 )
363+                 await  execution .queue .put (UnexpectedEndOfExecution ())
323364
324365            # Stream the results 
325366            async  for  item  in  self ._wait_for_result (message_id ):
@@ -343,6 +384,18 @@ async def _receive_message(self):
343384                await  self ._process_message (json .loads (message ))
344385        except  Exception  as  e :
345386            logger .error (f"WebSocket received error while receiving messages: { str (e )}  )
387+         finally :
388+             # To prevent infinite hang, we need to cancel all ongoing execution as we could lost results during the reconnect 
389+             # Thanks to the locking, there can be either no ongoing execution or just one. 
390+             for  key , execution  in  self ._executions .items ():
391+                 await  execution .queue .put (
392+                     Error (
393+                         name = "WebSocketError" ,
394+                         value = "The connections was lost, rerun the code to get the results" ,
395+                         traceback = "" ,
396+                     )
397+                 )
398+                 await  execution .queue .put (UnexpectedEndOfExecution ())
346399
347400    async  def  _process_message (self , data : dict ):
348401        """ 
0 commit comments