Skip to content

Commit f34cf8f

Browse files
committed
reset env variables to global in case of override or delete them after execution
1 parent 3d9cb5b commit f34cf8f

File tree

2 files changed

+79
-23
lines changed

2 files changed

+79
-23
lines changed

template/server/contexts.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from consts import JUPYTER_BASE_URL
99
from errors import ExecutionError
1010
from messaging import ContextWebSocket
11+
from envs import get_envs
1112

1213
logger = logging.Logger(__name__)
1314

@@ -51,6 +52,7 @@ async def create_context(client, websockets: dict, language: str, cwd: str) -> C
5152
session_data = response.json()
5253
session_id = session_data["id"]
5354
context_id = session_data["kernel"]["id"]
55+
global_env_vars = get_envs()
5456

5557
logger.debug(f"Created context {context_id}")
5658

@@ -67,4 +69,12 @@ async def create_context(client, websockets: dict, language: str, cwd: str) -> C
6769
status_code=500,
6870
)
6971

72+
try:
73+
await ws.set_env_vars(global_env_vars)
74+
except ExecutionError as e:
75+
return PlainTextResponse(
76+
"Failed to set environment variables",
77+
status_code=500,
78+
)
79+
7080
return Context(language=language, id=context_id, cwd=cwd)

template/server/messaging.py

Lines changed: 69 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,69 @@ async def _wait_for_result(self, message_id: str):
135135

136136
yield output.model_dump(exclude_none=True)
137137

138+
async def set_env_vars(self, env_vars: Dict[StrictStr, str]):
139+
env_commands = []
140+
141+
for k, v in env_vars.items():
142+
if self.language == "python":
143+
env_commands.append(f"os.environ['{k}'] = '{v}'")
144+
elif self.language in ["javascript", "typescript"]:
145+
env_commands.append(f"process.env['{k}'] = '{v}'")
146+
elif self.language == "deno":
147+
env_commands.append(f"Deno.env.set('{k}', '{v}')")
148+
elif self.language == "r":
149+
env_commands.append(f"Sys.setenv('{k}' = '{v}')")
150+
elif self.language == "java":
151+
env_commands.append(f"System.setProperty('{k}', '{v}')")
152+
elif self.language == "bash":
153+
env_commands.append(f"export {k}='{v}'")
154+
155+
if env_commands:
156+
env_vars_snippet = "\n".join(env_commands)
157+
print(f"Setting env vars: {env_vars_snippet}")
158+
request = self._get_execute_request(str(uuid.uuid4()), env_vars_snippet, False)
159+
await self._ws.send(request)
160+
161+
async def reset_env_vars(self, env_vars: Dict[StrictStr, str]):
162+
global_env_vars = get_envs()
163+
164+
# Create a dict of vars to reset and a list of vars to remove
165+
vars_to_reset = {}
166+
vars_to_remove = []
167+
168+
for key in env_vars:
169+
if key in global_env_vars:
170+
vars_to_reset[key] = global_env_vars[key]
171+
else:
172+
vars_to_remove.append(key)
173+
174+
# Reset vars that exist in global env vars
175+
if vars_to_reset:
176+
await self.set_env_vars(vars_to_reset)
177+
178+
# Remove vars that don't exist in global env vars
179+
if vars_to_remove:
180+
remove_commands = []
181+
for key in vars_to_remove:
182+
if self.language == "python":
183+
remove_commands.append(f"del os.environ['{key}']")
184+
elif self.language in ["javascript", "typescript"]:
185+
remove_commands.append(f"delete process.env['{key}']")
186+
elif self.language == "deno":
187+
remove_commands.append(f"Deno.env.delete('{key}')")
188+
elif self.language == "r":
189+
remove_commands.append(f"Sys.unsetenv('{key}')")
190+
elif self.language == "java":
191+
remove_commands.append(f"System.clearProperty('{key}')")
192+
elif self.language == "bash":
193+
remove_commands.append(f"unset {key}")
194+
195+
if remove_commands:
196+
remove_snippet = "\n".join(remove_commands)
197+
print(f"Removing env vars: {remove_snippet}")
198+
request = self._get_execute_request(str(uuid.uuid4()), remove_snippet, False)
199+
await self._ws.send(request)
200+
138201
async def change_current_directory(
139202
self, path: Union[str, StrictStr], language: str
140203
):
@@ -178,31 +241,10 @@ async def execute(
178241
if self._ws is None:
179242
raise Exception("WebSocket not connected")
180243

181-
global_env_vars = get_envs()
182-
env_vars = {**global_env_vars, **env_vars} if env_vars else global_env_vars
183244
async with self._lock:
245+
# set env vars (will override global env vars)
184246
if env_vars:
185-
vars_to_set = {**global_env_vars, **env_vars}
186-
env_vars_snippet = ""
187-
188-
if self.language == "python":
189-
env_vars_snippet = f"os.environ.set_envs_for_execution({vars_to_set})\n"
190-
elif self.language in ["javascript", "typescript"]:
191-
env_vars_snippet = "\n".join([f"process.env['{k}'] = '{v}';" for k, v in vars_to_set.items()])
192-
elif self.language == "deno":
193-
env_vars_snippet = "\n".join([f"Deno.env.set('{k}', '{v}');" for k, v in vars_to_set.items()])
194-
elif self.language == "r":
195-
env_vars_snippet = "\n".join([f"Sys.setenv('{k}' = '{v}')" for k, v in vars_to_set.items()])
196-
elif self.language == "java":
197-
env_vars_snippet = "\n".join([f"System.setProperty('{k}', '{v}');" for k, v in vars_to_set.items()])
198-
elif self.language == "bash":
199-
env_vars_snippet = "\n".join([f"export {k}='{v}'" for k, v in vars_to_set.items()])
200-
else:
201-
raise Exception(f"Unsupported language: {self.language}")
202-
203-
print(f"Setting env vars: {env_vars_snippet}")
204-
request = self._get_execute_request(str(uuid.uuid4()), env_vars_snippet, False)
205-
await self._ws.send(request)
247+
await self.set_env_vars(env_vars)
206248

207249
if self.language == "typescript":
208250
logger.info("Compiling TypeScript: %s", code)
@@ -242,6 +284,10 @@ async def execute(
242284

243285
del self._executions[message_id]
244286

287+
# reset env vars to their previous values, if they were set globally or remove them
288+
if env_vars:
289+
await self.reset_env_vars(env_vars)
290+
245291
async def _receive_message(self):
246292
if not self._ws:
247293
logger.error("No WebSocket connection")

0 commit comments

Comments
 (0)