diff --git a/mellea/stdlib/reqlib/python.py b/mellea/stdlib/reqlib/python.py new file mode 100644 index 00000000..5447e16c --- /dev/null +++ b/mellea/stdlib/reqlib/python.py @@ -0,0 +1,896 @@ +"""Requirements for Python code generation validation.""" + +import ast +import importlib.util +import re +import subprocess +import sys +import tempfile +from abc import ABC, abstractmethod +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +from mellea.helpers.fancy_logger import FancyLogger +from mellea.stdlib.base import Context +from mellea.stdlib.requirement import Requirement, ValidationResult + +logger = FancyLogger.get_logger() + +# region execution backends + + +@dataclass +class ExecutionResult: + """Result of code execution.""" + success: bool + message: str | None = None + error: str | None = None + skipped: bool = False + + +class ExecutionBackend(ABC): + """Abstract backend for executing Python code.""" + @abstractmethod + def execute(self, code: str, timeout: int) -> ExecutionResult: + """Execute code and return result.""" + + +class SafeBackend(ExecutionBackend): + """Safe backend that validates but does not execute code.""" + def execute(self, code: str, timeout: int) -> ExecutionResult: + """Validate code syntax without executing.""" + try: + ast.parse(code) + return ExecutionResult( + success=True, + skipped=True, + message="Code validated but not executed (safe mode)" + ) + except SyntaxError as e: + return ExecutionResult(success=False, error=str(e)) + + +class UnsafeBackend(ExecutionBackend): + """Unsafe backend that executes code directly with subprocess.""" + def __init__(self, allowed_imports: list[str] | None = None): + """Initialize with optional import restrictions.""" + self.allowed_imports = allowed_imports + + def execute(self, code: str, timeout: int) -> ExecutionResult: + """Execute code with subprocess after checking imports.""" + if self.allowed_imports and not _check_allowed_imports(code, self.allowed_imports): + return ExecutionResult(success=False, error="Unauthorized imports detected") + + return self._execute_subprocess(code, timeout) + + def _execute_subprocess(self, code: str, timeout: int) -> ExecutionResult: + with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f: + f.write(code) + temp_file = f.name + + try: + result = subprocess.run( + [sys.executable, temp_file], capture_output=True, text=True, timeout=timeout + ) + + if result.returncode == 0: + return ExecutionResult(success=True, message="Code executed successfully") + else: + return ExecutionResult( + success=False, + error=f"Execution failed with error: {result.stderr[:200]}", + ) + except subprocess.TimeoutExpired: + return ExecutionResult( + success=False, error=f"Execution timed out after {timeout} seconds" + ) + except Exception as e: + return ExecutionResult(success=False, error=f"Execution error: {e!s}") + finally: + try: + Path(temp_file).unlink() + except Exception: + pass + + +def _check_allowed_imports(code: str, allowed_imports: list[str]) -> bool: + tree = ast.parse(code) + for node in ast.walk(tree): + if isinstance(node, ast.Import): + for alias in node.names: + base_module = alias.name.split(".")[0] + if base_module not in allowed_imports: + return False + elif isinstance(node, ast.ImportFrom): + if node.module: + base_module = node.module.split(".")[0] + if base_module not in allowed_imports: + return False + return True + + +# endregion + +# region code extraction + + +def _score_code_block(code: str, context_text: str = "") -> int: + """Score a code block to determine if it's likely the mainI like i answer. + + Returns higher score for blocks that: + - Are longer (more substantial) + - Contain function/class definitions (not just tests) + - Don't look like test code + - Appear after positive language cues + """ + score = 0 + + # Length bonus (longer is usually the main implementation) + score += len(code) // 50 # 1 point per ~50 chars + + # Has function/class definitions (not just calls) + if "def " in code: + score += 10 + if "class " in code: + score += 10 + + # Penalty for test code indicators + test_indicators = ["unittest", "pytest", "assert ", "test_", "TestCase", "def test"] + if any(indicator in code for indicator in test_indicators): + score -= 20 + + # Bonus if preceded by positive language + if context_text: + positive_phrases = ["here's the", "correct", "solution", "final", "here is the"] + # Check text before this code block + if any(phrase in context_text.lower() for phrase in positive_phrases): + score += 5 + + # Penalty for being preceded by negative language + if context_text: + negative_phrases = ["wrong", "bad", "don't", "avoid", "incorrect", "won't work"] + if any(phrase in context_text.lower() for phrase in negative_phrases): + score -= 15 + + return score + + +def extract_python_code(text: str) -> str | None: + """Extract Python code from markdown code blocks or plain text. + + Uses intelligent extraction strategy: + 1. Finds all ```python...``` blocks + 2. Scores each block (prefers longer, non-test code after positive cues) + 3. Returns highest-scoring block + 4. Falls back to generic blocks or raw text + + Returns None if no Python code found. + """ + # Try explicit python code blocks first + python_block_pattern = r"```python\s*\n(.*?)```" + matches = re.findall(python_block_pattern, text, re.DOTALL) + + if matches: + if len(matches) == 1: + return matches[0].strip() + + # Multiple blocks - need to be smart about which one + best_block = None + best_score = -999 + + # Find positions of each match to get context + for match in matches: + # Get text before this code block for context + match_pos = text.find(f"```python\n{match}") + context_before = text[max(0, match_pos - 200) : match_pos] + + score = _score_code_block(match, context_before) + + if score > best_score: + best_score = score + best_block = match + + return best_block.strip() if best_block else matches[0].strip() + + # Try generic code blocks + generic_block_pattern = r"```\s*\n(.*?)```" + matches = re.findall(generic_block_pattern, text, re.DOTALL) + if matches: + # Check if any look like Python + for match in matches: + candidate = match.strip() + if any( + keyword in candidate + for keyword in [ + "def ", + "class ", + "import ", + "from ", + "if ", + "for ", + "while ", + ] + ): + return candidate + + # If no code blocks, check if entire text looks like Python + stripped_text = text.strip() + if any( + keyword in stripped_text for keyword in ["def ", "class ", "import ", "from "] + ): + return stripped_text + + return None + + +def _has_python_code_listing(ctx: Context) -> ValidationResult: + """Validate that context contains extractable Python code.""" + last_output = ctx.last_output() + if last_output is None or last_output.value is None: + return ValidationResult(result=False, reason="No output found in context") + + code = extract_python_code(last_output.value) + if code is None: + return ValidationResult( + result=False, reason="No Python code block found in output" + ) + + return ValidationResult( + result=True, + reason=code, # Return extracted code for downstream use + ) + + +class HasPythonCodeListing(Requirement): + """Verifies that the output contains a valid Python code listing.""" + + def __init__(self): + """Initialize the Python code listing validator.""" + super().__init__( + description="The result should contain a Python code listing in markdown format or as plain code.", + validation_fn=_has_python_code_listing, + check_only=True, + ) + + +# endregion + +# region syntax validation + + +def _python_code_parses(ctx: Context) -> ValidationResult: + """Validate that extracted Python code is syntactically valid using AST.""" + # First extract the code + extraction_result = _has_python_code_listing(ctx) + if not extraction_result.as_bool(): + return ValidationResult( + result=False, + reason=extraction_result.reason or "Could not extract Python code", + ) + + code = extraction_result.reason # Code is stored in reason field + assert code is not None + + try: + ast.parse(code) + return ValidationResult(result=True, reason="Python code parses successfully") + except SyntaxError as e: + return ValidationResult( + result=False, reason=f"Syntax error at line {e.lineno}: {e.msg}" + ) + except Exception as e: + return ValidationResult(result=False, reason=f"Parse error: {e!s}") + + +class PythonCodeParses(Requirement): + """Verifies that the Python code is syntactically valid.""" + + def __init__(self): + """Initialize the Python code parser validator.""" + super().__init__( + description="The Python code should be syntactically valid and parseable.", + validation_fn=_python_code_parses, + check_only=True, + ) + + +# endregion + +# region import validation + + +def get_imported_modules(code: str) -> list[str]: + """Extract all imported module names from Python code.""" + try: + tree = ast.parse(code) + except SyntaxError: + return [] + + modules = [] + for node in ast.walk(tree): + if isinstance(node, ast.Import): + for alias in node.names: + modules.append(alias.name.split(".")[0]) # Get top-level module + elif isinstance(node, ast.ImportFrom): + if node.module: + modules.append(node.module.split(".")[0]) + + return list(set(modules)) # Remove duplicates + + +def is_module_available(module_name: str, venv_path: str | None = None) -> bool: + """Check if a module is available in the system or specified venv.""" + if venv_path: + # Check in specified venv using pip list + try: + result = subprocess.run( + [f"{venv_path}/bin/python", "-m", "pip", "list", "--format=freeze"], + capture_output=True, + text=True, + timeout=5, + ) + installed = [ + line.split("==")[0].lower() for line in result.stdout.split("\n") + ] + return module_name.lower() in installed + except Exception: + return False + else: + # Check in current environment + return importlib.util.find_spec(module_name) is not None + + +def _python_valid_imports( + ctx: Context, venv_path: str | None = None +) -> ValidationResult: + """Validate that all imports in Python code are available.""" + # First extract and parse the code + extraction_result = _has_python_code_listing(ctx) + if not extraction_result.as_bool(): + return ValidationResult( + result=False, reason="Could not extract Python code for import validation" + ) + + code = extraction_result.reason + assert code is not None + + # Check if code parses + try: + ast.parse(code) + except SyntaxError: + return ValidationResult( + result=False, reason="Code has syntax errors, cannot validate imports" + ) + + modules = get_imported_modules(code) + if not modules: + # No imports is valid + return ValidationResult(result=True, reason="No imports to validate") + + unavailable_modules = [] + for module in modules: + if not is_module_available(module, venv_path): + unavailable_modules.append(module) + + if unavailable_modules: + return ValidationResult( + result=False, + reason=f"Unavailable modules: {', '.join(unavailable_modules)}", + ) + + return ValidationResult( + result=True, reason=f"All imports valid: {', '.join(modules)}" + ) + + +class PythonValidImports(Requirement): + """Verifies that all import statements reference available packages.""" + + def __init__(self, venv_path: str | None = None): + """Initialize import validator. + + Args: + venv_path: Optional path to virtual environment to check imports against. + If None, checks against current Python environment. + """ + self._venv_path = venv_path + super().__init__( + description=f"All import statements should use packages available in {'specified venv' if venv_path else 'current environment'}.", + validation_fn=lambda ctx: _python_valid_imports(ctx, self._venv_path), + check_only=True, + ) + + +# endregion + +# region execution validation + + +def _python_executes_without_error( + ctx: Context, + timeout: int = 5, + allow_unsafe: bool = False, + allowed_imports: list[str] | None = None, +) -> ValidationResult: + """Validate that Python code executes without raising exceptions.""" + extraction_result = _has_python_code_listing(ctx) + if not extraction_result.as_bool(): + return ValidationResult( + result=False, reason="Could not extract Python code for execution" + ) + + code = extraction_result.reason + assert code is not None + + backend: ExecutionBackend + if allow_unsafe: + backend = UnsafeBackend(allowed_imports=allowed_imports) + else: + backend = SafeBackend() + + result = backend.execute(code, timeout) + return ValidationResult( + result=result.success, + reason=result.message or result.error + ) + + +class PythonExecutesWithoutError(Requirement): + """Verifies that Python code runs without raising exceptions.""" + + def __init__( + self, + timeout: int = 5, + allow_unsafe_execution: bool = False, + allowed_imports: list[str] | None = None, + ): + """Initialize execution validator. + + Args: + timeout: Maximum seconds to allow code to run before timing out. + allow_unsafe_execution: If True, execute code directly with subprocess (unsafe). + allowed_imports: List of allowed import modules when using unsafe execution. + """ + self._timeout = timeout + self._allow_unsafe = allow_unsafe_execution + self._allowed_imports = allowed_imports + + if allow_unsafe_execution: + logger.warning("⚠️ UNSAFE: Executing untrusted code directly. Only use with trusted sources!") + + execution_mode = "validation only" if not allow_unsafe_execution else f"timeout: {timeout}s" + super().__init__( + description=f"The Python code should execute without errors ({execution_mode}).", + validation_fn=lambda ctx: _python_executes_without_error( + ctx, self._timeout, self._allow_unsafe, self._allowed_imports + ), + check_only=True, + ) + + +# endregion + +# region structural validation + + +def _python_has_function_def(ctx: Context) -> ValidationResult: + """Validate that Python code contains at least one function definition.""" + extraction_result = _has_python_code_listing(ctx) + if not extraction_result.as_bool(): + return ValidationResult(result=False, reason="Could not extract Python code") + + code = extraction_result.reason + assert code is not None + + try: + tree = ast.parse(code) + except SyntaxError: + return ValidationResult(result=False, reason="Code has syntax errors") + + function_names = [] + for node in ast.walk(tree): + if isinstance(node, ast.FunctionDef): + function_names.append(node.name) + + if function_names: + return ValidationResult( + result=True, + reason=f"Found {len(function_names)} function(s): {', '.join(function_names)}", + ) + else: + return ValidationResult( + result=False, reason="No function definitions found in code" + ) + + +class PythonHasFunctionDef(Requirement): + """Verifies that Python code contains at least one function definition.""" + + def __init__(self): + """Initialize the function definition validator.""" + super().__init__( + description="The Python code should define at least one function.", + validation_fn=_python_has_function_def, + check_only=True, + ) + + +# endregion + +# region completeness validation + + +def _python_has_class_def(ctx: Context) -> ValidationResult: + """Validate that Python code contains at least one class definition.""" + extraction_result = _has_python_code_listing(ctx) + if not extraction_result.as_bool(): + return ValidationResult(result=False, reason="Could not extract Python code") + + code = extraction_result.reason + assert code is not None + + try: + tree = ast.parse(code) + except SyntaxError: + return ValidationResult(result=False, reason="Code has syntax errors") + + class_names = [] + for node in ast.walk(tree): + if isinstance(node, ast.ClassDef): + class_names.append(node.name) + + if class_names: + return ValidationResult( + result=True, + reason=f"Found {len(class_names)} class(es): {', '.join(class_names)}", + ) + else: + return ValidationResult( + result=False, reason="No class definitions found in code" + ) + + +class PythonHasClassDef(Requirement): + """Verifies that Python code contains at least one class definition.""" + + def __init__(self): + """Initialize the class definition validator.""" + super().__init__( + description="The Python code should define at least one class.", + validation_fn=_python_has_class_def, + check_only=True, + ) + + +# endregion + +# region correctness validation + + +def _python_matches_examples( + ctx: Context, function_name: str, examples: list[tuple[dict, Any]] +) -> ValidationResult: + """Validate that Python function produces correct outputs for given examples. + + Args: + ctx: Context containing the code + function_name: Name of the function to test + examples: List of (input_kwargs, expected_output) tuples + + Returns: + ValidationResult indicating if all examples passed + """ + extraction_result = _has_python_code_listing(ctx) + if not extraction_result.as_bool(): + return ValidationResult(result=False, reason="Could not extract Python code") + + code = extraction_result.reason + assert code is not None + + # Check if code parses + try: + ast.parse(code) + except SyntaxError as e: + return ValidationResult(result=False, reason=f"Code has syntax errors: {e}") + + # Execute code in isolated namespace + namespace: dict[str, Any] = {} + try: + exec(code, namespace) + except Exception as e: + return ValidationResult(result=False, reason=f"Code execution failed: {e!s}") + + # Check if function exists + if function_name not in namespace: + return ValidationResult( + result=False, reason=f"Function '{function_name}' not found in code" + ) + + func = namespace[function_name] + if not callable(func): + return ValidationResult( + result=False, reason=f"'{function_name}' is not callable" + ) + + # Test all examples + failed_examples = [] + for i, (inputs, expected) in enumerate(examples): + try: + result = func(**inputs) + if result != expected: + failed_examples.append( + f"Example {i + 1}: {function_name}({inputs}) = {result}, expected {expected}" + ) + except Exception as e: + failed_examples.append( + f"Example {i + 1}: {function_name}({inputs}) raised {type(e).__name__}: {e!s}" + ) + + if failed_examples: + return ValidationResult( + result=False, + reason=f"Failed {len(failed_examples)}/{len(examples)} examples:\n" + + "\n".join(failed_examples), + ) + + return ValidationResult( + result=True, reason=f"All {len(examples)} examples passed", score=1.0 + ) + + +class PythonMatchesExamples(Requirement): + """Verifies that generated Python function produces correct outputs for given examples. + + This is a lightweight functional correctness checker that tests the generated + function against specific input/output examples. + """ + + def __init__(self, function_name: str, examples: list[tuple[dict, Any]]): + """Initialize example-based correctness validator. + + Args: + function_name: Name of the function to test + examples: List of (input_kwargs, expected_output) tuples. + For example: [ + ({"n": 5}, 120), # factorial(5) should return 120 + ({"n": 0}, 1), # factorial(0) should return 1 + ] + """ + self._function_name = function_name + self._examples = examples + super().__init__( + description=f"The function '{function_name}' should produce correct outputs for {len(examples)} test examples.", + validation_fn=lambda ctx: _python_matches_examples( + ctx, self._function_name, self._examples + ), + check_only=True, + ) + + +# endregion + +# region docstring verification + + +def _python_matches_docstring( + ctx: Context, + docstring: str, + function_name: str | None, + tests: list[tuple[dict, Any]], +) -> ValidationResult: + """Validate that Python code matches a docstring specification using provided tests. + + Args: + ctx: Context containing the code + docstring: The specification/docstring (for reference) + function_name: Name of the function to test (if None, will try to infer) + tests: Pre-generated test cases + + Returns: + ValidationResult indicating if code matches specification + """ + # Extract the code + extraction_result = _has_python_code_listing(ctx) + if not extraction_result.as_bool(): + return ValidationResult( + result=False, reason="Could not extract Python code from output" + ) + + code = extraction_result.reason + assert code is not None + + # Parse code to find function name if not provided + if function_name is None: + try: + tree = ast.parse(code) + functions = [ + node.name + for node in ast.walk(tree) + if isinstance(node, ast.FunctionDef) + ] + if not functions: + return ValidationResult( + result=False, reason="No function found in generated code" + ) + function_name = functions[0] # Use first function found + except SyntaxError: + return ValidationResult( + result=False, reason="Code has syntax errors, cannot identify function" + ) + + # Run the tests using existing logic + return _python_matches_examples(ctx, function_name, tests) + + +class PythonMatchesDocstring(Requirement): + """Verifies that generated Python code matches a docstring specification. + + This validator uses an LLM to generate test cases from a docstring/specification, + then validates the generated code against those test cases. + + The tests are generated once when you call generate_tests() and cached for reuse + across multiple validation calls during rejection sampling. + + Example usage: + ```python + with start_session("ollama", "granite3.3:8b") as session: + verifier = PythonMatchesDocstring( + "Calculate factorial of n", + num_tests=5 + ) + # Generate tests once (uses active session's LLM) + verifier.generate_tests() + + # Now use in rejection sampling (no additional LLM calls) + result = session.instruct( + "Write a factorial function", + requirements=[verifier] + ) + ``` + """ + + def __init__( + self, + docstring: str, + function_name: str | None = None, + num_tests: int = 5, + tests: list[tuple[dict, Any]] | None = None, + ): + """Initialize docstring-based correctness validator. + + Args: + docstring: The specification/docstring describing expected behavior. + function_name: Name of the function to test. If None, will infer from code. + num_tests: Number of test cases to generate (default: 5). + tests: Pre-generated tests. If provided, skips LLM generation. + Format: list of (input_dict, expected_output) tuples. + """ + self._docstring = docstring + self._function_name = function_name + self._num_tests = num_tests + self._cached_tests: list[tuple[dict, Any]] | None = tests + + super().__init__( + description=f"The code should match the specification: {docstring[:100]}{'...' if len(docstring) > 100 else ''}", + validation_fn=lambda ctx: self._validate(ctx), + check_only=True, + ) + + def generate_tests(self, temperature: float = 0.3): + """Generate test cases from the docstring using the active LLM session. + + This should be called once before using the verifier. Requires an active session. + + Args: + temperature: LLM temperature for test generation (default: 0.3 for consistency). + + Raises: + RuntimeError: If no active session is found. + """ + if self._cached_tests is not None: + return # Already have tests + + from mellea.stdlib.session import get_session + + session = get_session() + + # Infer function name if needed (use placeholder for generation) + func_name = self._function_name or "function" + + test_prompt = f"""Given this function specification, generate {self._num_tests} diverse test cases. + +Specification: +{self._docstring} + +Function name: {func_name} + +Generate test cases as a JSON array. Each test has "inputs" (dict of param names to values) and "expected_output". + +Example format: +[ + {{"inputs": {{"n": 5}}, "expected_output": 120}}, + {{"inputs": {{"n": 0}}, "expected_output": 1}} +] + +Focus on: +1. Normal cases (typical inputs) +2. Edge cases (boundaries, empty values) +3. Different data types if applicable + +Output ONLY the JSON array, no other text.""" + + # Generate using session.instruct (handles async properly) + result = session.instruct( + test_prompt, model_options={"temperature": temperature} + ) + + # Parse JSON response + import json + + test_json_str = result.value + assert test_json_str is not None + + # Extract JSON from markdown blocks if present + if "```json" in test_json_str: + test_json_str = test_json_str.split("```json")[1].split("```")[0].strip() + elif "```" in test_json_str: + test_json_str = test_json_str.split("```")[1].split("```")[0].strip() + + # Fix common Python->JSON inconsistencies + test_json_str = test_json_str.replace(": None", ": null") + test_json_str = test_json_str.replace(": True", ": true") + test_json_str = test_json_str.replace(": False", ": false") + + try: + test_data = json.loads(test_json_str) + except json.JSONDecodeError as e: + raise ValueError( + f"Failed to parse test JSON. LLM output:\n{test_json_str[:500]}\nError: {e}" + ) + + # Convert to expected format with flexible key names + self._cached_tests = [] + for i, test in enumerate(test_data): + # Skip tests that expect errors (we only test success cases) + if "expected_error" in test or "error" in test: + continue + + # Handle various possible key names + inputs = test.get("inputs") or test.get("input") or test.get("args") or {} + expected = ( + test.get("expected_output") + or test.get("output") + or test.get("expected") + or test.get("result") + ) + + if expected is None: + # Skip if no expected output (might be an error case) + continue + + # Skip if expected output looks like an error message + if isinstance(expected, str) and expected.lower().startswith("error"): + continue + + self._cached_tests.append((inputs, expected)) + + if not self._cached_tests: + raise ValueError( + f"No valid test cases found. LLM response:\n{test_json_str[:500]}" + ) + + def _validate(self, ctx: Context) -> ValidationResult: + """Validation function.""" + if self._cached_tests is None: + return ValidationResult( + result=False, + reason="Tests not generated. Call generate_tests() first or provide tests in constructor.", + ) + + return _python_matches_docstring( + ctx, self._docstring, self._function_name, self._cached_tests + ) + + +# endregion diff --git a/test/stdlib_basics/test_reqlib_python.py b/test/stdlib_basics/test_reqlib_python.py new file mode 100644 index 00000000..3a8f576b --- /dev/null +++ b/test/stdlib_basics/test_reqlib_python.py @@ -0,0 +1,982 @@ +"""Tests for Python code verifiers - basic functionality and edge cases.""" + +import pytest + +from mellea.stdlib.base import ModelOutputThunk, Context, ChatContext +from mellea.stdlib.reqlib.python import ( + HasPythonCodeListing, + PythonCodeParses, + PythonValidImports, + PythonExecutesWithoutError, + PythonHasFunctionDef, + PythonHasClassDef, + PythonMatchesExamples, + extract_python_code, +) + + +def from_model(s: str) -> Context: + """Helper to create a context with model output.""" + ctx = ChatContext() + ctx = ctx.add(ModelOutputThunk(value=s, meta={"test": True})) + return ctx + + +# region: Basic test contexts + +VALID_PYTHON_MARKDOWN_CTX = from_model( + """ +Here's a simple Python function: + +```python +def hello_world(): + print("Hello, world!") + return 42 +``` + +This function prints a greeting. +""" +) + +VALID_PYTHON_GENERIC_BLOCK_CTX = from_model( + """ +``` +def greet(name): + return f"Hello, {name}!" +``` +""" +) + +VALID_PYTHON_PLAIN_CTX = from_model( + """ +def add(a, b): + return a + b +""" +) + +INVALID_SYNTAX_CTX = from_model( + """ +```python +def broken_function( + print("missing closing paren") +``` +""" +) + +PYTHON_WITH_IMPORTS_CTX = from_model( + """ +```python +import os +import sys +from pathlib import Path + +def get_home(): + return Path.home() +``` +""" +) + +PYTHON_WITH_INVALID_IMPORTS_CTX = from_model( + """ +```python +import nonexistent_package_xyz +import another_fake_module + +def foo(): + pass +``` +""" +) + +PYTHON_EXECUTABLE_CTX = from_model( + """ +```python +def multiply(x, y): + return x * y + +if __name__ == "__main__": + result = multiply(3, 4) + print(f"Result: {result}") +``` +""" +) + +PYTHON_INFINITE_LOOP_CTX = from_model( + """ +```python +while True: + pass +``` +""" +) + +PYTHON_WITH_CLASS_CTX = from_model( + """ +```python +class Calculator: + def add(self, a, b): + return a + b + + def subtract(self, a, b): + return a - b +``` +""" +) + +NO_CODE_CTX = from_model( + """ +This is just a text response with no code at all. +""" +) + +# endregion + +# region: Edge case contexts + +MULTIPLE_CODE_BLOCKS_CTX = from_model( + """ +Here's the correct version: +```python +def good(): + return "correct" +``` + +And here's a test for it: +```python +def test_good(): + assert good() == "correct" +``` +""" +) + +# Edge case: Bad code shown first, then good code +BAD_THEN_GOOD_CTX = from_model( + """ +Here's a wrong approach (don't use this): +```python +def add(a, b): + return str(a) + str(b) # Wrong! Concatenates +``` + +Here's the correct way: +```python +def add(a, b): + return a + b +``` +""" +) + +# Edge case: Evolution (simple then improved) +EVOLUTION_CTX = from_model( + """ +Let's start simple: +```python +def add(a, b): + return a + b +``` + +Now with type hints and docstring: +```python +def add(a: int, b: int) -> int: + '''Add two numbers.''' + return a + b +``` +""" +) + +MIXED_LANGUAGE_CTX = from_model( + """ +```javascript +function hello() { + console.log("Hello"); +} +``` + +```python +def greet(): + print("Hello") +``` +""" +) + +COMPLEX_CODE_CTX = from_model( + ''' +```python +"""Module docstring.""" + +def fibonacci(n: int) -> list[int]: + """ + Calculate fibonacci sequence. + + Args: + n: Number of terms + + Returns: + List of fibonacci numbers + """ + # Initialize first two terms + if n <= 0: + return [] + elif n == 1: + return [0] + + fib = [0, 1] + for i in range(2, n): + fib.append(fib[i-1] + fib[i-2]) + return fib +``` +''' +) + +MULTILINE_STRING_CTX = from_model( + ''' +```python +def format_message(name, age): + message = f""" + Hello {name}! + You are {age} years old. + """ + return message.strip() +``` +''' +) + +NESTED_FUNCTIONS_CTX = from_model( + """ +```python +def outer(x): + def inner(y): + return x + y + return inner + +add_five = outer(5) +result = add_five(3) +``` +""" +) + +DECORATORS_CTX = from_model( + """ +```python +def my_decorator(func): + def wrapper(*args, **kwargs): + return func(*args, **kwargs) + return wrapper + +@my_decorator +def greet(name): + return f"Hello {name}" +``` +""" +) + +COMPREHENSIONS_CTX = from_model( + """ +```python +def process_data(): + squares = [x**2 for x in range(10)] + evens = {x: x**2 for x in range(10) if x % 2 == 0} + return squares, evens +``` +""" +) + +EXCEPTION_HANDLING_CTX = from_model( + """ +```python +def safe_divide(a, b): + try: + return a / b + except ZeroDivisionError: + return None + except TypeError as e: + raise ValueError(f"Invalid types: {e}") + finally: + print("Division attempted") +``` +""" +) + +CONTEXT_MANAGER_CTX = from_model( + """ +```python +class MyContext: + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + pass + +def use_context(): + with MyContext() as ctx: + pass +``` +""" +) + +ASYNC_CODE_CTX = from_model( + """ +```python +import asyncio + +async def fetch_data(): + await asyncio.sleep(1) + return "data" + +async def main(): + result = await fetch_data() + return result +``` +""" +) + +GENERATOR_CTX = from_model( + """ +```python +def count_up_to(n): + i = 0 + while i < n: + yield i + i += 1 +``` +""" +) + +LAMBDA_CTX = from_model( + """ +```python +add = lambda x, y: x + y +squares = list(map(lambda x: x**2, range(5))) +``` +""" +) + +ADVANCED_TYPE_HINTS_CTX = from_model( + """ +```python +from typing import Union, Optional, List, Dict + +def process(data: Union[str, int], config: Optional[Dict[str, any]] = None) -> List[str]: + return [str(data)] +``` +""" +) + +CODE_IN_STRING_CTX = from_model( + """ +```python +def example(): + code_snippet = "def bad( syntax error" + return code_snippet +``` +""" +) + +EMPTY_FUNCTION_CTX = from_model( + """ +```python +def placeholder(): + pass +``` +""" +) + +ONLY_IMPORTS_CTX = from_model( + """ +```python +import os +import sys +from pathlib import Path + +BASE_DIR = Path(__file__).parent +``` +""" +) + +COMPLEX_CLASS_CTX = from_model( + """ +```python +class Calculator: + def __init__(self): + self._value = 0 + + @property + def value(self): + return self._value + + @staticmethod + def add(a, b): + return a + b + + @classmethod + def create(cls): + return cls() +``` +""" +) + +RUNTIME_ERROR_CTX = from_model( + """ +```python +def will_fail(): + undefined_variable = nonexistent_var + 1 + return undefined_variable + +# Actually call it to trigger the error +result = will_fail() +``` +""" +) + +RELATIVE_IMPORT_CTX = from_model( + """ +```python +from . import sibling_module +from ..parent import something + +def use_imports(): + pass +``` +""" +) + +# endregion + +# region: Basic extraction tests + + +def test_extract_python_from_markdown(): + code = extract_python_code(VALID_PYTHON_MARKDOWN_CTX.last_output().value) + assert code is not None + assert "def hello_world():" in code + + +def test_extract_python_from_generic_block(): + code = extract_python_code(VALID_PYTHON_GENERIC_BLOCK_CTX.last_output().value) + assert code is not None + assert "def greet(name):" in code + + +def test_extract_python_from_plain_text(): + code = extract_python_code(VALID_PYTHON_PLAIN_CTX.last_output().value) + assert code is not None + assert "def add(a, b):" in code + + +def test_extract_python_no_code(): + code = extract_python_code(NO_CODE_CTX.last_output().value) + assert code is None + + +def test_extract_multiple_code_blocks(): + """Should extract the main function, not the test.""" + code = extract_python_code(MULTIPLE_CODE_BLOCKS_CTX.last_output().value) + assert code is not None + assert "def good" in code + assert "test_good" not in code # Should not extract the test + + +def test_extract_bad_then_good(): + """Should extract the GOOD code even though bad code comes first.""" + code = extract_python_code(BAD_THEN_GOOD_CTX.last_output().value) + assert code is not None + assert "return a + b" in code + assert "str(a)" not in code # Should avoid the bad version + + +def test_extract_evolution(): + """Should extract the more complete version when code evolves.""" + code = extract_python_code(EVOLUTION_CTX.last_output().value) + assert code is not None + # Should prefer the longer, more complete version with type hints + assert "int" in code or "Add two numbers" in code # Either is fine + + +def test_extract_from_mixed_languages(): + """Should extract Python code, not JavaScript.""" + code = extract_python_code(MIXED_LANGUAGE_CTX.last_output().value) + assert code is not None + assert "def greet" in code + assert "function hello" not in code + + +# endregion + +# region: HasPythonCodeListing tests + + +def test_has_python_code_listing_valid(): + requirement = HasPythonCodeListing() + result = requirement.validation_fn(VALID_PYTHON_MARKDOWN_CTX) + assert result.as_bool() is True + assert "def hello_world():" in result.reason + + +def test_has_python_code_listing_invalid(): + requirement = HasPythonCodeListing() + result = requirement.validation_fn(NO_CODE_CTX) + assert result.as_bool() is False + + +# endregion + +# region: PythonCodeParses tests + + +def test_python_code_parses_valid(): + requirement = PythonCodeParses() + result = requirement.validation_fn(VALID_PYTHON_MARKDOWN_CTX) + assert result.as_bool() is True + + +def test_python_code_parses_invalid(): + requirement = PythonCodeParses() + result = requirement.validation_fn(INVALID_SYNTAX_CTX) + assert result.as_bool() is False + assert "Syntax error" in result.reason or "Parse error" in result.reason + + +def test_python_code_parses_no_code(): + requirement = PythonCodeParses() + result = requirement.validation_fn(NO_CODE_CTX) + assert result.as_bool() is False + + +def test_complex_code_with_docstrings(): + req = PythonCodeParses() + result = req.validation_fn(COMPLEX_CODE_CTX) + assert result.as_bool() is True + + +def test_multiline_strings(): + req = PythonCodeParses() + result = req.validation_fn(MULTILINE_STRING_CTX) + assert result.as_bool() is True + + +def test_nested_functions(): + req = PythonCodeParses() + result = req.validation_fn(NESTED_FUNCTIONS_CTX) + assert result.as_bool() is True + + +def test_decorators(): + req = PythonCodeParses() + result = req.validation_fn(DECORATORS_CTX) + assert result.as_bool() is True + + +def test_comprehensions(): + req = PythonCodeParses() + result = req.validation_fn(COMPREHENSIONS_CTX) + assert result.as_bool() is True + + +def test_exception_handling(): + req = PythonCodeParses() + result = req.validation_fn(EXCEPTION_HANDLING_CTX) + assert result.as_bool() is True + + +def test_context_managers(): + req = PythonCodeParses() + result = req.validation_fn(CONTEXT_MANAGER_CTX) + assert result.as_bool() is True + + +def test_async_code(): + req = PythonCodeParses() + result = req.validation_fn(ASYNC_CODE_CTX) + assert result.as_bool() is True + + +def test_generators(): + req = PythonCodeParses() + result = req.validation_fn(GENERATOR_CTX) + assert result.as_bool() is True + + +def test_lambda_functions(): + req = PythonCodeParses() + result = req.validation_fn(LAMBDA_CTX) + assert result.as_bool() is True + + +def test_advanced_type_hints(): + req = PythonCodeParses() + result = req.validation_fn(ADVANCED_TYPE_HINTS_CTX) + assert result.as_bool() is True + + +def test_code_in_string(): + """Code with syntax errors in strings should still parse.""" + req = PythonCodeParses() + result = req.validation_fn(CODE_IN_STRING_CTX) + assert result.as_bool() is True + + +def test_relative_imports_parse(): + """Relative imports should parse fine.""" + req = PythonCodeParses() + result = req.validation_fn(RELATIVE_IMPORT_CTX) + assert result.as_bool() is True + + +def test_runtime_error_code_parses(): + """Code with runtime errors should still parse successfully.""" + req = PythonCodeParses() + result = req.validation_fn(RUNTIME_ERROR_CTX) + assert result.as_bool() is True + + +# endregion + +# region: PythonValidImports tests + + +def test_python_valid_imports_stdlib(): + requirement = PythonValidImports() + result = requirement.validation_fn(PYTHON_WITH_IMPORTS_CTX) + assert result.as_bool() is True + + +def test_python_valid_imports_invalid(): + requirement = PythonValidImports() + result = requirement.validation_fn(PYTHON_WITH_INVALID_IMPORTS_CTX) + assert result.as_bool() is False + assert "nonexistent_package_xyz" in result.reason + + +# endregion + +# region: PythonExecutesWithoutError tests + + +def test_python_executes_without_error_valid(): + requirement = PythonExecutesWithoutError(timeout=2, allow_unsafe_execution=True) + result = requirement.validation_fn(PYTHON_EXECUTABLE_CTX) + assert result.as_bool() is True + + +def test_python_executes_without_error_timeout(): + requirement = PythonExecutesWithoutError(timeout=1, allow_unsafe_execution=True) + result = requirement.validation_fn(PYTHON_INFINITE_LOOP_CTX) + assert result.as_bool() is False + assert "timed out" in result.reason.lower() + + +def test_python_executes_without_error_syntax(): + requirement = PythonExecutesWithoutError() + result = requirement.validation_fn(INVALID_SYNTAX_CTX) + assert result.as_bool() is False + + +def test_runtime_error_code_fails_execution(): + """Code with runtime errors should fail execution test.""" + req = PythonExecutesWithoutError(timeout=2, allow_unsafe_execution=True) + result = req.validation_fn(RUNTIME_ERROR_CTX) + assert result.as_bool() is False + + +def test_safe_mode_default(): + """Safe mode should be default and not execute code.""" + req = PythonExecutesWithoutError() + result = req.validation_fn(PYTHON_EXECUTABLE_CTX) + assert result.as_bool() is True + assert "safe mode" in result.reason + + +def test_safe_mode_syntax_error(): + """Safe mode should catch syntax errors.""" + req = PythonExecutesWithoutError() + result = req.validation_fn(INVALID_SYNTAX_CTX) + assert result.as_bool() is False + + +def test_unsafe_execution_with_flag(): + """Unsafe execution should work when explicitly enabled.""" + req = PythonExecutesWithoutError(allow_unsafe_execution=True, timeout=2) + result = req.validation_fn(PYTHON_EXECUTABLE_CTX) + assert result.as_bool() is True + + +def test_unsafe_execution_with_import_restrictions(): + """Import restrictions should block unauthorized imports.""" + req = PythonExecutesWithoutError( + allow_unsafe_execution=True, + allowed_imports=["math", "json"] + ) + result = req.validation_fn(PYTHON_WITH_INVALID_IMPORTS_CTX) + assert result.as_bool() is False + assert "Unauthorized imports" in result.reason + + +def test_unsafe_execution_with_allowed_imports(): + """Allowed imports should pass validation.""" + req = PythonExecutesWithoutError( + allow_unsafe_execution=True, + allowed_imports=["os", "sys", "pathlib"] + ) + result = req.validation_fn(PYTHON_WITH_IMPORTS_CTX) + assert result.as_bool() is True + + +# endregion + +# region: PythonHasFunctionDef tests + + +def test_python_has_function_def_valid(): + requirement = PythonHasFunctionDef() + result = requirement.validation_fn(VALID_PYTHON_MARKDOWN_CTX) + assert result.as_bool() is True + assert "hello_world" in result.reason + + +def test_python_has_function_def_invalid(): + requirement = PythonHasFunctionDef() + result = requirement.validation_fn(from_model("```python\nx = 5\n```")) + assert result.as_bool() is False + + +def test_decorators_have_functions(): + has_func = PythonHasFunctionDef() + assert has_func.validation_fn(DECORATORS_CTX).as_bool() is True + + +def test_empty_function(): + req = PythonCodeParses() + result = req.validation_fn(EMPTY_FUNCTION_CTX) + assert result.as_bool() is True + + has_func = PythonHasFunctionDef() + assert has_func.validation_fn(EMPTY_FUNCTION_CTX).as_bool() is True + + +def test_only_imports_no_functions(): + req = PythonCodeParses() + result = req.validation_fn(ONLY_IMPORTS_CTX) + assert result.as_bool() is True + + has_func = PythonHasFunctionDef() + assert has_func.validation_fn(ONLY_IMPORTS_CTX).as_bool() is False + + +# endregion + +# region: PythonHasClassDef tests + + +def test_python_has_class_def_valid(): + requirement = PythonHasClassDef() + result = requirement.validation_fn(PYTHON_WITH_CLASS_CTX) + assert result.as_bool() is True + assert "Calculator" in result.reason + + +def test_python_has_class_def_invalid(): + requirement = PythonHasClassDef() + result = requirement.validation_fn(VALID_PYTHON_MARKDOWN_CTX) + assert result.as_bool() is False + + +def test_context_managers_have_classes(): + has_class = PythonHasClassDef() + assert has_class.validation_fn(CONTEXT_MANAGER_CTX).as_bool() is True + + +def test_complex_class(): + req = PythonCodeParses() + result = req.validation_fn(COMPLEX_CLASS_CTX) + assert result.as_bool() is True + + has_class = PythonHasClassDef() + assert has_class.validation_fn(COMPLEX_CLASS_CTX).as_bool() is True + + +# endregion + +# region: Integration tests + + +def test_full_validation_pipeline(): + """Test chaining multiple validators.""" + ctx = PYTHON_WITH_CLASS_CTX + + # Should have code + has_code = HasPythonCodeListing() + assert has_code.validation_fn(ctx).as_bool() is True + + # Should parse + parses = PythonCodeParses() + assert parses.validation_fn(ctx).as_bool() is True + + # Should have class + has_class = PythonHasClassDef() + assert has_class.validation_fn(ctx).as_bool() is True + + +def test_chained_validation_complex_code(): + """Test all validators on complex real-world code.""" + ctx = COMPLEX_CODE_CTX + + has_code = HasPythonCodeListing() + assert has_code.validation_fn(ctx).as_bool() is True + + parses = PythonCodeParses() + assert parses.validation_fn(ctx).as_bool() is True + + has_func = PythonHasFunctionDef() + assert has_func.validation_fn(ctx).as_bool() is True + + +# endregion + +# region: PythonMatchesExamples tests + +FACTORIAL_CODE_CTX = from_model( + """ +```python +def factorial(n): + if n == 0: + return 1 + return n * factorial(n - 1) +``` +""" +) + +INCORRECT_FACTORIAL_CODE_CTX = from_model( + """ +```python +def factorial(n): + # Incorrect implementation - off by one + if n == 0: + return 0 + return n * factorial(n - 1) +``` +""" +) + +FIBONACCI_CODE_CTX = from_model( + """ +```python +def fibonacci(n): + if n <= 0: + return [] + elif n == 1: + return [0] + + fib = [0, 1] + for i in range(2, n): + fib.append(fib[i-1] + fib[i-2]) + return fib +``` +""" +) + +ADD_FUNCTION_CTX = from_model( + """ +```python +def add(a, b): + return a + b +``` +""" +) + + +def test_matches_examples_correct(): + """Test that correct function passes all examples.""" + req = PythonMatchesExamples( + function_name="factorial", + examples=[ + ({"n": 0}, 1), + ({"n": 1}, 1), + ({"n": 5}, 120), + ] + ) + result = req.validation_fn(FACTORIAL_CODE_CTX) + assert result.as_bool() is True + assert "All 3 examples passed" in result.reason + + +def test_matches_examples_incorrect(): + """Test that incorrect function fails examples.""" + req = PythonMatchesExamples( + function_name="factorial", + examples=[ + ({"n": 0}, 1), # Will fail - function returns 0 + ({"n": 5}, 120), + ] + ) + result = req.validation_fn(INCORRECT_FACTORIAL_CODE_CTX) + assert result.as_bool() is False + assert "Failed" in result.reason + + +def test_matches_examples_function_not_found(): + """Test behavior when function doesn't exist.""" + req = PythonMatchesExamples( + function_name="nonexistent", + examples=[({"n": 5}, 120)] + ) + result = req.validation_fn(FACTORIAL_CODE_CTX) + assert result.as_bool() is False + assert "not found" in result.reason + + +def test_matches_examples_list_output(): + """Test with functions that return lists.""" + req = PythonMatchesExamples( + function_name="fibonacci", + examples=[ + ({"n": 0}, []), + ({"n": 1}, [0]), + ({"n": 5}, [0, 1, 1, 2, 3]), + ] + ) + result = req.validation_fn(FIBONACCI_CODE_CTX) + assert result.as_bool() is True + + +def test_matches_examples_multiple_args(): + """Test with functions that take multiple arguments.""" + req = PythonMatchesExamples( + function_name="add", + examples=[ + ({"a": 2, "b": 3}, 5), + ({"a": 0, "b": 0}, 0), + ({"a": -1, "b": 1}, 0), + ] + ) + result = req.validation_fn(ADD_FUNCTION_CTX) + assert result.as_bool() is True + + +def test_matches_examples_runtime_error(): + """Test behavior when function raises exception.""" + req = PythonMatchesExamples( + function_name="factorial", + examples=[ + ({"n": -1}, 1), # Will cause infinite recursion + ] + ) + # This should catch the exception and report failure + result = req.validation_fn(FACTORIAL_CODE_CTX) + assert result.as_bool() is False + + +def test_matches_examples_no_code(): + """Test behavior with no code.""" + req = PythonMatchesExamples( + function_name="foo", + examples=[({"x": 1}, 1)] + ) + result = req.validation_fn(NO_CODE_CTX) + assert result.as_bool() is False + + +# endregion + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])