Skip to content

Commit 44c41bc

Browse files
committed
pre-flight check to ensure web socket is connected and reconnect
1 parent 09efac2 commit 44c41bc

File tree

1 file changed

+83
-4
lines changed

1 file changed

+83
-4
lines changed

template/server/messaging.py

Lines changed: 83 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,62 @@ async def connect(self):
8080
name="receive_message",
8181
)
8282

83+
async def reconnect(self, max_retries: int = 5, retry_delay: float = 0.1):
84+
"""Reconnect the WebSocket if it's disconnected with retry logic."""
85+
logger.info(f"Attempting to reconnect WebSocket {self.context_id}")
86+
87+
# Close existing connection if any
88+
if self._ws is not None:
89+
try:
90+
await self._ws.close()
91+
except Exception as e:
92+
logger.warning(f"Error closing existing WebSocket: {e}")
93+
94+
# Cancel existing receive task if any
95+
if self._receive_task is not None and not self._receive_task.done():
96+
self._receive_task.cancel()
97+
try:
98+
await self._receive_task
99+
except asyncio.CancelledError:
100+
pass
101+
102+
# Reset WebSocket and task references
103+
self._ws = None
104+
self._receive_task = None
105+
106+
# Attempt to reconnect with fixed delay
107+
for attempt in range(max_retries):
108+
try:
109+
await self.connect()
110+
logger.info(f"Successfully reconnected WebSocket {self.context_id} on attempt {attempt + 1}")
111+
return True
112+
except Exception as e:
113+
if attempt < max_retries - 1:
114+
logger.warning(f"Reconnection attempt {attempt + 1} failed: {e}. Retrying in {retry_delay}s...")
115+
await asyncio.sleep(retry_delay)
116+
else:
117+
logger.error(f"Failed to reconnect WebSocket {self.context_id} after {max_retries} attempts: {e}")
118+
return False
119+
120+
return False
121+
122+
def is_connected(self) -> bool:
123+
"""Check if the WebSocket is connected and healthy."""
124+
return (
125+
self._ws is not None
126+
and not self._ws.closed
127+
and self._receive_task is not None
128+
and not self._receive_task.done()
129+
)
130+
131+
async def ensure_connected(self):
132+
"""Ensure WebSocket is connected, reconnect if necessary."""
133+
if not self.is_connected():
134+
logger.warning(f"WebSocket {self.context_id} is not connected, attempting to reconnect")
135+
success = await self.reconnect()
136+
if not success:
137+
raise Exception(f"Failed to reconnect WebSocket {self.context_id}")
138+
83139
def _get_execute_request(
84140
self, msg_id: str, code: Union[str, StrictStr], background: bool
85141
) -> str:
@@ -209,11 +265,15 @@ async def _cleanup_env_vars(self, env_vars: Dict[StrictStr, str]):
209265
cleanup_code = self._reset_env_vars_code(env_vars)
210266
if cleanup_code:
211267
logger.info(f"Cleaning up env vars: {cleanup_code}")
268+
# Ensure WebSocket is connected before sending cleanup request
269+
await self.ensure_connected()
212270
request = self._get_execute_request(message_id, cleanup_code, True)
271+
if self._ws is None:
272+
raise Exception("WebSocket not connected")
213273
await self._ws.send(request)
214274

215275
async for item in self._wait_for_result(message_id):
216-
if item["type"] == "error":
276+
if isinstance(item, dict) and item.get("type") == "error":
217277
logger.error(f"Error during env var cleanup: {item}")
218278
finally:
219279
del self._executions[message_id]
@@ -242,6 +302,10 @@ async def change_current_directory(
242302
):
243303
message_id = str(uuid.uuid4())
244304
self._executions[message_id] = Execution(in_background=True)
305+
306+
# Ensure WebSocket is connected before changing directory
307+
await self.ensure_connected()
308+
245309
if language == "python":
246310
request = self._get_execute_request(message_id, f"%cd {path}", True)
247311
elif language == "deno":
@@ -262,10 +326,13 @@ async def change_current_directory(
262326
else:
263327
return
264328

329+
if self._ws is None:
330+
raise Exception("WebSocket not connected")
331+
265332
await self._ws.send(request)
266333

267334
async for item in self._wait_for_result(message_id):
268-
if item["type"] == "error":
335+
if isinstance(item, dict) and item.get("type") == "error":
269336
raise ExecutionError(f"Error during execution: {item}")
270337

271338
async def execute(
@@ -277,8 +344,8 @@ async def execute(
277344
message_id = str(uuid.uuid4())
278345
self._executions[message_id] = Execution()
279346

280-
if self._ws is None:
281-
raise Exception("WebSocket not connected")
347+
# Ensure WebSocket is connected before executing
348+
await self.ensure_connected()
282349

283350
async with self._lock:
284351
# Wait for any pending cleanup task to complete
@@ -319,6 +386,8 @@ async def execute(
319386
request = self._get_execute_request(message_id, complete_code, False)
320387

321388
# Send the code for execution
389+
if self._ws is None:
390+
raise Exception("WebSocket not connected")
322391
await self._ws.send(request)
323392

324393
# Stream the results
@@ -343,6 +412,16 @@ async def _receive_message(self):
343412
await self._process_message(json.loads(message))
344413
except Exception as e:
345414
logger.error(f"WebSocket received error while receiving messages: {str(e)}")
415+
# Mark all pending executions as failed due to connection loss
416+
for execution in self._executions.values():
417+
await execution.queue.put(
418+
Error(
419+
name="ConnectionLost",
420+
value="WebSocket connection was lost during execution",
421+
traceback="",
422+
)
423+
)
424+
await execution.queue.put(UnexpectedEndOfExecution())
346425

347426
async def _process_message(self, data: dict):
348427
"""

0 commit comments

Comments
 (0)