diff --git a/python/tests/conftest.py b/python/tests/conftest.py index 30276cb1..daac9822 100644 --- a/python/tests/conftest.py +++ b/python/tests/conftest.py @@ -17,7 +17,7 @@ def template(): @pytest.fixture() def sandbox(template, debug): - sandbox = Sandbox(template, timeout=timeout) + sandbox = Sandbox(template, timeout=timeout, debug=debug) try: yield sandbox @@ -33,7 +33,7 @@ def sandbox(template, debug): @pytest_asyncio.fixture async def async_sandbox(template, debug): - async_sandbox = await AsyncSandbox.create(template, timeout=timeout) + async_sandbox = await AsyncSandbox.create(template, timeout=timeout, debug=debug) try: yield async_sandbox diff --git a/template/server/main.py b/template/server/main.py index 59b188a2..86e57e17 100644 --- a/template/server/main.py +++ b/template/server/main.py @@ -108,11 +108,6 @@ async def post_execute(request: ExecutionRequest): status_code=404, ) - # set global env vars if not set on first execution - if not ws.global_env_vars: - ws.global_env_vars = await get_envs() - await ws.set_env_vars(ws.global_env_vars) - return StreamingListJsonResponse( ws.execute( request.code, diff --git a/template/server/messaging.py b/template/server/messaging.py index 7bc2902f..7c295b2d 100644 --- a/template/server/messaging.py +++ b/template/server/messaging.py @@ -23,6 +23,7 @@ UnexpectedEndOfExecution, ) from errors import ExecutionError +from envs import get_envs logger = logging.getLogger(__name__) @@ -47,7 +48,8 @@ def __init__(self, in_background: bool = False): class ContextWebSocket: _ws: Optional[WebSocketClientProtocol] = None _receive_task: Optional[asyncio.Task] = None - global_env_vars: Optional[Dict[StrictStr, str]] = None + _global_env_vars: Optional[Dict[StrictStr, str]] = None + _cleanup_task: Optional[asyncio.Task] = None def __init__( self, @@ -114,6 +116,113 @@ def _get_execute_request( } ) + def _set_env_var_snippet(self, key: str, value: str) -> str: + """Get environment variable set command for the current language.""" + if self.language == "python": + return f"import os; os.environ['{key}'] = '{value}'" + elif self.language in ["javascript", "typescript"]: + return f"process.env['{key}'] = '{value}'" + elif self.language == "deno": + return f"Deno.env.set('{key}', '{value}')" + elif self.language == "r": + return f'Sys.setenv({key} = "{value}")' + elif self.language == "java": + return f'System.setProperty("{key}", "{value}");' + elif self.language == "bash": + return f"export {key}='{value}'" + return "" + + def _delete_env_var_snippet(self, key: str) -> str: + """Get environment variable delete command for the current language.""" + if self.language == "python": + return f"import os; del os.environ['{key}']" + elif self.language in ["javascript", "typescript"]: + return f"delete process.env['{key}']" + elif self.language == "deno": + return f"Deno.env.delete('{key}')" + elif self.language == "r": + return f"Sys.unsetenv('{key}')" + elif self.language == "java": + return f'System.clearProperty("{key}");' + elif self.language == "bash": + return f"unset {key}" + return "" + + def _set_env_vars_code(self, env_vars: Dict[StrictStr, str]) -> str: + """Build environment variable code for the current language.""" + env_commands = [] + for k, v in env_vars.items(): + command = self._set_env_var_snippet(k, v) + if command: + env_commands.append(command) + + return "\n".join(env_commands) + + def _reset_env_vars_code(self, env_vars: Dict[StrictStr, str]) -> str: + """Build environment variable cleanup code for the current language.""" + cleanup_commands = [] + + for key in env_vars: + # Check if this var exists in global env vars + if self._global_env_vars and key in self._global_env_vars: + # Reset to global value + value = self._global_env_vars[key] + command = self._set_env_var_snippet(key, value) + else: + # Remove the variable + command = self._delete_env_var_snippet(key) + + if command: + cleanup_commands.append(command) + + return "\n".join(cleanup_commands) + + def _get_code_indentation(self, code: str) -> str: + """Get the indentation from the first non-empty line of code.""" + if not code or not code.strip(): + return "" + + lines = code.split('\n') + for line in lines: + if line.strip(): # First non-empty line + return line[:len(line) - len(line.lstrip())] + + return "" + + def _indent_code_with_level(self, code: str, indent_level: str) -> str: + """Apply the given indentation level to each line of code.""" + if not code or not indent_level: + return code + + lines = code.split('\n') + indented_lines = [] + + for line in lines: + if line.strip(): # Non-empty lines + indented_lines.append(indent_level + line) + else: + indented_lines.append(line) + + return '\n'.join(indented_lines) + + async def _cleanup_env_vars(self, env_vars: Dict[StrictStr, str]): + """Clean up environment variables in a separate execution request.""" + message_id = str(uuid.uuid4()) + self._executions[message_id] = Execution(in_background=True) + + try: + cleanup_code = self._reset_env_vars_code(env_vars) + if cleanup_code: + logger.info(f"Cleaning up env vars: {cleanup_code}") + request = self._get_execute_request(message_id, cleanup_code, True) + await self._ws.send(request) + + async for item in self._wait_for_result(message_id): + if item["type"] == "error": + logger.error(f"Error during env var cleanup: {item}") + finally: + del self._executions[message_id] + async def _wait_for_result(self, message_id: str): queue = self._executions[message_id].queue @@ -133,84 +242,6 @@ async def _wait_for_result(self, message_id: str): yield output.model_dump(exclude_none=True) - async def set_env_vars(self, env_vars: Dict[StrictStr, str]): - message_id = str(uuid.uuid4()) - self._executions[message_id] = Execution(in_background=True) - - env_commands = [] - for k, v in env_vars.items(): - if self.language == "python": - env_commands.append(f"import os; os.environ['{k}'] = '{v}'") - elif self.language in ["javascript", "typescript"]: - env_commands.append(f"process.env['{k}'] = '{v}'") - elif self.language == "deno": - env_commands.append(f"Deno.env.set('{k}', '{v}')") - elif self.language == "r": - env_commands.append(f'Sys.setenv({k} = "{v}")') - elif self.language == "java": - env_commands.append(f'System.setProperty("{k}", "{v}");') - elif self.language == "bash": - env_commands.append(f"export {k}='{v}'") - else: - return - - if env_commands: - env_vars_snippet = "\n".join(env_commands) - logger.info(f"Setting env vars: {env_vars_snippet} for {self.language}") - request = self._get_execute_request(message_id, env_vars_snippet, True) - await self._ws.send(request) - - async for item in self._wait_for_result(message_id): - if item["type"] == "error": - raise ExecutionError(f"Error during execution: {item}") - - async def reset_env_vars(self, env_vars: Dict[StrictStr, str]): - # Create a dict of vars to reset and a list of vars to remove - vars_to_reset = {} - vars_to_remove = [] - - for key in env_vars: - if self.global_env_vars and key in self.global_env_vars: - vars_to_reset[key] = self.global_env_vars[key] - else: - vars_to_remove.append(key) - - # Reset vars that exist in global env vars - if vars_to_reset: - await self.set_env_vars(vars_to_reset) - - # Remove vars that don't exist in global env vars - if vars_to_remove: - message_id = str(uuid.uuid4()) - self._executions[message_id] = Execution(in_background=True) - - remove_commands = [] - for key in vars_to_remove: - if self.language == "python": - remove_commands.append(f"import os; del os.environ['{key}']") - elif self.language in ["javascript", "typescript"]: - remove_commands.append(f"delete process.env['{key}']") - elif self.language == "deno": - remove_commands.append(f"Deno.env.delete('{key}')") - elif self.language == "r": - remove_commands.append(f"Sys.unsetenv('{key}')") - elif self.language == "java": - remove_commands.append(f'System.clearProperty("{key}");') - elif self.language == "bash": - remove_commands.append(f"unset {key}") - else: - return - - if remove_commands: - remove_snippet = "\n".join(remove_commands) - logger.info(f"Removing env vars: {remove_snippet} for {self.language}") - request = self._get_execute_request(message_id, remove_snippet, True) - await self._ws.send(request) - - async for item in self._wait_for_result(message_id): - if item["type"] == "error": - raise ExecutionError(f"Error during execution: {item}") - async def change_current_directory( self, path: Union[str, StrictStr], language: str ): @@ -248,20 +279,44 @@ async def execute( env_vars: Dict[StrictStr, str] = None, ): message_id = str(uuid.uuid4()) - logger.debug(f"Sending code for the execution ({message_id}): {code}") - self._executions[message_id] = Execution() if self._ws is None: raise Exception("WebSocket not connected") async with self._lock: - # set env vars (will override global env vars) + # Wait for any pending cleanup task to complete + if self._cleanup_task and not self._cleanup_task.done(): + logger.debug("Waiting for pending cleanup task to complete") + try: + await self._cleanup_task + except Exception as e: + logger.warning(f"Cleanup task failed: {e}") + finally: + self._cleanup_task = None + + # Get the indentation level from the code + code_indent = self._get_code_indentation(code) + + # Build the complete code snippet with env vars + complete_code = code + + global_env_vars_snippet = "" + env_vars_snippet = "" + + if self._global_env_vars is None: + self._global_env_vars = await get_envs() + global_env_vars_snippet = self._set_env_vars_code(self._global_env_vars) + if env_vars: - await self.set_env_vars(env_vars) + env_vars_snippet = self._set_env_vars_code(env_vars) - logger.info(code) - request = self._get_execute_request(message_id, code, False) + if global_env_vars_snippet or env_vars_snippet: + indented_env_code = self._indent_code_with_level(f"{global_env_vars_snippet}\n{env_vars_snippet}", code_indent) + complete_code = f"{indented_env_code}\n{complete_code}" + + logger.info(f"Sending code for the execution ({message_id}): {complete_code}") + request = self._get_execute_request(message_id, complete_code, False) # Send the code for execution await self._ws.send(request) @@ -272,9 +327,9 @@ async def execute( del self._executions[message_id] - # reset env vars to their previous values, if they were set globally or remove them + # Clean up env vars in a separate request after the main code has run if env_vars: - await self.reset_env_vars(env_vars) + self._cleanup_task = asyncio.create_task(self._cleanup_env_vars(env_vars)) async def _receive_message(self): if not self._ws: @@ -434,7 +489,16 @@ async def close(self): if self._ws is not None: await self._ws.close() - self._receive_task.cancel() + if self._receive_task is not None: + self._receive_task.cancel() + + # Cancel any pending cleanup task + if self._cleanup_task and not self._cleanup_task.done(): + self._cleanup_task.cancel() + try: + await self._cleanup_task + except asyncio.CancelledError: + pass for execution in self._executions.values(): execution.queue.put_nowait(UnexpectedEndOfExecution())