diff --git a/mellea/stdlib/reqlib/python.py b/mellea/stdlib/reqlib/python.py new file mode 100644 index 00000000..24c24066 --- /dev/null +++ b/mellea/stdlib/reqlib/python.py @@ -0,0 +1,342 @@ +"""Requirements for Python code generation validation.""" + +import ast +import subprocess +import sys +import tempfile +from abc import ABC, abstractmethod +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +from mellea.helpers.fancy_logger import FancyLogger +from mellea.stdlib.base import Context +from mellea.stdlib.requirement import Requirement, ValidationResult + +logger = FancyLogger.get_logger() + +# region execution backends + + +@dataclass +class ExecutionResult: + """Result of code execution.""" + + success: bool + message: str | None = None + error: str | None = None + skipped: bool = False + + +class ExecutionBackend(ABC): + """Abstract backend for executing Python code.""" + + @abstractmethod + def execute(self, code: str, timeout: int) -> ExecutionResult: + """Execute code and return result.""" + + +class SafeBackend(ExecutionBackend): + """Safe backend that validates but does not execute code.""" + + def __init__(self, allowed_imports: list[str] | None = None): + """Initialize with optional import restrictions.""" + self.allowed_imports = allowed_imports + + def execute(self, code: str, timeout: int) -> ExecutionResult: + """Validate code syntax and imports without executing.""" + try: + ast.parse(code) + except SyntaxError as e: + return ExecutionResult(success=False, error=str(e)) + + if self.allowed_imports and not _check_allowed_imports( + code, self.allowed_imports + ): + return ExecutionResult(success=False, error="Unauthorized imports detected") + + return ExecutionResult( + success=True, + skipped=True, + message="Code validated but not executed (safe mode)", + ) + + +class UnsafeBackend(ExecutionBackend): + """Unsafe backend that executes code directly with subprocess.""" + + def __init__(self, allowed_imports: list[str] | None = None): + """Initialize with optional import restrictions.""" + self.allowed_imports = allowed_imports + + def execute(self, code: str, timeout: int) -> ExecutionResult: + """Execute code with subprocess after checking imports.""" + if self.allowed_imports and not _check_allowed_imports( + code, self.allowed_imports + ): + return ExecutionResult(success=False, error="Unauthorized imports detected") + + return self._execute_subprocess(code, timeout) + + def _execute_subprocess(self, code: str, timeout: int) -> ExecutionResult: + with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f: + f.write(code) + temp_file = f.name + + try: + result = subprocess.run( + [sys.executable, temp_file], + capture_output=True, + text=True, + timeout=timeout, + ) + + if result.returncode == 0: + return ExecutionResult( + success=True, message="Code executed successfully" + ) + else: + return ExecutionResult( + success=False, + error=f"Execution failed with error: {result.stderr[:200]}", + ) + except subprocess.TimeoutExpired: + return ExecutionResult( + success=False, error=f"Execution timed out after {timeout} seconds" + ) + except Exception as e: + return ExecutionResult(success=False, error=f"Execution error: {e!s}") + finally: + try: + Path(temp_file).unlink() + except Exception: + pass + + +class LLMSandboxBackend(ExecutionBackend): + """Backend using llm-sandbox for secure Docker-based execution.""" + + def __init__(self, allowed_imports: list[str] | None = None): + """Initialize with optional import restrictions.""" + self.allowed_imports = allowed_imports + + def execute(self, code: str, timeout: int) -> ExecutionResult: + """Execute code using llm-sandbox.""" + if self.allowed_imports and not _check_allowed_imports( + code, self.allowed_imports + ): + return ExecutionResult(success=False, error="Unauthorized imports detected") + + try: + from llm_sandbox import SandboxSession + except ImportError: + return ExecutionResult( + success=False, + error="llm-sandbox not installed. Install with: uv add 'llm-sandbox[docker]'", + ) + + try: + with SandboxSession( + lang="python", verbose=False, keep_template=False + ) as session: + result = session.run(code, timeout=timeout) + + if result.exit_code == 0: + return ExecutionResult( + success=True, message="Code executed successfully in sandbox" + ) + else: + return ExecutionResult( + success=False, + error=f"Sandbox execution failed: {result.stderr[:200] if result.stderr else 'Unknown error'}", + ) + + except Exception as e: + return ExecutionResult( + success=False, error=f"Sandbox execution error: {e!s}" + ) + + +def _check_allowed_imports(code: str, allowed_imports: list[str]) -> bool: + """Check if code only uses allowed imports.""" + try: + tree = ast.parse(code) + except SyntaxError: + return False + + for node in ast.walk(tree): + if isinstance(node, ast.Import): + for alias in node.names: + base_module = alias.name.split(".")[0] + if base_module not in allowed_imports: + return False + elif isinstance(node, ast.ImportFrom): + if node.module: + base_module = node.module.split(".")[0] + if base_module not in allowed_imports: + return False + return True + + +# endregion + +# region code extraction + + +def _score_code_block(code: str, context_text: str = "") -> int: + """Score a code block to determine if it's likely the main answer.""" + score = 0 + lines = code.split("\n") + + # Longer blocks generally better + score += min(len(lines), 10) + + # Prefer complete functions/classes + if "def " in code or "class " in code: + score += 5 + + # Prefer blocks with actual logic + if any(keyword in code for keyword in ["if ", "for ", "while ", "try:", "with "]): + score += 3 + + # Avoid pure imports or comments + non_trivial_lines = [ + line.strip() + for line in lines + if line.strip() and not line.strip().startswith(("#", "import ", "from ")) + ] + if len(non_trivial_lines) < 2: + score -= 5 + + return score + + +def _has_python_code_listing(ctx: Context) -> ValidationResult: + """Extract Python code from context.""" + last_output = ctx.last_output() + if last_output is None or last_output.value is None: + return ValidationResult(result=False, reason="No output found in context") + + content = last_output.value + + # Look for code blocks with python specifier + import re + + # Pattern for ```python ... ``` blocks + python_blocks = re.findall(r"```python\s*\n(.*?)\n```", content, re.DOTALL) + + # Pattern for generic ``` blocks + generic_blocks = re.findall(r"```\s*\n(.*?)\n```", content, re.DOTALL) + + all_blocks = [] + + # Add python blocks with high priority + for block in python_blocks: + all_blocks.append( + (block.strip(), _score_code_block(block.strip(), content) + 10) + ) + + # Add generic blocks if they look like Python + for block in generic_blocks: + block = block.strip() + if block and any( + keyword in block + for keyword in ["def ", "class ", "import ", "print(", "if __name__"] + ): + all_blocks.append((block, _score_code_block(block, content))) + + if not all_blocks: + return ValidationResult(result=False, reason="No Python code blocks found") + + # Return the highest scoring block + best_block = max(all_blocks, key=lambda x: x[1]) + return ValidationResult(result=True, reason=best_block[0]) + + +# endregion + +# region execution validation + + +def _python_executes_without_error( + ctx: Context, + timeout: int = 5, + allow_unsafe: bool = False, + allowed_imports: list[str] | None = None, + use_sandbox: bool = False, +) -> ValidationResult: + """Validate that Python code executes without raising exceptions.""" + extraction_result = _has_python_code_listing(ctx) + if not extraction_result.as_bool(): + return ValidationResult( + result=False, reason="Could not extract Python code for execution" + ) + + code = extraction_result.reason + assert code is not None + + backend: ExecutionBackend + if use_sandbox: + backend = LLMSandboxBackend(allowed_imports=allowed_imports) + elif allow_unsafe: + backend = UnsafeBackend(allowed_imports=allowed_imports) + else: + backend = SafeBackend(allowed_imports=allowed_imports) + + result = backend.execute(code, timeout) + return ValidationResult( + result=result.success, reason=result.message or result.error + ) + + +class PythonExecutesWithoutError(Requirement): + """Verifies that Python code runs without raising exceptions.""" + + def __init__( + self, + timeout: int = 5, + allow_unsafe_execution: bool = False, + allowed_imports: list[str] | None = None, + use_sandbox: bool = False, + ): + """Initialize execution validator. + + Args: + timeout: Maximum seconds to allow code to run before timing out. + allow_unsafe_execution: If True, execute code directly with subprocess (unsafe). + allowed_imports: List of allowed import modules when using execution. + use_sandbox: If True, use llm-sandbox for secure Docker-based execution. + """ + self._timeout = timeout + self._allow_unsafe = allow_unsafe_execution + self._allowed_imports = allowed_imports + self._use_sandbox = use_sandbox + + if allow_unsafe_execution and not use_sandbox: + logger.warning( + "⚠️ UNSAFE: Executing untrusted code directly. Only use with trusted sources!" + ) + + if use_sandbox and allow_unsafe_execution: + execution_mode = f"sandbox execution (timeout: {timeout}s)" + elif allow_unsafe_execution: + execution_mode = f"unsafe execution (timeout: {timeout}s)" + elif use_sandbox: + execution_mode = f"sandbox execution (timeout: {timeout}s)" + else: + execution_mode = "validation only" + + super().__init__( + description=f"The Python code should execute without errors ({execution_mode}).", + validation_fn=lambda ctx: _python_executes_without_error( + ctx, + self._timeout, + self._allow_unsafe, + self._allowed_imports, + self._use_sandbox, + ), + check_only=True, + ) + + +# endregion diff --git a/pyproject.toml b/pyproject.toml index 87533270..dc825cae 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,8 +42,9 @@ dependencies = [ "mistletoe>=1.4.0", "huggingface-hub>=0.33.4", "pillow", - "math_verify", # Needed for Majority Voting Sampling Strategies. - "rouge_score" # Needed for Majority Voting Sampling Strategies. + "math_verify", # Needed for Majority Voting Sampling Strategies. # Needed for Majority Voting Sampling Strategies. + "rouge_score", + "llm-sandbox[docker]>=0.3.23", ] [project.scripts] diff --git a/test/stdlib_basics/test_reqlib_python.py b/test/stdlib_basics/test_reqlib_python.py new file mode 100644 index 00000000..a3bcaad9 --- /dev/null +++ b/test/stdlib_basics/test_reqlib_python.py @@ -0,0 +1,331 @@ +"""Tests for Python code execution requirements.""" + +import pytest + +from mellea.stdlib.base import Context +from mellea.stdlib.reqlib.python import ( + PythonExecutesWithoutError, + _has_python_code_listing, + _python_executes_without_error, +) + + +def from_model(content: str) -> Context: + """Helper to create context from model output.""" + from mellea.stdlib.base import ChatContext, ModelOutputThunk + ctx = ChatContext() + ctx = ctx.add(ModelOutputThunk(value=content)) + return ctx + + +# Test contexts +VALID_PYTHON_CODE = """```python +def hello_world(): + return "Hello, World!" + +print(hello_world()) +```""" + +PYTHON_WITH_SYNTAX_ERROR = """```python +def hello_world( + return "Hello, World!" +```""" + +PYTHON_WITH_RUNTIME_ERROR = """```python +def divide_by_zero(): + return 1 / 0 + +divide_by_zero() +```""" + +PYTHON_WITH_IMPORTS = """```python +import os +import sys +from pathlib import Path + +print("Hello from imports!") +```""" + +PYTHON_WITH_FORBIDDEN_IMPORTS = """```python +import subprocess +import socket +import urllib + +print("Dangerous imports!") +```""" + +PYTHON_SIMPLE_PRINT = """```python +print("Hello, World!") +```""" + +PYTHON_INFINITE_LOOP = """```python +while True: + pass +```""" + +NO_PYTHON_CODE = """ +This is just text without any Python code blocks. +It contains no executable content. +""" + +# Create contexts +VALID_PYTHON_CTX = from_model(VALID_PYTHON_CODE) +SYNTAX_ERROR_CTX = from_model(PYTHON_WITH_SYNTAX_ERROR) +RUNTIME_ERROR_CTX = from_model(PYTHON_WITH_RUNTIME_ERROR) +PYTHON_WITH_IMPORTS_CTX = from_model(PYTHON_WITH_IMPORTS) +PYTHON_WITH_FORBIDDEN_IMPORTS_CTX = from_model(PYTHON_WITH_FORBIDDEN_IMPORTS) +PYTHON_SIMPLE_CTX = from_model(PYTHON_SIMPLE_PRINT) +PYTHON_INFINITE_LOOP_CTX = from_model(PYTHON_INFINITE_LOOP) +NO_PYTHON_CTX = from_model(NO_PYTHON_CODE) + + +# region: Code extraction tests + + +def test_has_python_code_listing_valid(): + """Test extraction of valid Python code.""" + result = _has_python_code_listing(VALID_PYTHON_CTX) + assert result.as_bool() is True + assert "def hello_world" in result.reason + + +def test_has_python_code_listing_no_code(): + """Test handling when no Python code is present.""" + result = _has_python_code_listing(NO_PYTHON_CTX) + assert result.as_bool() is False + assert "No Python code blocks found" in result.reason + + +def test_has_python_code_listing_simple(): + """Test extraction of simple Python code.""" + result = _has_python_code_listing(PYTHON_SIMPLE_CTX) + assert result.as_bool() is True + assert "print" in result.reason + + +# endregion + +# region: Safe mode tests (default behavior) + + +def test_safe_mode_default(): + """Test that safe mode is default and validates without executing.""" + req = PythonExecutesWithoutError() + result = req.validation_fn(VALID_PYTHON_CTX) + assert result.as_bool() is True + assert "safe mode" in result.reason + + +def test_safe_mode_syntax_error(): + """Test that safe mode catches syntax errors.""" + req = PythonExecutesWithoutError() + result = req.validation_fn(SYNTAX_ERROR_CTX) + assert result.as_bool() is False + + +def test_safe_mode_no_execution(): + """Test that safe mode doesn't execute code (even infinite loops).""" + req = PythonExecutesWithoutError(timeout=1) + result = req.validation_fn(PYTHON_INFINITE_LOOP_CTX) + assert result.as_bool() is True # Should pass because it's not actually executed + assert "safe mode" in result.reason + + +# endregion + +# region: Unsafe execution tests + + +def test_unsafe_execution_valid(): + """Test unsafe execution with valid code.""" + req = PythonExecutesWithoutError(allow_unsafe_execution=True, timeout=5) + result = req.validation_fn(VALID_PYTHON_CTX) + assert result.as_bool() is True + + +def test_unsafe_execution_runtime_error(): + """Test unsafe execution with runtime error.""" + req = PythonExecutesWithoutError(allow_unsafe_execution=True, timeout=5) + result = req.validation_fn(RUNTIME_ERROR_CTX) + assert result.as_bool() is False + assert "error" in result.reason.lower() + + +def test_unsafe_execution_timeout(): + """Test unsafe execution with timeout.""" + req = PythonExecutesWithoutError(allow_unsafe_execution=True, timeout=1) + result = req.validation_fn(PYTHON_INFINITE_LOOP_CTX) + assert result.as_bool() is False + assert "timed out" in result.reason.lower() + + +def test_unsafe_execution_syntax_error(): + """Test unsafe execution with syntax error.""" + req = PythonExecutesWithoutError(allow_unsafe_execution=True) + result = req.validation_fn(SYNTAX_ERROR_CTX) + assert result.as_bool() is False + + +# endregion + +# region: Import restriction tests + + +def test_import_restrictions_block_forbidden(): + """Test that import restrictions block forbidden imports.""" + req = PythonExecutesWithoutError( + allow_unsafe_execution=True, + allowed_imports=["os", "sys"] + ) + result = req.validation_fn(PYTHON_WITH_FORBIDDEN_IMPORTS_CTX) + assert result.as_bool() is False + assert "Unauthorized imports" in result.reason + + +def test_import_restrictions_allow_permitted(): + """Test that import restrictions allow permitted imports.""" + req = PythonExecutesWithoutError( + allow_unsafe_execution=True, + allowed_imports=["os", "sys", "pathlib"] + ) + result = req.validation_fn(PYTHON_WITH_IMPORTS_CTX) + assert result.as_bool() is True + + +def test_import_restrictions_with_safe_mode(): + """Test that import restrictions work with safe mode.""" + req = PythonExecutesWithoutError( + allowed_imports=["os", "sys"] + ) + result = req.validation_fn(PYTHON_WITH_FORBIDDEN_IMPORTS_CTX) + assert result.as_bool() is False + assert "Unauthorized imports" in result.reason + + +# endregion + +# region: Sandbox execution tests + + +@pytest.mark.skipif( + True, # Skip by default since Docker may not be available in CI + reason="Sandbox tests require Docker and llm-sandbox setup" +) +def test_sandbox_execution_valid(): + """Test sandbox execution with valid code.""" + req = PythonExecutesWithoutError(use_sandbox=True, timeout=10) + result = req.validation_fn(VALID_PYTHON_CTX) + assert result.as_bool() is True + assert "sandbox" in result.reason.lower() + + +@pytest.mark.skipif( + True, # Skip by default since Docker may not be available in CI + reason="Sandbox tests require Docker and llm-sandbox setup" +) +def test_sandbox_execution_with_imports(): + """Test sandbox execution with allowed imports.""" + req = PythonExecutesWithoutError( + use_sandbox=True, + allowed_imports=["os", "sys", "pathlib"], + timeout=10 + ) + result = req.validation_fn(PYTHON_WITH_IMPORTS_CTX) + assert result.as_bool() is True + + +@pytest.mark.skipif( + True, # Skip by default since Docker may not be available in CI + reason="Sandbox tests require Docker and llm-sandbox setup" +) +def test_sandbox_execution_timeout(): + """Test sandbox execution timeout.""" + req = PythonExecutesWithoutError(use_sandbox=True, timeout=2) + result = req.validation_fn(PYTHON_INFINITE_LOOP_CTX) + assert result.as_bool() is False + + +def test_sandbox_without_llm_sandbox_installed(): + """Test graceful handling when llm-sandbox is not available.""" + # This test will pass even if llm-sandbox is installed, but tests the error handling + req = PythonExecutesWithoutError(use_sandbox=True) + # We can't easily test this without mocking, but the error handling is in the code + assert req is not None + + +# endregion + +# region: Configuration tests + + +def test_description_updates_based_on_mode(): + """Test that requirement description reflects execution mode.""" + safe_req = PythonExecutesWithoutError() + assert "validation only" in safe_req.description + + unsafe_req = PythonExecutesWithoutError(allow_unsafe_execution=True, timeout=5) + assert "unsafe execution" in unsafe_req.description + assert "timeout: 5s" in unsafe_req.description + + sandbox_req = PythonExecutesWithoutError(use_sandbox=True, timeout=10) + assert "sandbox execution" in sandbox_req.description + assert "timeout: 10s" in sandbox_req.description + + +def test_parameter_combinations(): + """Test various parameter combinations.""" + # Safe mode (default) + req1 = PythonExecutesWithoutError() + assert req1._allow_unsafe is False + assert req1._use_sandbox is False + + # Unsafe mode + req2 = PythonExecutesWithoutError(allow_unsafe_execution=True) + assert req2._allow_unsafe is True + assert req2._use_sandbox is False + + # Sandbox mode + req3 = PythonExecutesWithoutError(use_sandbox=True) + assert req3._allow_unsafe is False + assert req3._use_sandbox is True + + # Sandbox + unsafe (sandbox takes precedence) + req4 = PythonExecutesWithoutError(allow_unsafe_execution=True, use_sandbox=True) + assert req4._allow_unsafe is True + assert req4._use_sandbox is True + + +# endregion + +# region: Integration tests + + +def test_direct_validation_function(): + """Test calling validation function directly.""" + result = _python_executes_without_error( + VALID_PYTHON_CTX, + timeout=5, + allow_unsafe=False, + use_sandbox=False + ) + assert result.as_bool() is True + assert "safe mode" in result.reason + + result = _python_executes_without_error( + SYNTAX_ERROR_CTX, + timeout=5, + allow_unsafe=False, + use_sandbox=False + ) + assert result.as_bool() is False + + +def test_no_code_extraction(): + """Test behavior when no code can be extracted.""" + req = PythonExecutesWithoutError() + result = req.validation_fn(NO_PYTHON_CTX) + assert result.as_bool() is False + assert "Could not extract Python code" in result.reason + + +# endregion diff --git a/uv.lock b/uv.lock index 7d9d045f..509e0e4b 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 3 +revision = 2 requires-python = ">=3.10" resolution-markers = [ "python_full_version >= '3.14' and sys_platform == 'darwin'", @@ -803,6 +803,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/12/b3/231ffd4ab1fc9d679809f356cebee130ac7daa00d6d6f3206dd4fd137e9e/distro-1.9.0-py3-none-any.whl", hash = "sha256:7bffd925d65168f85027d8da9af6bddab658135b840670a223589bc0c8ef02b2", size = 20277, upload-time = "2023-12-24T09:54:30.421Z" }, ] +[[package]] +name = "docker" +version = "7.1.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pywin32", marker = "sys_platform == 'win32'" }, + { name = "requests" }, + { name = "urllib3" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/91/9b/4a2ea29aeba62471211598dac5d96825bb49348fa07e906ea930394a83ce/docker-7.1.0.tar.gz", hash = "sha256:ad8c70e6e3f8926cb8a92619b832b4ea5299e2831c14284663184e200546fa6c", size = 117834, upload-time = "2024-05-23T11:13:57.216Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e3/26/57c6fb270950d476074c087527a558ccb6f4436657314bfb6cdf484114c4/docker-7.1.0-py3-none-any.whl", hash = "sha256:c96b93b7f0a746f9e77d325bcfb87422a3d8bd4f03136ae8a85b37f1898d5fc0", size = 147774, upload-time = "2024-05-23T11:13:55.01Z" }, +] + [[package]] name = "docling" version = "2.51.0" @@ -2154,6 +2168,23 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d0/d9/5f8ed27241b487f51f04573b8ba06d4460ebed9f792ff5cc148649fbf862/litellm-1.76.3-py3-none-any.whl", hash = "sha256:d62e3ff2a80ec5e551c6d7a0fe199ffe718ecb6cbaa43fc9250dd8d7c0944352", size = 9000797, upload-time = "2025-09-07T01:59:16.261Z" }, ] +[[package]] +name = "llm-sandbox" +version = "0.3.23" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pydantic" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/58/68/395f8a0aa5cc0cab73b8cb08c3a4531d834068e9c385f4f53fac41414ace/llm_sandbox-0.3.23.tar.gz", hash = "sha256:0062ac66b73c777a24958fbd03a3d4fc321d147cdb39e73214881012d8bd6b5f", size = 495511, upload-time = "2025-10-21T04:37:55.684Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/6b/fc/b8dc6099cf411c16f3f19ad0c6e6618e2b2b374640bc34f568dd4cd4fe39/llm_sandbox-0.3.23-py3-none-any.whl", hash = "sha256:08a9b61075d36bfa3258c779012bc852d9a27b40d5cab604acd7d57ca0e94cba", size = 69825, upload-time = "2025-10-21T04:37:54.4Z" }, +] + +[package.optional-dependencies] +docker = [ + { name = "docker" }, +] + [[package]] name = "lomond" version = "0.3.3" @@ -2380,6 +2411,7 @@ dependencies = [ { name = "huggingface-hub" }, { name = "jinja2" }, { name = "json5" }, + { name = "llm-sandbox", extra = ["docker"] }, { name = "math-verify" }, { name = "mistletoe" }, { name = "ollama" }, @@ -2465,6 +2497,7 @@ requires-dist = [ { name = "jinja2" }, { name = "json5" }, { name = "litellm", marker = "extra == 'litellm'", specifier = ">=1.76" }, + { name = "llm-sandbox", extras = ["docker"], specifier = ">=0.3.23" }, { name = "math-verify" }, { name = "mellea", extras = ["watsonx", "docling", "hf", "litellm"], marker = "extra == 'all'" }, { name = "mistletoe", specifier = ">=1.4.0" },