|
3 | 3 | import logging |
4 | 4 | import uuid |
5 | 5 | import asyncio |
6 | | -import textwrap |
7 | 6 |
|
8 | 7 | from asyncio import Queue |
9 | 8 | from typing import ( |
@@ -51,6 +50,7 @@ class ContextWebSocket: |
51 | 50 | _receive_task: Optional[asyncio.Task] = None |
52 | 51 | global_env_vars: Optional[Dict[StrictStr, str]] = None |
53 | 52 | global_env_vars_set = False |
| 53 | + _cleanup_task: Optional[asyncio.Task] = None |
54 | 54 |
|
55 | 55 | def __init__( |
56 | 56 | self, |
@@ -209,17 +209,21 @@ async def _cleanup_env_vars(self, env_vars: Dict[StrictStr, str]): |
209 | 209 | message_id = str(uuid.uuid4()) |
210 | 210 | self._executions[message_id] = Execution(in_background=True) |
211 | 211 |
|
212 | | - cleanup_code = self._reset_env_vars_code(env_vars) |
213 | | - if cleanup_code: |
214 | | - logger.info(f"Cleaning up env vars: {cleanup_code}") |
215 | | - request = self._get_execute_request(message_id, cleanup_code, True) |
216 | | - await self._ws.send(request) |
217 | | - |
218 | | - async for item in self._wait_for_result(message_id): |
219 | | - if item["type"] == "error": |
220 | | - logger.error(f"Error during env var cleanup: {item}") |
221 | | - |
222 | | - del self._executions[message_id] |
| 212 | + try: |
| 213 | + cleanup_code = self._reset_env_vars_code(env_vars) |
| 214 | + if cleanup_code: |
| 215 | + logger.info(f"Cleaning up env vars: {cleanup_code}") |
| 216 | + request = self._get_execute_request(message_id, cleanup_code, True) |
| 217 | + await self._ws.send(request) |
| 218 | + |
| 219 | + async for item in self._wait_for_result(message_id): |
| 220 | + if item["type"] == "error": |
| 221 | + logger.error(f"Error during env var cleanup: {item}") |
| 222 | + finally: |
| 223 | + del self._executions[message_id] |
| 224 | + # Clear the task reference when cleanup is complete |
| 225 | + if self._cleanup_task and self._cleanup_task.done(): |
| 226 | + self._cleanup_task = None |
223 | 227 |
|
224 | 228 | async def _wait_for_result(self, message_id: str): |
225 | 229 | queue = self._executions[message_id].queue |
@@ -285,6 +289,16 @@ async def execute( |
285 | 289 | raise Exception("WebSocket not connected") |
286 | 290 |
|
287 | 291 | async with self._lock: |
| 292 | + # Wait for any pending cleanup task to complete |
| 293 | + if self._cleanup_task and not self._cleanup_task.done(): |
| 294 | + logger.debug("Waiting for pending cleanup task to complete") |
| 295 | + try: |
| 296 | + await self._cleanup_task |
| 297 | + except Exception as e: |
| 298 | + logger.warning(f"Cleanup task failed: {e}") |
| 299 | + finally: |
| 300 | + self._cleanup_task = None |
| 301 | + |
288 | 302 | # Get the indentation level from the code |
289 | 303 | code_indent = self._get_code_indentation(code) |
290 | 304 |
|
@@ -320,9 +334,9 @@ async def execute( |
320 | 334 |
|
321 | 335 | del self._executions[message_id] |
322 | 336 |
|
323 | | - # Clean up env vars in a separate request after the main code has run (outside the lock) |
324 | | - if env_vars: |
325 | | - asyncio.create_task(self._cleanup_env_vars(env_vars)) |
| 337 | + # Clean up env vars in a separate request after the main code has run |
| 338 | + if env_vars: |
| 339 | + self._cleanup_task = asyncio.create_task(self._cleanup_env_vars(env_vars)) |
326 | 340 |
|
327 | 341 | async def _receive_message(self): |
328 | 342 | if not self._ws: |
@@ -485,5 +499,13 @@ async def close(self): |
485 | 499 | if self._receive_task is not None: |
486 | 500 | self._receive_task.cancel() |
487 | 501 |
|
| 502 | + # Cancel any pending cleanup task |
| 503 | + if self._cleanup_task and not self._cleanup_task.done(): |
| 504 | + self._cleanup_task.cancel() |
| 505 | + try: |
| 506 | + await self._cleanup_task |
| 507 | + except asyncio.CancelledError: |
| 508 | + pass |
| 509 | + |
488 | 510 | for execution in self._executions.values(): |
489 | 511 | execution.queue.put_nowait(UnexpectedEndOfExecution()) |
0 commit comments