Skip to content

Commit be28fb7

Browse files
Handle websocket error gracefully + add retry logic (#149)
* Handle websocket error gracefully * Lint * Remove early return * Remove traceback * Add reconnect logic * Fix headers for unsecure sandbox * Update the error message * Lint * Clean up * Don't reconnect in the last iteration * Change the error log * Don't reconnect in the receive task * Wait till the receive task doesn't finish * Revert the interval * Simplify * Address PR comments * Increase the retries * Update template/server/messaging.py Co-authored-by: graphite-app[bot] <96075541+graphite-app[bot]@users.noreply.github.com> * Disable timeout completely * Fix missing access token * Simplify * Add changeset * Revert unwanted change --------- Co-authored-by: graphite-app[bot] <96075541+graphite-app[bot]@users.noreply.github.com>
1 parent 4f4ddc3 commit be28fb7

File tree

5 files changed

+91
-12
lines changed

5 files changed

+91
-12
lines changed

.changeset/curly-pumpkins-kick.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
'@e2b/code-interpreter-template': patch
3+
---
4+
5+
Add retry

.changeset/wicked-mirrors-punch.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
'@e2b/code-interpreter-python': patch
3+
---
4+
5+
Fix issue with secure False

python/e2b_code_interpreter/code_interpreter_async.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,10 @@ async def run_code(
191191
request_timeout = request_timeout or self.connection_config.request_timeout
192192
context_id = context.id if context else None
193193

194+
headers: Dict[str, str] = {}
195+
if self._envd_access_token:
196+
headers = {"X-Access-Token": self._envd_access_token}
197+
194198
try:
195199
async with self._client.stream(
196200
"POST",
@@ -201,7 +205,7 @@ async def run_code(
201205
"language": language,
202206
"env_vars": envs,
203207
},
204-
headers={"X-Access-Token": self._envd_access_token},
208+
headers=headers,
205209
timeout=(request_timeout, timeout, request_timeout, request_timeout),
206210
) as response:
207211
err = await aextract_exception(response)
@@ -249,10 +253,14 @@ async def create_code_context(
249253
if cwd:
250254
data["cwd"] = cwd
251255

256+
headers: Dict[str, str] = {}
257+
if self._envd_access_token:
258+
headers = {"X-Access-Token": self._envd_access_token}
259+
252260
try:
253261
response = await self._client.post(
254262
f"{self._jupyter_url}/contexts",
255-
headers={"X-Access-Token": self._envd_access_token},
263+
headers=headers,
256264
json=data,
257265
timeout=request_timeout or self.connection_config.request_timeout,
258266
)

python/e2b_code_interpreter/code_interpreter_sync.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,10 @@ def run_code(
188188
request_timeout = request_timeout or self.connection_config.request_timeout
189189
context_id = context.id if context else None
190190

191+
headers: Dict[str, str] = {}
192+
if self._envd_access_token:
193+
headers = {"X-Access-Token": self._envd_access_token}
194+
191195
try:
192196
with self._client.stream(
193197
"POST",
@@ -198,7 +202,7 @@ def run_code(
198202
"language": language,
199203
"env_vars": envs,
200204
},
201-
headers={"X-Access-Token": self._envd_access_token},
205+
headers=headers,
202206
timeout=(request_timeout, timeout, request_timeout, request_timeout),
203207
) as response:
204208
err = extract_exception(response)
@@ -246,11 +250,15 @@ def create_code_context(
246250
if cwd:
247251
data["cwd"] = cwd
248252

253+
headers: Dict[str, str] = {}
254+
if self._envd_access_token:
255+
headers = {"X-Access-Token": self._envd_access_token}
256+
249257
try:
250258
response = self._client.post(
251259
f"{self._jupyter_url}/contexts",
252260
json=data,
253-
headers={"X-Access-Token": self._envd_access_token},
261+
headers=headers,
254262
timeout=request_timeout or self.connection_config.request_timeout,
255263
)
256264

template/server/messaging.py

Lines changed: 61 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@
1212
)
1313
from pydantic import StrictStr
1414
from websockets.client import WebSocketClientProtocol, connect
15+
from websockets.exceptions import (
16+
ConnectionClosedError,
17+
WebSocketException,
18+
)
1519

1620
from api.models.error import Error
1721
from api.models.logs import Stdout, Stderr
@@ -27,6 +31,9 @@
2731

2832
logger = logging.getLogger(__name__)
2933

34+
MAX_RECONNECT_RETRIES = 3
35+
PING_TIMEOUT = 30
36+
3037

3138
class Execution:
3239
def __init__(self, in_background: bool = False):
@@ -61,6 +68,15 @@ def __init__(self, context_id: str, session_id: str, language: str, cwd: str):
6168
self._executions: Dict[str, Execution] = {}
6269
self._lock = asyncio.Lock()
6370

71+
async def reconnect(self):
72+
if self._ws is not None:
73+
await self._ws.close(reason="Reconnecting")
74+
75+
if self._receive_task is not None:
76+
await self._receive_task
77+
78+
await self.connect()
79+
6480
async def connect(self):
6581
logger.debug(f"WebSocket connecting to {self.url}")
6682

@@ -69,6 +85,7 @@ async def connect(self):
6985

7086
self._ws = await connect(
7187
self.url,
88+
ping_timeout=PING_TIMEOUT,
7289
max_size=None,
7390
max_queue=None,
7491
logger=ws_logger,
@@ -274,9 +291,6 @@ async def execute(
274291
env_vars: Dict[StrictStr, str],
275292
access_token: str,
276293
):
277-
message_id = str(uuid.uuid4())
278-
self._executions[message_id] = Execution()
279-
280294
if self._ws is None:
281295
raise Exception("WebSocket not connected")
282296

@@ -313,13 +327,40 @@ async def execute(
313327
)
314328
complete_code = f"{indented_env_code}\n{complete_code}"
315329

316-
logger.info(
317-
f"Sending code for the execution ({message_id}): {complete_code}"
318-
)
319-
request = self._get_execute_request(message_id, complete_code, False)
330+
message_id = str(uuid.uuid4())
331+
execution = Execution()
332+
self._executions[message_id] = execution
320333

321334
# Send the code for execution
322-
await self._ws.send(request)
335+
# Initial request and retries
336+
for i in range(1 + MAX_RECONNECT_RETRIES):
337+
try:
338+
logger.info(
339+
f"Sending code for the execution ({message_id}): {complete_code}"
340+
)
341+
request = self._get_execute_request(
342+
message_id, complete_code, False
343+
)
344+
await self._ws.send(request)
345+
break
346+
except (ConnectionClosedError, WebSocketException) as e:
347+
# Keep the last result, even if error
348+
if i < MAX_RECONNECT_RETRIES:
349+
logger.warning(
350+
f"WebSocket connection lost while sending execution request, {i + 1}. reconnecting...: {str(e)}"
351+
)
352+
await self.reconnect()
353+
else:
354+
# The retry didn't help, request wasn't sent successfully
355+
logger.error("Failed to send execution request")
356+
await execution.queue.put(
357+
Error(
358+
name="WebSocketError",
359+
value="Failed to send execution request",
360+
traceback="",
361+
)
362+
)
363+
await execution.queue.put(UnexpectedEndOfExecution())
323364

324365
# Stream the results
325366
async for item in self._wait_for_result(message_id):
@@ -343,6 +384,18 @@ async def _receive_message(self):
343384
await self._process_message(json.loads(message))
344385
except Exception as e:
345386
logger.error(f"WebSocket received error while receiving messages: {str(e)}")
387+
finally:
388+
# To prevent infinite hang, we need to cancel all ongoing execution as we could lost results during the reconnect
389+
# Thanks to the locking, there can be either no ongoing execution or just one.
390+
for key, execution in self._executions.items():
391+
await execution.queue.put(
392+
Error(
393+
name="WebSocketError",
394+
value="The connections was lost, rerun the code to get the results",
395+
traceback="",
396+
)
397+
)
398+
await execution.queue.put(UnexpectedEndOfExecution())
346399

347400
async def _process_message(self, data: dict):
348401
"""

0 commit comments

Comments
 (0)