@@ -80,6 +80,62 @@ async def connect(self):
8080            name = "receive_message" ,
8181        )
8282
83+     async  def  reconnect (self , max_retries : int  =  5 , retry_delay : float  =  0.1 ):
84+         """Reconnect the WebSocket if it's disconnected with retry logic.""" 
85+         logger .info (f"Attempting to reconnect WebSocket { self .context_id }  )
86+         
87+         # Close existing connection if any 
88+         if  self ._ws  is  not None :
89+             try :
90+                 await  self ._ws .close ()
91+             except  Exception  as  e :
92+                 logger .warning (f"Error closing existing WebSocket: { e }  )
93+         
94+         # Cancel existing receive task if any 
95+         if  self ._receive_task  is  not None  and  not  self ._receive_task .done ():
96+             self ._receive_task .cancel ()
97+             try :
98+                 await  self ._receive_task 
99+             except  asyncio .CancelledError :
100+                 pass 
101+         
102+         # Reset WebSocket and task references 
103+         self ._ws  =  None 
104+         self ._receive_task  =  None 
105+         
106+         # Attempt to reconnect with fixed delay 
107+         for  attempt  in  range (max_retries ):
108+             try :
109+                 await  self .connect ()
110+                 logger .info (f"Successfully reconnected WebSocket { self .context_id } { attempt  +  1 }  )
111+                 return  True 
112+             except  Exception  as  e :
113+                 if  attempt  <  max_retries  -  1 :
114+                     logger .warning (f"Reconnection attempt { attempt  +  1 } { e } { retry_delay }  )
115+                     await  asyncio .sleep (retry_delay )
116+                 else :
117+                     logger .error (f"Failed to reconnect WebSocket { self .context_id } { max_retries } { e }  )
118+                     return  False 
119+         
120+         return  False 
121+ 
122+     def  is_connected (self ) ->  bool :
123+         """Check if the WebSocket is connected and healthy.""" 
124+         return  (
125+             self ._ws  is  not None  
126+             and  not  self ._ws .closed  
127+             and  self ._receive_task  is  not None  
128+             and  not  self ._receive_task .done ()
129+         )
130+ 
131+     async  def  ensure_connected (self ):
132+         """Ensure WebSocket is connected, reconnect if necessary.""" 
133+         if  not  self .is_connected ():
134+             logger .warning (f"WebSocket { self .context_id }  )
135+             success  =  await  self .reconnect ()
136+             if  not  success :
137+                 raise  Exception (f"Failed to reconnect WebSocket { self .context_id }  )
138+ 
83139    def  _get_execute_request (
84140        self , msg_id : str , code : Union [str , StrictStr ], background : bool 
85141    ) ->  str :
@@ -209,11 +265,15 @@ async def _cleanup_env_vars(self, env_vars: Dict[StrictStr, str]):
209265            cleanup_code  =  self ._reset_env_vars_code (env_vars )
210266            if  cleanup_code :
211267                logger .info (f"Cleaning up env vars: { cleanup_code }  )
268+                 # Ensure WebSocket is connected before sending cleanup request 
269+                 await  self .ensure_connected ()
212270                request  =  self ._get_execute_request (message_id , cleanup_code , True )
271+                 if  self ._ws  is  None :
272+                     raise  Exception ("WebSocket not connected" )
213273                await  self ._ws .send (request )
214274
215275                async  for  item  in  self ._wait_for_result (message_id ):
216-                     if  item [ "type" ]  ==  "error" :
276+                     if  isinstance ( item ,  dict )  and   item . get ( "type" )  ==  "error" :
217277                        logger .error (f"Error during env var cleanup: { item }  )
218278        finally :
219279            del  self ._executions [message_id ]
@@ -242,6 +302,10 @@ async def change_current_directory(
242302    ):
243303        message_id  =  str (uuid .uuid4 ())
244304        self ._executions [message_id ] =  Execution (in_background = True )
305+         
306+         # Ensure WebSocket is connected before changing directory 
307+         await  self .ensure_connected ()
308+         
245309        if  language  ==  "python" :
246310            request  =  self ._get_execute_request (message_id , f"%cd { path }  , True )
247311        elif  language  ==  "deno" :
@@ -262,10 +326,13 @@ async def change_current_directory(
262326        else :
263327            return 
264328
329+         if  self ._ws  is  None :
330+             raise  Exception ("WebSocket not connected" )
331+         
265332        await  self ._ws .send (request )
266333
267334        async  for  item  in  self ._wait_for_result (message_id ):
268-             if  item [ "type" ]  ==  "error" :
335+             if  isinstance ( item ,  dict )  and   item . get ( "type" )  ==  "error" :
269336                raise  ExecutionError (f"Error during execution: { item }  )
270337
271338    async  def  execute (
@@ -277,8 +344,8 @@ async def execute(
277344        message_id  =  str (uuid .uuid4 ())
278345        self ._executions [message_id ] =  Execution ()
279346
280-         if   self . _ws   is  None : 
281-              raise   Exception ( "WebSocket not connected" )
347+         # Ensure WebSocket  is connected before executing 
348+         await   self . ensure_connected ( )
282349
283350        async  with  self ._lock :
284351            # Wait for any pending cleanup task to complete 
@@ -319,6 +386,8 @@ async def execute(
319386            request  =  self ._get_execute_request (message_id , complete_code , False )
320387
321388            # Send the code for execution 
389+             if  self ._ws  is  None :
390+                 raise  Exception ("WebSocket not connected" )
322391            await  self ._ws .send (request )
323392
324393            # Stream the results 
@@ -343,6 +412,16 @@ async def _receive_message(self):
343412                await  self ._process_message (json .loads (message ))
344413        except  Exception  as  e :
345414            logger .error (f"WebSocket received error while receiving messages: { str (e )}  )
415+             # Mark all pending executions as failed due to connection loss 
416+             for  execution  in  self ._executions .values ():
417+                 await  execution .queue .put (
418+                     Error (
419+                         name = "ConnectionLost" ,
420+                         value = "WebSocket connection was lost during execution" ,
421+                         traceback = "" ,
422+                     )
423+                 )
424+                 await  execution .queue .put (UnexpectedEndOfExecution ())
346425
347426    async  def  _process_message (self , data : dict ):
348427        """ 
0 commit comments