Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions python/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def template():

@pytest.fixture()
def sandbox(template, debug):
sandbox = Sandbox(template, timeout=timeout)
sandbox = Sandbox(template, timeout=timeout, debug=debug)

try:
yield sandbox
Expand All @@ -33,7 +33,7 @@ def sandbox(template, debug):

@pytest_asyncio.fixture
async def async_sandbox(template, debug):
async_sandbox = await AsyncSandbox.create(template, timeout=timeout)
async_sandbox = await AsyncSandbox.create(template, timeout=timeout, debug=debug)

try:
yield async_sandbox
Expand Down
5 changes: 0 additions & 5 deletions template/server/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,11 +108,6 @@ async def post_execute(request: ExecutionRequest):
status_code=404,
)

# set global env vars if not set on first execution
if not ws.global_env_vars:
ws.global_env_vars = await get_envs()
await ws.set_env_vars(ws.global_env_vars)

return StreamingListJsonResponse(
ws.execute(
request.code,
Expand Down
240 changes: 152 additions & 88 deletions template/server/messaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
UnexpectedEndOfExecution,
)
from errors import ExecutionError
from envs import get_envs

logger = logging.getLogger(__name__)

Expand All @@ -47,7 +48,8 @@ def __init__(self, in_background: bool = False):
class ContextWebSocket:
_ws: Optional[WebSocketClientProtocol] = None
_receive_task: Optional[asyncio.Task] = None
global_env_vars: Optional[Dict[StrictStr, str]] = None
_global_env_vars: Optional[Dict[StrictStr, str]] = None
_cleanup_task: Optional[asyncio.Task] = None

def __init__(
self,
Expand Down Expand Up @@ -114,6 +116,113 @@ def _get_execute_request(
}
)

def _set_env_var_snippet(self, key: str, value: str) -> str:
"""Get environment variable set command for the current language."""
if self.language == "python":
return f"import os; os.environ['{key}'] = '{value}'"
elif self.language in ["javascript", "typescript"]:
return f"process.env['{key}'] = '{value}'"
elif self.language == "deno":
return f"Deno.env.set('{key}', '{value}')"
elif self.language == "r":
return f'Sys.setenv({key} = "{value}")'
elif self.language == "java":
return f'System.setProperty("{key}", "{value}");'
elif self.language == "bash":
return f"export {key}='{value}'"
return ""

def _delete_env_var_snippet(self, key: str) -> str:
"""Get environment variable delete command for the current language."""
if self.language == "python":
return f"import os; del os.environ['{key}']"
elif self.language in ["javascript", "typescript"]:
return f"delete process.env['{key}']"
elif self.language == "deno":
return f"Deno.env.delete('{key}')"
elif self.language == "r":
return f"Sys.unsetenv('{key}')"
elif self.language == "java":
return f'System.clearProperty("{key}");'
elif self.language == "bash":
return f"unset {key}"
return ""

def _set_env_vars_code(self, env_vars: Dict[StrictStr, str]) -> str:
"""Build environment variable code for the current language."""
env_commands = []
for k, v in env_vars.items():
command = self._set_env_var_snippet(k, v)
if command:
env_commands.append(command)

return "\n".join(env_commands)

def _reset_env_vars_code(self, env_vars: Dict[StrictStr, str]) -> str:
"""Build environment variable cleanup code for the current language."""
cleanup_commands = []

for key in env_vars:
# Check if this var exists in global env vars
if self._global_env_vars and key in self._global_env_vars:
# Reset to global value
value = self._global_env_vars[key]
command = self._set_env_var_snippet(key, value)
else:
# Remove the variable
command = self._delete_env_var_snippet(key)

if command:
cleanup_commands.append(command)

return "\n".join(cleanup_commands)

def _get_code_indentation(self, code: str) -> str:
"""Get the indentation from the first non-empty line of code."""
if not code or not code.strip():
return ""

lines = code.split('\n')
for line in lines:
if line.strip(): # First non-empty line
return line[:len(line) - len(line.lstrip())]

return ""

def _indent_code_with_level(self, code: str, indent_level: str) -> str:
"""Apply the given indentation level to each line of code."""
if not code or not indent_level:
return code

lines = code.split('\n')
indented_lines = []

for line in lines:
if line.strip(): # Non-empty lines
indented_lines.append(indent_level + line)
else:
indented_lines.append(line)

return '\n'.join(indented_lines)

async def _cleanup_env_vars(self, env_vars: Dict[StrictStr, str]):
"""Clean up environment variables in a separate execution request."""
message_id = str(uuid.uuid4())
self._executions[message_id] = Execution(in_background=True)

try:
cleanup_code = self._reset_env_vars_code(env_vars)
if cleanup_code:
logger.info(f"Cleaning up env vars: {cleanup_code}")
request = self._get_execute_request(message_id, cleanup_code, True)
await self._ws.send(request)

async for item in self._wait_for_result(message_id):
if item["type"] == "error":
logger.error(f"Error during env var cleanup: {item}")
finally:
del self._executions[message_id]

async def _wait_for_result(self, message_id: str):
queue = self._executions[message_id].queue

Expand All @@ -133,84 +242,6 @@ async def _wait_for_result(self, message_id: str):

yield output.model_dump(exclude_none=True)

async def set_env_vars(self, env_vars: Dict[StrictStr, str]):
message_id = str(uuid.uuid4())
self._executions[message_id] = Execution(in_background=True)

env_commands = []
for k, v in env_vars.items():
if self.language == "python":
env_commands.append(f"import os; os.environ['{k}'] = '{v}'")
elif self.language in ["javascript", "typescript"]:
env_commands.append(f"process.env['{k}'] = '{v}'")
elif self.language == "deno":
env_commands.append(f"Deno.env.set('{k}', '{v}')")
elif self.language == "r":
env_commands.append(f'Sys.setenv({k} = "{v}")')
elif self.language == "java":
env_commands.append(f'System.setProperty("{k}", "{v}");')
elif self.language == "bash":
env_commands.append(f"export {k}='{v}'")
else:
return

if env_commands:
env_vars_snippet = "\n".join(env_commands)
logger.info(f"Setting env vars: {env_vars_snippet} for {self.language}")
request = self._get_execute_request(message_id, env_vars_snippet, True)
await self._ws.send(request)

async for item in self._wait_for_result(message_id):
if item["type"] == "error":
raise ExecutionError(f"Error during execution: {item}")

async def reset_env_vars(self, env_vars: Dict[StrictStr, str]):
# Create a dict of vars to reset and a list of vars to remove
vars_to_reset = {}
vars_to_remove = []

for key in env_vars:
if self.global_env_vars and key in self.global_env_vars:
vars_to_reset[key] = self.global_env_vars[key]
else:
vars_to_remove.append(key)

# Reset vars that exist in global env vars
if vars_to_reset:
await self.set_env_vars(vars_to_reset)

# Remove vars that don't exist in global env vars
if vars_to_remove:
message_id = str(uuid.uuid4())
self._executions[message_id] = Execution(in_background=True)

remove_commands = []
for key in vars_to_remove:
if self.language == "python":
remove_commands.append(f"import os; del os.environ['{key}']")
elif self.language in ["javascript", "typescript"]:
remove_commands.append(f"delete process.env['{key}']")
elif self.language == "deno":
remove_commands.append(f"Deno.env.delete('{key}')")
elif self.language == "r":
remove_commands.append(f"Sys.unsetenv('{key}')")
elif self.language == "java":
remove_commands.append(f'System.clearProperty("{key}");')
elif self.language == "bash":
remove_commands.append(f"unset {key}")
else:
return

if remove_commands:
remove_snippet = "\n".join(remove_commands)
logger.info(f"Removing env vars: {remove_snippet} for {self.language}")
request = self._get_execute_request(message_id, remove_snippet, True)
await self._ws.send(request)

async for item in self._wait_for_result(message_id):
if item["type"] == "error":
raise ExecutionError(f"Error during execution: {item}")

async def change_current_directory(
self, path: Union[str, StrictStr], language: str
):
Expand Down Expand Up @@ -248,20 +279,44 @@ async def execute(
env_vars: Dict[StrictStr, str] = None,
):
message_id = str(uuid.uuid4())
logger.debug(f"Sending code for the execution ({message_id}): {code}")

self._executions[message_id] = Execution()

if self._ws is None:
raise Exception("WebSocket not connected")

async with self._lock:
# set env vars (will override global env vars)
# Wait for any pending cleanup task to complete
if self._cleanup_task and not self._cleanup_task.done():
logger.debug("Waiting for pending cleanup task to complete")
try:
await self._cleanup_task
except Exception as e:
logger.warning(f"Cleanup task failed: {e}")
finally:
self._cleanup_task = None

# Get the indentation level from the code
code_indent = self._get_code_indentation(code)

# Build the complete code snippet with env vars
complete_code = code

global_env_vars_snippet = ""
env_vars_snippet = ""

if self._global_env_vars is None:
self._global_env_vars = await get_envs()
global_env_vars_snippet = self._set_env_vars_code(self._global_env_vars)

if env_vars:
await self.set_env_vars(env_vars)
env_vars_snippet = self._set_env_vars_code(env_vars)

logger.info(code)
request = self._get_execute_request(message_id, code, False)
if global_env_vars_snippet or env_vars_snippet:
indented_env_code = self._indent_code_with_level(f"{global_env_vars_snippet}\n{env_vars_snippet}", code_indent)
complete_code = f"{indented_env_code}\n{complete_code}"

logger.info(f"Sending code for the execution ({message_id}): {complete_code}")
request = self._get_execute_request(message_id, complete_code, False)

# Send the code for execution
await self._ws.send(request)
Expand All @@ -272,9 +327,9 @@ async def execute(

del self._executions[message_id]

# reset env vars to their previous values, if they were set globally or remove them
# Clean up env vars in a separate request after the main code has run
if env_vars:
await self.reset_env_vars(env_vars)
self._cleanup_task = asyncio.create_task(self._cleanup_env_vars(env_vars))

async def _receive_message(self):
if not self._ws:
Expand Down Expand Up @@ -434,7 +489,16 @@ async def close(self):
if self._ws is not None:
await self._ws.close()

self._receive_task.cancel()
if self._receive_task is not None:
self._receive_task.cancel()

# Cancel any pending cleanup task
if self._cleanup_task and not self._cleanup_task.done():
self._cleanup_task.cancel()
try:
await self._cleanup_task
except asyncio.CancelledError:
pass

for execution in self._executions.values():
execution.queue.put_nowait(UnexpectedEndOfExecution())