Skip to content

Commit 2072a8e

Browse files
committed
set global envs on first execution
1 parent 645f91d commit 2072a8e

File tree

3 files changed

+10
-18
lines changed

3 files changed

+10
-18
lines changed

template/server/contexts.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from consts import JUPYTER_BASE_URL
99
from errors import ExecutionError
1010
from messaging import ContextWebSocket
11-
from envs import get_envs
1211

1312
logger = logging.Logger(__name__)
1413

@@ -52,7 +51,6 @@ async def create_context(client, websockets: dict, language: str, cwd: str) -> C
5251
session_data = response.json()
5352
session_id = session_data["id"]
5453
context_id = session_data["kernel"]["id"]
55-
global_env_vars = await get_envs()
5654

5755
logger.debug(f"Created context {context_id}")
5856

@@ -69,12 +67,4 @@ async def create_context(client, websockets: dict, language: str, cwd: str) -> C
6967
status_code=500,
7068
)
7169

72-
try:
73-
await ws.set_env_vars(global_env_vars)
74-
except ExecutionError as e:
75-
return PlainTextResponse(
76-
"Failed to set environment variables",
77-
status_code=500,
78-
)
79-
8070
return Context(language=language, id=context_id, cwd=cwd)

template/server/main.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from messaging import ContextWebSocket
1818
from stream import StreamingListJsonResponse
1919
from utils.locks import LockedMap
20+
from envs import get_envs
2021

2122
logging.basicConfig(level=logging.DEBUG, stream=sys.stdout)
2223
logger = logging.Logger(__name__)
@@ -104,6 +105,11 @@ async def post_execute(request: ExecutionRequest):
104105
status_code=404,
105106
)
106107

108+
# set global env vars if not set on first execution
109+
if not ws.global_env_vars:
110+
ws.global_env_vars = await get_envs()
111+
await ws.set_env_vars(ws.global_env_vars)
112+
107113
return StreamingListJsonResponse(
108114
ws.execute(
109115
request.code,

template/server/messaging.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,8 @@
33
import logging
44
import uuid
55
import asyncio
6-
import subprocess
76

87
from asyncio import Queue
9-
from envs import get_envs
108
from typing import (
119
Dict,
1210
Optional,
@@ -49,20 +47,20 @@ def __init__(self, in_background: bool = False):
4947
class ContextWebSocket:
5048
_ws: Optional[WebSocketClientProtocol] = None
5149
_receive_task: Optional[asyncio.Task] = None
50+
global_env_vars: Optional[Dict[StrictStr, str]] = None
5251

5352
def __init__(
5453
self,
5554
context_id: str,
5655
session_id: str,
5756
language: str,
58-
cwd: str,
57+
cwd: str
5958
):
6059
self.language = language
6160
self.cwd = cwd
6261
self.context_id = context_id
6362
self.url = f"ws://localhost:8888/api/kernels/{context_id}/channels"
6463
self.session_id = session_id
65-
6664
self._executions: Dict[str, Execution] = {}
6765
self._lock = asyncio.Lock()
6866

@@ -167,15 +165,13 @@ async def set_env_vars(self, env_vars: Dict[StrictStr, str]):
167165
raise ExecutionError(f"Error during execution: {item}")
168166

169167
async def reset_env_vars(self, env_vars: Dict[StrictStr, str]):
170-
global_env_vars = await get_envs()
171-
172168
# Create a dict of vars to reset and a list of vars to remove
173169
vars_to_reset = {}
174170
vars_to_remove = []
175171

176172
for key in env_vars:
177-
if key in global_env_vars:
178-
vars_to_reset[key] = global_env_vars[key]
173+
if self.global_env_vars and key in self.global_env_vars:
174+
vars_to_reset[key] = self.global_env_vars[key]
179175
else:
180176
vars_to_remove.append(key)
181177

0 commit comments

Comments
 (0)