Skip to content

Commit fbb037e

Browse files
committed
handle setting global env vars (remove in background)
1 parent 0f3766a commit fbb037e

File tree

2 files changed

+30
-10
lines changed

2 files changed

+30
-10
lines changed

template/server/main.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -105,11 +105,6 @@ async def post_execute(request: ExecutionRequest):
105105
status_code=404,
106106
)
107107

108-
# set global env vars if not set on first execution
109-
if not ws.global_env_vars:
110-
ws.global_env_vars = await get_envs()
111-
await ws.set_env_vars(ws.global_env_vars)
112-
113108
return StreamingListJsonResponse(
114109
ws.execute(
115110
request.code,

template/server/messaging.py

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
UnexpectedEndOfExecution,
2424
)
2525
from errors import ExecutionError
26+
from envs import get_envs
2627

2728
logger = logging.getLogger(__name__)
2829

@@ -48,6 +49,7 @@ class ContextWebSocket:
4849
_ws: Optional[WebSocketClientProtocol] = None
4950
_receive_task: Optional[asyncio.Task] = None
5051
global_env_vars: Optional[Dict[StrictStr, str]] = None
52+
global_env_vars_set = False
5153

5254
def __init__(
5355
self,
@@ -175,6 +177,23 @@ def _reset_env_vars_code(self, env_vars: Dict[StrictStr, str]) -> str:
175177

176178
return "\n".join(cleanup_commands)
177179

180+
async def _cleanup_env_vars(self, env_vars: Dict[StrictStr, str]):
181+
"""Clean up environment variables in a separate execution request."""
182+
message_id = str(uuid.uuid4())
183+
self._executions[message_id] = Execution(in_background=True)
184+
185+
cleanup_code = self._reset_env_vars_code(env_vars)
186+
if cleanup_code:
187+
logger.info(f"Cleaning up env vars: {cleanup_code}")
188+
request = self._get_execute_request(message_id, cleanup_code, True)
189+
await self._ws.send(request)
190+
191+
async for item in self._wait_for_result(message_id):
192+
if item["type"] == "error":
193+
logger.error(f"Error during env var cleanup: {item}")
194+
195+
del self._executions[message_id]
196+
178197
async def _wait_for_result(self, message_id: str):
179198
queue = self._executions[message_id].queue
180199

@@ -241,17 +260,19 @@ async def execute(
241260
async with self._lock:
242261
# Build the complete code snippet with env vars
243262
complete_code = code
263+
264+
if not self.global_env_vars:
265+
self.global_env_vars = await get_envs()
266+
267+
if not self.global_env_vars_set and self.global_env_vars:
268+
complete_code = self._set_env_vars_code(self.global_env_vars)
269+
self.global_env_vars_set = True
244270

245271
if env_vars:
246272
# Add env var setup at the beginning
247273
env_setup_code = self._set_env_vars_code(env_vars)
248274
if env_setup_code:
249275
complete_code = f"{env_setup_code}\n{complete_code}"
250-
251-
# Add env var cleanup at the end
252-
env_cleanup_code = self._reset_env_vars_code(env_vars)
253-
if env_cleanup_code:
254-
complete_code = f"{complete_code}\n{env_cleanup_code}"
255276

256277
logger.info(f"Executing complete code: {complete_code}")
257278
request = self._get_execute_request(message_id, complete_code, False)
@@ -265,6 +286,10 @@ async def execute(
265286

266287
del self._executions[message_id]
267288

289+
# Clean up env vars in a separate request after the main code has run (outside the lock)
290+
if env_vars:
291+
asyncio.create_task(self._cleanup_env_vars(env_vars))
292+
268293
async def _receive_message(self):
269294
if not self._ws:
270295
logger.error("No WebSocket connection")

0 commit comments

Comments
 (0)