|  | 
| 3 | 3 | import logging | 
| 4 | 4 | import uuid | 
| 5 | 5 | import asyncio | 
| 6 |  | -import textwrap | 
| 7 | 6 | 
 | 
| 8 | 7 | from asyncio import Queue | 
| 9 | 8 | from typing import ( | 
| @@ -51,6 +50,7 @@ class ContextWebSocket: | 
| 51 | 50 |     _receive_task: Optional[asyncio.Task] = None | 
| 52 | 51 |     global_env_vars: Optional[Dict[StrictStr, str]] = None | 
| 53 | 52 |     global_env_vars_set = False | 
|  | 53 | +    _cleanup_task: Optional[asyncio.Task] = None | 
| 54 | 54 | 
 | 
| 55 | 55 |     def __init__( | 
| 56 | 56 |         self, | 
| @@ -209,17 +209,21 @@ async def _cleanup_env_vars(self, env_vars: Dict[StrictStr, str]): | 
| 209 | 209 |         message_id = str(uuid.uuid4()) | 
| 210 | 210 |         self._executions[message_id] = Execution(in_background=True) | 
| 211 | 211 | 
 | 
| 212 |  | -        cleanup_code = self._reset_env_vars_code(env_vars) | 
| 213 |  | -        if cleanup_code: | 
| 214 |  | -            logger.info(f"Cleaning up env vars: {cleanup_code}") | 
| 215 |  | -            request = self._get_execute_request(message_id, cleanup_code, True) | 
| 216 |  | -            await self._ws.send(request) | 
| 217 |  | - | 
| 218 |  | -            async for item in self._wait_for_result(message_id): | 
| 219 |  | -                if item["type"] == "error": | 
| 220 |  | -                    logger.error(f"Error during env var cleanup: {item}") | 
| 221 |  | - | 
| 222 |  | -        del self._executions[message_id] | 
|  | 212 | +        try: | 
|  | 213 | +            cleanup_code = self._reset_env_vars_code(env_vars) | 
|  | 214 | +            if cleanup_code: | 
|  | 215 | +                logger.info(f"Cleaning up env vars: {cleanup_code}") | 
|  | 216 | +                request = self._get_execute_request(message_id, cleanup_code, True) | 
|  | 217 | +                await self._ws.send(request) | 
|  | 218 | + | 
|  | 219 | +                async for item in self._wait_for_result(message_id): | 
|  | 220 | +                    if item["type"] == "error": | 
|  | 221 | +                        logger.error(f"Error during env var cleanup: {item}") | 
|  | 222 | +        finally: | 
|  | 223 | +            del self._executions[message_id] | 
|  | 224 | +            # Clear the task reference when cleanup is complete | 
|  | 225 | +            if self._cleanup_task and self._cleanup_task.done(): | 
|  | 226 | +                self._cleanup_task = None | 
| 223 | 227 | 
 | 
| 224 | 228 |     async def _wait_for_result(self, message_id: str): | 
| 225 | 229 |         queue = self._executions[message_id].queue | 
| @@ -285,6 +289,16 @@ async def execute( | 
| 285 | 289 |             raise Exception("WebSocket not connected") | 
| 286 | 290 | 
 | 
| 287 | 291 |         async with self._lock: | 
|  | 292 | +            # Wait for any pending cleanup task to complete | 
|  | 293 | +            if self._cleanup_task and not self._cleanup_task.done(): | 
|  | 294 | +                logger.debug("Waiting for pending cleanup task to complete") | 
|  | 295 | +                try: | 
|  | 296 | +                    await self._cleanup_task | 
|  | 297 | +                except Exception as e: | 
|  | 298 | +                    logger.warning(f"Cleanup task failed: {e}") | 
|  | 299 | +                finally: | 
|  | 300 | +                    self._cleanup_task = None | 
|  | 301 | +             | 
| 288 | 302 |             # Get the indentation level from the code | 
| 289 | 303 |             code_indent = self._get_code_indentation(code) | 
| 290 | 304 | 
 | 
| @@ -320,9 +334,9 @@ async def execute( | 
| 320 | 334 | 
 | 
| 321 | 335 |             del self._executions[message_id] | 
| 322 | 336 | 
 | 
| 323 |  | -        # Clean up env vars in a separate request after the main code has run (outside the lock) | 
| 324 |  | -        if env_vars: | 
| 325 |  | -            asyncio.create_task(self._cleanup_env_vars(env_vars)) | 
|  | 337 | +            # Clean up env vars in a separate request after the main code has run | 
|  | 338 | +            if env_vars: | 
|  | 339 | +                self._cleanup_task = asyncio.create_task(self._cleanup_env_vars(env_vars)) | 
| 326 | 340 | 
 | 
| 327 | 341 |     async def _receive_message(self): | 
| 328 | 342 |         if not self._ws: | 
| @@ -485,5 +499,13 @@ async def close(self): | 
| 485 | 499 |         if self._receive_task is not None: | 
| 486 | 500 |             self._receive_task.cancel() | 
| 487 | 501 | 
 | 
|  | 502 | +        # Cancel any pending cleanup task | 
|  | 503 | +        if self._cleanup_task and not self._cleanup_task.done(): | 
|  | 504 | +            self._cleanup_task.cancel() | 
|  | 505 | +            try: | 
|  | 506 | +                await self._cleanup_task | 
|  | 507 | +            except asyncio.CancelledError: | 
|  | 508 | +                pass | 
|  | 509 | + | 
| 488 | 510 |         for execution in self._executions.values(): | 
| 489 | 511 |             execution.queue.put_nowait(UnexpectedEndOfExecution()) | 
0 commit comments