@@ -80,6 +80,62 @@ async def connect(self):
8080 name = "receive_message" ,
8181 )
8282
83+ async def reconnect (self , max_retries : int = 5 , retry_delay : float = 0.1 ):
84+ """Reconnect the WebSocket if it's disconnected with retry logic."""
85+ logger .info (f"Attempting to reconnect WebSocket { self .context_id } " )
86+
87+ # Close existing connection if any
88+ if self ._ws is not None :
89+ try :
90+ await self ._ws .close ()
91+ except Exception as e :
92+ logger .warning (f"Error closing existing WebSocket: { e } " )
93+
94+ # Cancel existing receive task if any
95+ if self ._receive_task is not None and not self ._receive_task .done ():
96+ self ._receive_task .cancel ()
97+ try :
98+ await self ._receive_task
99+ except asyncio .CancelledError :
100+ pass
101+
102+ # Reset WebSocket and task references
103+ self ._ws = None
104+ self ._receive_task = None
105+
106+ # Attempt to reconnect with fixed delay
107+ for attempt in range (max_retries ):
108+ try :
109+ await self .connect ()
110+ logger .info (f"Successfully reconnected WebSocket { self .context_id } on attempt { attempt + 1 } " )
111+ return True
112+ except Exception as e :
113+ if attempt < max_retries - 1 :
114+ logger .warning (f"Reconnection attempt { attempt + 1 } failed: { e } . Retrying in { retry_delay } s..." )
115+ await asyncio .sleep (retry_delay )
116+ else :
117+ logger .error (f"Failed to reconnect WebSocket { self .context_id } after { max_retries } attempts: { e } " )
118+ return False
119+
120+ return False
121+
122+ def is_connected (self ) -> bool :
123+ """Check if the WebSocket is connected and healthy."""
124+ return (
125+ self ._ws is not None
126+ and not self ._ws .closed
127+ and self ._receive_task is not None
128+ and not self ._receive_task .done ()
129+ )
130+
131+ async def ensure_connected (self ):
132+ """Ensure WebSocket is connected, reconnect if necessary."""
133+ if not self .is_connected ():
134+ logger .warning (f"WebSocket { self .context_id } is not connected, attempting to reconnect" )
135+ success = await self .reconnect ()
136+ if not success :
137+ raise Exception (f"Failed to reconnect WebSocket { self .context_id } " )
138+
83139 def _get_execute_request (
84140 self , msg_id : str , code : Union [str , StrictStr ], background : bool
85141 ) -> str :
@@ -209,11 +265,15 @@ async def _cleanup_env_vars(self, env_vars: Dict[StrictStr, str]):
209265 cleanup_code = self ._reset_env_vars_code (env_vars )
210266 if cleanup_code :
211267 logger .info (f"Cleaning up env vars: { cleanup_code } " )
268+ # Ensure WebSocket is connected before sending cleanup request
269+ await self .ensure_connected ()
212270 request = self ._get_execute_request (message_id , cleanup_code , True )
271+ if self ._ws is None :
272+ raise Exception ("WebSocket not connected" )
213273 await self ._ws .send (request )
214274
215275 async for item in self ._wait_for_result (message_id ):
216- if item [ "type" ] == "error" :
276+ if isinstance ( item , dict ) and item . get ( "type" ) == "error" :
217277 logger .error (f"Error during env var cleanup: { item } " )
218278 finally :
219279 del self ._executions [message_id ]
@@ -242,6 +302,10 @@ async def change_current_directory(
242302 ):
243303 message_id = str (uuid .uuid4 ())
244304 self ._executions [message_id ] = Execution (in_background = True )
305+
306+ # Ensure WebSocket is connected before changing directory
307+ await self .ensure_connected ()
308+
245309 if language == "python" :
246310 request = self ._get_execute_request (message_id , f"%cd { path } " , True )
247311 elif language == "deno" :
@@ -262,10 +326,13 @@ async def change_current_directory(
262326 else :
263327 return
264328
329+ if self ._ws is None :
330+ raise Exception ("WebSocket not connected" )
331+
265332 await self ._ws .send (request )
266333
267334 async for item in self ._wait_for_result (message_id ):
268- if item [ "type" ] == "error" :
335+ if isinstance ( item , dict ) and item . get ( "type" ) == "error" :
269336 raise ExecutionError (f"Error during execution: { item } " )
270337
271338 async def execute (
@@ -277,8 +344,8 @@ async def execute(
277344 message_id = str (uuid .uuid4 ())
278345 self ._executions [message_id ] = Execution ()
279346
280- if self . _ws is None :
281- raise Exception ( "WebSocket not connected" )
347+ # Ensure WebSocket is connected before executing
348+ await self . ensure_connected ( )
282349
283350 async with self ._lock :
284351 # Wait for any pending cleanup task to complete
@@ -319,6 +386,8 @@ async def execute(
319386 request = self ._get_execute_request (message_id , complete_code , False )
320387
321388 # Send the code for execution
389+ if self ._ws is None :
390+ raise Exception ("WebSocket not connected" )
322391 await self ._ws .send (request )
323392
324393 # Stream the results
@@ -343,6 +412,16 @@ async def _receive_message(self):
343412 await self ._process_message (json .loads (message ))
344413 except Exception as e :
345414 logger .error (f"WebSocket received error while receiving messages: { str (e )} " )
415+ # Mark all pending executions as failed due to connection loss
416+ for execution in self ._executions .values ():
417+ await execution .queue .put (
418+ Error (
419+ name = "ConnectionLost" ,
420+ value = "WebSocket connection was lost during execution" ,
421+ traceback = "" ,
422+ )
423+ )
424+ await execution .queue .put (UnexpectedEndOfExecution ())
346425
347426 async def _process_message (self , data : dict ):
348427 """
0 commit comments