Skip to content

Commit 37c2a22

Browse files
committed
fix a race condition when subsequent request fire too rapidly before the cleanup could complete
1 parent aef5740 commit 37c2a22

File tree

1 file changed

+37
-15
lines changed

1 file changed

+37
-15
lines changed

template/server/messaging.py

Lines changed: 37 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import logging
44
import uuid
55
import asyncio
6-
import textwrap
76

87
from asyncio import Queue
98
from typing import (
@@ -51,6 +50,7 @@ class ContextWebSocket:
5150
_receive_task: Optional[asyncio.Task] = None
5251
global_env_vars: Optional[Dict[StrictStr, str]] = None
5352
global_env_vars_set = False
53+
_cleanup_task: Optional[asyncio.Task] = None
5454

5555
def __init__(
5656
self,
@@ -209,17 +209,21 @@ async def _cleanup_env_vars(self, env_vars: Dict[StrictStr, str]):
209209
message_id = str(uuid.uuid4())
210210
self._executions[message_id] = Execution(in_background=True)
211211

212-
cleanup_code = self._reset_env_vars_code(env_vars)
213-
if cleanup_code:
214-
logger.info(f"Cleaning up env vars: {cleanup_code}")
215-
request = self._get_execute_request(message_id, cleanup_code, True)
216-
await self._ws.send(request)
217-
218-
async for item in self._wait_for_result(message_id):
219-
if item["type"] == "error":
220-
logger.error(f"Error during env var cleanup: {item}")
221-
222-
del self._executions[message_id]
212+
try:
213+
cleanup_code = self._reset_env_vars_code(env_vars)
214+
if cleanup_code:
215+
logger.info(f"Cleaning up env vars: {cleanup_code}")
216+
request = self._get_execute_request(message_id, cleanup_code, True)
217+
await self._ws.send(request)
218+
219+
async for item in self._wait_for_result(message_id):
220+
if item["type"] == "error":
221+
logger.error(f"Error during env var cleanup: {item}")
222+
finally:
223+
del self._executions[message_id]
224+
# Clear the task reference when cleanup is complete
225+
if self._cleanup_task and self._cleanup_task.done():
226+
self._cleanup_task = None
223227

224228
async def _wait_for_result(self, message_id: str):
225229
queue = self._executions[message_id].queue
@@ -285,6 +289,16 @@ async def execute(
285289
raise Exception("WebSocket not connected")
286290

287291
async with self._lock:
292+
# Wait for any pending cleanup task to complete
293+
if self._cleanup_task and not self._cleanup_task.done():
294+
logger.debug("Waiting for pending cleanup task to complete")
295+
try:
296+
await self._cleanup_task
297+
except Exception as e:
298+
logger.warning(f"Cleanup task failed: {e}")
299+
finally:
300+
self._cleanup_task = None
301+
288302
# Get the indentation level from the code
289303
code_indent = self._get_code_indentation(code)
290304

@@ -320,9 +334,9 @@ async def execute(
320334

321335
del self._executions[message_id]
322336

323-
# Clean up env vars in a separate request after the main code has run (outside the lock)
324-
if env_vars:
325-
asyncio.create_task(self._cleanup_env_vars(env_vars))
337+
# Clean up env vars in a separate request after the main code has run
338+
if env_vars:
339+
self._cleanup_task = asyncio.create_task(self._cleanup_env_vars(env_vars))
326340

327341
async def _receive_message(self):
328342
if not self._ws:
@@ -485,5 +499,13 @@ async def close(self):
485499
if self._receive_task is not None:
486500
self._receive_task.cancel()
487501

502+
# Cancel any pending cleanup task
503+
if self._cleanup_task and not self._cleanup_task.done():
504+
self._cleanup_task.cancel()
505+
try:
506+
await self._cleanup_task
507+
except asyncio.CancelledError:
508+
pass
509+
488510
for execution in self._executions.values():
489511
execution.queue.put_nowait(UnexpectedEndOfExecution())

0 commit comments

Comments
 (0)