diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml new file mode 100644 index 0000000..a1f39c1 --- /dev/null +++ b/.github/workflows/test.yml @@ -0,0 +1,179 @@ +name: Test Suite + +on: + push: + branches: [ main, develop ] + pull_request: + branches: [ main, develop ] + +jobs: + test: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.12"] + + steps: + - uses: actions/checkout@v4 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + + - name: Install uv + uses: astral-sh/setup-uv@v2 + with: + version: "latest" + + - name: Set up virtual environment + run: | + uv venv + echo "VIRTUAL_ENV=.venv" >> $GITHUB_ENV + echo "$PWD/.venv/bin" >> $GITHUB_PATH + + - name: Install dependencies + run: | + uv sync --extra dev + + - name: Run linting + run: | + uv run ruff check . + uv run ruff format --check . + + - name: Run unit tests + run: | + uv run pytest tests/unit/ -v --cov=grainchain --cov-report=xml --cov-report=term-missing + + - name: Run integration tests (local only) + run: | + uv run pytest tests/integration/test_local_provider.py -v -m "not slow" + + - name: Upload coverage to Codecov + uses: codecov/codecov-action@v3 + with: + file: ./coverage.xml + flags: unittests + name: codecov-umbrella + fail_ci_if_error: false + + test-with-providers: + runs-on: ubuntu-latest + if: github.event_name == 'push' && (github.ref == 'refs/heads/main' || github.ref == 'refs/heads/develop') + + steps: + - uses: actions/checkout@v4 + + - name: Set up Python 3.12 + uses: actions/setup-python@v4 + with: + python-version: "3.12" + + - name: Install uv + uses: astral-sh/setup-uv@v2 + with: + version: "latest" + + - name: Set up virtual environment + run: | + uv venv + echo "VIRTUAL_ENV=.venv" >> $GITHUB_ENV + echo "$PWD/.venv/bin" >> $GITHUB_PATH + + - name: Install dependencies with all providers + run: | + uv sync --extra dev --extra all + + - name: Run integration tests with E2B + if: env.E2B_API_KEY != '' + env: + E2B_API_KEY: ${{ secrets.E2B_API_KEY }} + run: | + uv run pytest tests/integration/test_e2b_provider.py -v -m "not slow" + + - name: Run integration tests with Modal + if: env.MODAL_TOKEN_ID != '' && env.MODAL_TOKEN_SECRET != '' + env: + MODAL_TOKEN_ID: ${{ secrets.MODAL_TOKEN_ID }} + MODAL_TOKEN_SECRET: ${{ secrets.MODAL_TOKEN_SECRET }} + run: | + uv run pytest tests/integration/test_modal_provider.py -v -m "not slow" + + - name: Run integration tests with Daytona + if: env.DAYTONA_API_KEY != '' + env: + DAYTONA_API_KEY: ${{ secrets.DAYTONA_API_KEY }} + run: | + uv run pytest tests/integration/test_daytona_provider.py -v -m "not slow" + + performance-test: + runs-on: ubuntu-latest + if: github.event_name == 'push' && github.ref == 'refs/heads/main' + + steps: + - uses: actions/checkout@v4 + + - name: Set up Python 3.12 + uses: actions/setup-python@v4 + with: + python-version: "3.12" + + - name: Install uv + uses: astral-sh/setup-uv@v2 + with: + version: "latest" + + - name: Set up virtual environment + run: | + uv venv + echo "VIRTUAL_ENV=.venv" >> $GITHUB_ENV + echo "$PWD/.venv/bin" >> $GITHUB_PATH + + - name: Install dependencies + run: | + uv sync --extra dev --extra benchmark + + - name: Run performance tests + run: | + uv run pytest tests/ -v -m "slow" --timeout=300 + + - name: Run benchmark + run: | + uv run grainchain benchmark --provider local --output benchmarks/results/ + + security-scan: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + + - name: Set up Python 3.12 + uses: actions/setup-python@v4 + with: + python-version: "3.12" + + - name: Install uv + uses: astral-sh/setup-uv@v2 + with: + version: "latest" + + - name: Set up virtual environment + run: | + uv venv + echo "VIRTUAL_ENV=.venv" >> $GITHUB_ENV + echo "$PWD/.venv/bin" >> $GITHUB_PATH + + - name: Install dependencies + run: | + uv sync --extra dev + + - name: Run security scan with bandit + run: | + uv run pip install bandit[toml] + uv run bandit -r grainchain/ -f json -o bandit-report.json || true + + - name: Upload security scan results + uses: actions/upload-artifact@v3 + with: + name: security-scan-results + path: bandit-report.json diff --git a/README.md b/README.md index 6935687..7f8be74 100644 --- a/README.md +++ b/README.md @@ -370,6 +370,95 @@ grainchain benchmark --provider local ./scripts/benchmark_status.sh ``` +### Testing + +Grainchain includes a comprehensive test suite with >90% code coverage to ensure reliability across all providers. + +### Running Tests + +```bash +# Run all tests +uv run pytest + +# Run only unit tests +uv run pytest tests/unit/ -v + +# Run only integration tests +uv run pytest tests/integration/ -v + +# Run tests with coverage +uv run pytest --cov=grainchain --cov-report=html + +# Run tests for specific provider +uv run pytest tests/integration/test_local_provider.py -v + +# Run performance tests +uv run pytest -m slow + +# Run tests excluding slow tests +uv run pytest -m "not slow" +``` + +### Test Categories + +- **Unit Tests** (`tests/unit/`): Fast, isolated tests for core functionality + - `test_sandbox.py`: Core Sandbox class tests + - `test_providers.py`: Provider implementation tests + - `test_config.py`: Configuration management tests + - `test_exceptions.py`: Exception handling tests + - `test_interfaces.py`: Interface and data structure tests + +- **Integration Tests** (`tests/integration/`): Real provider interactions + - `test_e2b_provider.py`: E2B provider integration tests + - `test_modal_provider.py`: Modal provider integration tests + - `test_daytona_provider.py`: Daytona provider integration tests + - `test_local_provider.py`: Local provider integration tests + +### Test Configuration + +Tests are configured via `pytest.ini` with the following markers: + +- `unit`: Unit tests (fast, no external dependencies) +- `integration`: Integration tests (require provider credentials) +- `slow`: Slow tests that may take longer to run +- `e2b`: Tests requiring E2B provider +- `modal`: Tests requiring Modal provider +- `daytona`: Tests requiring Daytona provider +- `local`: Tests requiring Local provider +- `snapshot`: Tests for snapshot functionality + +### Provider Credentials for Integration Tests + +To run integration tests with real providers, set these environment variables: + +```bash +# E2B +export E2B_API_KEY=your-e2b-api-key + +# Modal +export MODAL_TOKEN_ID=your-modal-token-id +export MODAL_TOKEN_SECRET=your-modal-token-secret + +# Daytona +export DAYTONA_API_KEY=your-daytona-api-key +``` + +### Continuous Integration + +Tests run automatically on GitHub Actions for: + +- **Python 3.12** on pull requests and main branch pushes +- **Integration tests** with real providers on main branch pushes +- **Performance tests** and benchmarks on releases +- **Security scans** with bandit on all commits + +### Coverage Requirements + +- Minimum coverage: **90%** +- All new code must include tests +- Integration tests must cover happy path and error scenarios +- Performance tests ensure operations complete within expected timeframes + ### CLI Commands Grainchain includes a comprehensive CLI for development: diff --git a/benchmarks/scripts/benchmark_runner.py b/benchmarks/scripts/benchmark_runner.py index 081ed86..f32a826 100755 --- a/benchmarks/scripts/benchmark_runner.py +++ b/benchmarks/scripts/benchmark_runner.py @@ -198,9 +198,9 @@ def take_snapshot(self, snapshot_name: str) -> dict[str, Any]: ) if result.exit_code == 0: size_output = result.output.decode().strip() - snapshot["metrics"]["filesystem"][ - "node_modules_size" - ] = size_output.split()[0] + snapshot["metrics"]["filesystem"]["node_modules_size"] = ( + size_output.split()[0] + ) # Package count result = self.container.exec_run( @@ -224,9 +224,9 @@ def take_snapshot(self, snapshot_name: str) -> dict[str, Any]: } if result.exit_code != 0: - snapshot["metrics"]["performance"][ - "build_error" - ] = result.output.decode() + snapshot["metrics"]["performance"]["build_error"] = ( + result.output.decode() + ) # Test run time (if tests exist) start_time = time.time() diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 0000000..1b4d16b --- /dev/null +++ b/pytest.ini @@ -0,0 +1,28 @@ +[pytest] +asyncio_mode = auto +asyncio_default_fixture_loop_scope = function +testpaths = tests +python_files = test_*.py +python_classes = Test* +python_functions = test_* +addopts = + --cov=grainchain + --cov-report=term-missing + --cov-report=html:htmlcov + --cov-report=xml + --cov-fail-under=90 + --strict-markers + --strict-config + -v +markers = + unit: Unit tests + integration: Integration tests + slow: Slow tests that may take longer to run + e2b: Tests requiring E2B provider + modal: Tests requiring Modal provider + daytona: Tests requiring Daytona provider + local: Tests requiring Local provider + snapshot: Tests for snapshot functionality +filterwarnings = + ignore::DeprecationWarning + ignore::PendingDeprecationWarning diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..5f20b9d --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,427 @@ +"""Pytest configuration and fixtures for Grainchain tests.""" + +import asyncio +import os +import tempfile +from collections.abc import AsyncGenerator, Generator +from pathlib import Path + +import pytest + +from grainchain import Sandbox, SandboxConfig +from grainchain.core.config import ConfigManager, ProviderConfig +from grainchain.core.exceptions import GrainchainError +from grainchain.core.interfaces import ( + ExecutionResult, + FileInfo, + SandboxSession, + SandboxStatus, +) +from grainchain.providers.base import BaseSandboxProvider, BaseSandboxSession + + +# Test configuration +@pytest.fixture +def test_config() -> SandboxConfig: + """Create a test sandbox configuration.""" + return SandboxConfig( + timeout=30, + memory_limit="1GB", + cpu_limit=1.0, + working_directory="/tmp", + environment_vars={"TEST_ENV": "test_value"}, + auto_cleanup=True, + ) + + +@pytest.fixture +def provider_config() -> ProviderConfig: + """Create a test provider configuration.""" + return ProviderConfig( + name="test_provider", + config={ + "api_key": "test_key", + "timeout": 30, + "test_setting": "test_value", + }, + ) + + +@pytest.fixture +def config_manager() -> ConfigManager: + """Create a test configuration manager.""" + # Create a temporary config file + config_data = { + "default_provider": "test", + "providers": { + "test": { + "api_key": "test_key", + "timeout": 30, + }, + "e2b": { + "api_key": "test_e2b_key", + "template": "python", + }, + "modal": { + "token_id": "test_modal_id", + "token_secret": "test_modal_secret", + }, + "daytona": { + "api_key": "test_daytona_key", + }, + "local": {}, + }, + "sandbox_defaults": { + "timeout": 180, + "working_directory": "/workspace", + "auto_cleanup": True, + }, + } + + manager = ConfigManager() + manager._config = config_data + manager._init_providers() + return manager + + +# Mock providers and sessions +class MockSandboxSession(BaseSandboxSession): + """Mock sandbox session for testing.""" + + def __init__( + self, sandbox_id: str, provider: "MockSandboxProvider", config: SandboxConfig + ): + super().__init__(sandbox_id, provider, config) + self._set_status(SandboxStatus.RUNNING) + self.executed_commands = [] + self.uploaded_files = {} + self.snapshots = {} + + async def execute( + self, + command: str, + timeout: int = None, + working_dir: str = None, + environment: dict[str, str] = None, + ) -> ExecutionResult: + """Mock command execution.""" + self._ensure_not_closed() + + # Record the command + self.executed_commands.append( + { + "command": command, + "timeout": timeout, + "working_dir": working_dir, + "environment": environment, + } + ) + + # Simulate different command responses + if command == "echo 'Hello, World!'": + return ExecutionResult( + command=command, + return_code=0, + stdout="Hello, World!\n", + stderr="", + execution_time=0.1, + success=True, + ) + elif command.startswith("python -c"): + return ExecutionResult( + command=command, + return_code=0, + stdout="Python output\n", + stderr="", + execution_time=0.2, + success=True, + ) + elif command == "exit 1": + return ExecutionResult( + command=command, + return_code=1, + stdout="", + stderr="Command failed\n", + execution_time=0.1, + success=False, + ) + elif "timeout" in command: + # Simulate timeout + raise asyncio.TimeoutError("Command timed out") + else: + return ExecutionResult( + command=command, + return_code=0, + stdout=f"Output for: {command}\n", + stderr="", + execution_time=0.1, + success=True, + ) + + async def upload_file( + self, path: str, content: bytes | str, mode: str = "w" + ) -> None: + """Mock file upload.""" + self._ensure_not_closed() + if isinstance(content, str): + content = content.encode() + self.uploaded_files[path] = {"content": content, "mode": mode} + + async def download_file(self, path: str) -> bytes: + """Mock file download.""" + self._ensure_not_closed() + if path in self.uploaded_files: + return self.uploaded_files[path]["content"] + elif path == "/test/existing_file.txt": + return b"Existing file content" + else: + raise FileNotFoundError(f"File not found: {path}") + + async def list_files(self, path: str = "/") -> list[FileInfo]: + """Mock file listing.""" + self._ensure_not_closed() + files = [] + for file_path in self.uploaded_files: + if file_path.startswith(path): + content = self.uploaded_files[file_path]["content"] + files.append( + FileInfo( + name=Path(file_path).name, + path=file_path, + size=len(content), + is_directory=False, + modified_time=None, + ) + ) + + # Add some default files + if path == "/" or path == "/test": + files.append( + FileInfo( + name="existing_file.txt", + path="/test/existing_file.txt", + size=20, + is_directory=False, + modified_time=None, + ) + ) + + return files + + async def create_snapshot(self) -> str: + """Mock snapshot creation.""" + self._ensure_not_closed() + snapshot_id = f"snapshot_{len(self.snapshots)}" + self.snapshots[snapshot_id] = { + "files": self.uploaded_files.copy(), + "commands": self.executed_commands.copy(), + } + return snapshot_id + + async def restore_snapshot(self, snapshot_id: str) -> None: + """Mock snapshot restoration.""" + self._ensure_not_closed() + if snapshot_id not in self.snapshots: + raise ValueError(f"Snapshot not found: {snapshot_id}") + + snapshot = self.snapshots[snapshot_id] + self.uploaded_files = snapshot["files"].copy() + # Note: We don't restore commands as they're historical + + async def _cleanup(self) -> None: + """Mock cleanup.""" + pass + + +class MockSandboxProvider(BaseSandboxProvider): + """Mock sandbox provider for testing.""" + + def __init__(self, config: ProviderConfig, name: str = "mock"): + super().__init__(config) + self._name = name + self.created_sessions = [] + self.cleanup_called = False + + @property + def name(self) -> str: + return self._name + + async def _create_session(self, config: SandboxConfig) -> SandboxSession: + """Create a mock sandbox session.""" + sandbox_id = f"mock_sandbox_{len(self.created_sessions)}" + session = MockSandboxSession(sandbox_id, self, config) + self.created_sessions.append(session) + return session + + async def cleanup(self) -> None: + """Mock cleanup method.""" + self.cleanup_called = True + await super().cleanup() + + +@pytest.fixture +def mock_provider(provider_config: ProviderConfig) -> MockSandboxProvider: + """Create a mock sandbox provider.""" + return MockSandboxProvider(provider_config) + + +@pytest.fixture +async def mock_session( + mock_provider: MockSandboxProvider, test_config: SandboxConfig +) -> AsyncGenerator[MockSandboxSession, None]: + """Create a mock sandbox session.""" + session = await mock_provider._create_session(test_config) + try: + yield session + finally: + await session.close() + + +@pytest.fixture +async def mock_sandbox( + mock_provider: MockSandboxProvider, test_config: SandboxConfig +) -> AsyncGenerator[Sandbox, None]: + """Create a mock sandbox instance.""" + sandbox = Sandbox(provider=mock_provider, config=test_config) + async with sandbox: + yield sandbox + + +# Environment fixtures +@pytest.fixture +def temp_dir() -> Generator[Path, None, None]: + """Create a temporary directory for tests.""" + with tempfile.TemporaryDirectory() as temp_dir: + yield Path(temp_dir) + + +@pytest.fixture +def env_vars() -> Generator[dict[str, str], None, None]: + """Set up test environment variables.""" + original_env = os.environ.copy() + test_env = { + "GRAINCHAIN_DEFAULT_PROVIDER": "test", + "E2B_API_KEY": "test_e2b_key", + "MODAL_TOKEN_ID": "test_modal_id", + "MODAL_TOKEN_SECRET": "test_modal_secret", + "DAYTONA_API_KEY": "test_daytona_key", + } + + # Set test environment variables + for key, value in test_env.items(): + os.environ[key] = value + + try: + yield test_env + finally: + # Restore original environment + os.environ.clear() + os.environ.update(original_env) + + +# Error simulation fixtures +@pytest.fixture +def failing_provider(provider_config: ProviderConfig) -> MockSandboxProvider: + """Create a provider that fails operations.""" + + class FailingProvider(MockSandboxProvider): + async def _create_session(self, config: SandboxConfig) -> SandboxSession: + raise GrainchainError("Simulated provider failure") + + return FailingProvider(provider_config, "failing") + + +@pytest.fixture +def timeout_provider(provider_config: ProviderConfig) -> MockSandboxProvider: + """Create a provider that times out.""" + + class TimeoutProvider(MockSandboxProvider): + async def _create_session(self, config: SandboxConfig) -> SandboxSession: + await asyncio.sleep(10) # This will timeout in tests + return await super()._create_session(config) + + return TimeoutProvider(provider_config, "timeout") + + +# Integration test fixtures (only if providers are available) +@pytest.fixture +def e2b_available() -> bool: + """Check if E2B provider is available.""" + try: + import e2b # noqa: F401 + + return bool(os.getenv("E2B_API_KEY")) + except ImportError: + return False + + +@pytest.fixture +def modal_available() -> bool: + """Check if Modal provider is available.""" + try: + import modal # noqa: F401 + + return bool(os.getenv("MODAL_TOKEN_ID") and os.getenv("MODAL_TOKEN_SECRET")) + except ImportError: + return False + + +@pytest.fixture +def daytona_available() -> bool: + """Check if Daytona provider is available.""" + try: + import daytona_sdk # noqa: F401 + + return bool(os.getenv("DAYTONA_API_KEY")) + except ImportError: + return False + + +# Async test utilities +@pytest.fixture +def event_loop(): + """Create an event loop for async tests.""" + loop = asyncio.new_event_loop() + yield loop + loop.close() + + +# Test data fixtures +@pytest.fixture +def sample_python_code() -> str: + """Sample Python code for testing.""" + return """ +import sys +import json + +def main(): + data = {"message": "Hello from Python!", "version": sys.version} + print(json.dumps(data)) + +if __name__ == "__main__": + main() +""" + + +@pytest.fixture +def sample_files() -> dict[str, str]: + """Sample files for testing.""" + return { + "hello.py": "print('Hello, World!')", + "data.json": '{"test": "data", "numbers": [1, 2, 3]}', + "script.sh": "#!/bin/bash\necho 'Shell script executed'", + "requirements.txt": "requests==2.31.0\npandas==2.0.3", + } + + +# Performance testing fixtures +@pytest.fixture +def performance_config() -> SandboxConfig: + """Configuration optimized for performance testing.""" + return SandboxConfig( + timeout=10, # Shorter timeout for performance tests + memory_limit="512MB", + cpu_limit=0.5, + working_directory="/tmp", + auto_cleanup=True, + ) diff --git a/tests/integration/test_daytona_provider.py b/tests/integration/test_daytona_provider.py new file mode 100644 index 0000000..13b699d --- /dev/null +++ b/tests/integration/test_daytona_provider.py @@ -0,0 +1,469 @@ +"""Integration tests for Daytona provider.""" + +import pytest + +from grainchain import Sandbox, SandboxConfig +from grainchain.core.config import ProviderConfig +from grainchain.core.exceptions import ConfigurationError, ProviderError, SandboxError + + +class TestDaytonaProviderIntegration: + """Integration tests for Daytona provider.""" + + @pytest.mark.integration + @pytest.mark.daytona + @pytest.mark.skipif( + not pytest.importorskip("daytona_sdk", reason="Daytona SDK not installed"), + reason="Daytona SDK not available", + ) + async def test_daytona_provider_real_connection(self, daytona_available): + """Test real Daytona provider connection (requires API key).""" + if not daytona_available: + pytest.skip("Daytona API key not available") + + config = SandboxConfig(timeout=60) + + async with Sandbox(provider="daytona", config=config) as sandbox: + assert sandbox.provider_name == "daytona" + assert sandbox.sandbox_id is not None + + # Test basic command execution + result = await sandbox.execute("echo 'Hello from Daytona!'") + assert result.return_code == 0 + assert "Hello from Daytona!" in result.stdout + + @pytest.mark.integration + @pytest.mark.daytona + @pytest.mark.skipif( + not pytest.importorskip("daytona_sdk", reason="Daytona SDK not installed"), + reason="Daytona SDK not available", + ) + async def test_daytona_development_environment(self, daytona_available): + """Test Daytona development environment features.""" + if not daytona_available: + pytest.skip("Daytona API key not available") + + config = SandboxConfig(timeout=60) + + async with Sandbox(provider="daytona", config=config) as sandbox: + # Test common development tools + tools_to_test = ["git", "curl", "wget", "vim"] + + for tool in tools_to_test: + result = await sandbox.execute(f"which {tool}") + assert result.return_code == 0, f"{tool} not available" + + @pytest.mark.integration + @pytest.mark.daytona + @pytest.mark.skipif( + not pytest.importorskip("daytona_sdk", reason="Daytona SDK not installed"), + reason="Daytona SDK not available", + ) + async def test_daytona_python_execution(self, daytona_available): + """Test Python code execution on Daytona.""" + if not daytona_available: + pytest.skip("Daytona API key not available") + + config = SandboxConfig(timeout=60) + + async with Sandbox(provider="daytona", config=config) as sandbox: + # Test Python execution + python_code = "import sys; print(f'Python {sys.version_info.major}.{sys.version_info.minor}')" + result = await sandbox.execute(f'python3 -c "{python_code}"') + + assert result.return_code == 0 + assert "Python" in result.stdout + + @pytest.mark.integration + @pytest.mark.daytona + @pytest.mark.skipif( + not pytest.importorskip("daytona_sdk", reason="Daytona SDK not installed"), + reason="Daytona SDK not available", + ) + async def test_daytona_file_operations(self, daytona_available): + """Test file operations with Daytona provider.""" + if not daytona_available: + pytest.skip("Daytona API key not available") + + config = SandboxConfig(timeout=60) + + async with Sandbox(provider="daytona", config=config) as sandbox: + # Upload a file + test_content = "Hello, Daytona file system!" + await sandbox.upload_file("/tmp/test_file.txt", test_content) + + # Verify file exists + result = await sandbox.execute("cat /tmp/test_file.txt") + assert result.return_code == 0 + assert test_content in result.stdout + + # Download the file + downloaded = await sandbox.download_file("/tmp/test_file.txt") + assert downloaded.decode() == test_content + + # List files + files = await sandbox.list_files("/tmp") + file_names = [f.name for f in files] + assert "test_file.txt" in file_names + + @pytest.mark.integration + @pytest.mark.daytona + @pytest.mark.skipif( + not pytest.importorskip("daytona_sdk", reason="Daytona SDK not installed"), + reason="Daytona SDK not available", + ) + async def test_daytona_git_operations(self, daytona_available): + """Test Git operations in Daytona workspace.""" + if not daytona_available: + pytest.skip("Daytona API key not available") + + config = SandboxConfig(timeout=120) # Git operations might take longer + + async with Sandbox(provider="daytona", config=config) as sandbox: + # Initialize a git repository + result = await sandbox.execute("git init /tmp/test_repo") + assert result.return_code == 0 + + # Configure git + await sandbox.execute( + "cd /tmp/test_repo && git config user.email 'test@example.com'" + ) + await sandbox.execute( + "cd /tmp/test_repo && git config user.name 'Test User'" + ) + + # Create and commit a file + await sandbox.upload_file("/tmp/test_repo/README.md", "# Test Repository") + result = await sandbox.execute("cd /tmp/test_repo && git add README.md") + assert result.return_code == 0 + + result = await sandbox.execute( + "cd /tmp/test_repo && git commit -m 'Initial commit'" + ) + assert result.return_code == 0 + + @pytest.mark.integration + @pytest.mark.daytona + @pytest.mark.skipif( + not pytest.importorskip("daytona_sdk", reason="Daytona SDK not installed"), + reason="Daytona SDK not available", + ) + async def test_daytona_package_installation(self, daytona_available): + """Test package installation on Daytona.""" + if not daytona_available: + pytest.skip("Daytona API key not available") + + config = SandboxConfig(timeout=120) # Longer timeout for package installation + + async with Sandbox(provider="daytona", config=config) as sandbox: + # Install a package + result = await sandbox.execute("pip install requests") + assert result.return_code == 0 + + # Test the package + result = await sandbox.execute( + "python3 -c 'import requests; print(requests.__version__)'" + ) + assert result.return_code == 0 + assert len(result.stdout.strip()) > 0 + + @pytest.mark.integration + @pytest.mark.daytona + @pytest.mark.skipif( + not pytest.importorskip("daytona_sdk", reason="Daytona SDK not installed"), + reason="Daytona SDK not available", + ) + async def test_daytona_error_handling(self, daytona_available): + """Test error handling with Daytona provider.""" + if not daytona_available: + pytest.skip("Daytona API key not available") + + config = SandboxConfig(timeout=30) + + async with Sandbox(provider="daytona", config=config) as sandbox: + # Test command that fails + result = await sandbox.execute("exit 1") + assert result.return_code == 1 + + # Test non-existent command + result = await sandbox.execute("nonexistent_command_12345") + assert result.return_code != 0 + + @pytest.mark.integration + @pytest.mark.daytona + @pytest.mark.skipif( + not pytest.importorskip("daytona_sdk", reason="Daytona SDK not installed"), + reason="Daytona SDK not available", + ) + async def test_daytona_timeout_handling(self, daytona_available): + """Test timeout handling with Daytona provider.""" + if not daytona_available: + pytest.skip("Daytona API key not available") + + config = SandboxConfig(timeout=5) # Short timeout + + async with Sandbox(provider="daytona", config=config) as sandbox: + # This should timeout + with pytest.raises(SandboxError): # Could be TimeoutError or SandboxError + await sandbox.execute("sleep 10") + + @pytest.mark.integration + @pytest.mark.daytona + async def test_daytona_invalid_api_key(self): + """Test Daytona provider with invalid API key.""" + # Create config with invalid API key + provider_config = ProviderConfig("daytona", {"api_key": "invalid_key"}) + + from grainchain.providers.daytona import DaytonaProvider + + provider = DaytonaProvider(provider_config) + config = SandboxConfig(timeout=30) + + with pytest.raises(ProviderError): + await provider.create_sandbox(config) + + @pytest.mark.integration + @pytest.mark.daytona + async def test_daytona_missing_api_key(self): + """Test Daytona provider without API key.""" + provider_config = ProviderConfig("daytona", {}) + + with pytest.raises(ConfigurationError, match="Daytona API key is required"): + from grainchain.providers.daytona import DaytonaProvider + + DaytonaProvider(provider_config) + + @pytest.mark.integration + @pytest.mark.daytona + @pytest.mark.skipif( + not pytest.importorskip("daytona_sdk", reason="Daytona SDK not installed"), + reason="Daytona SDK not available", + ) + async def test_daytona_workspace_template(self, daytona_available): + """Test Daytona provider with workspace template.""" + if not daytona_available: + pytest.skip("Daytona API key not available") + + # Test with a development template + provider_config = ProviderConfig( + "daytona", + { + "api_key": pytest.importorskip("os").getenv("DAYTONA_API_KEY"), + "workspace_template": "python-dev", + }, + ) + + from grainchain.providers.daytona import DaytonaProvider + + provider = DaytonaProvider(provider_config) + config = SandboxConfig(timeout=60) + + session = await provider.create_sandbox(config) + try: + # Test that Python development tools are available + result = await session.execute("python3 --version") + assert result.return_code == 0 + assert "Python" in result.stdout + + result = await session.execute("pip --version") + assert result.return_code == 0 + finally: + await session.close() + + @pytest.mark.integration + @pytest.mark.daytona + @pytest.mark.slow + @pytest.mark.skipif( + not pytest.importorskip("daytona_sdk", reason="Daytona SDK not installed"), + reason="Daytona SDK not available", + ) + async def test_daytona_multiple_workspaces(self, daytona_available): + """Test creating multiple Daytona workspaces.""" + if not daytona_available: + pytest.skip("Daytona API key not available") + + config = SandboxConfig(timeout=60) + + # Create multiple sandboxes concurrently + sandbox1 = Sandbox(provider="daytona", config=config) + sandbox2 = Sandbox(provider="daytona", config=config) + + async with sandbox1: + async with sandbox2: + # Both should work independently + result1 = await sandbox1.execute("echo 'Workspace 1'") + result2 = await sandbox2.execute("echo 'Workspace 2'") + + assert result1.return_code == 0 + assert result2.return_code == 0 + assert "Workspace 1" in result1.stdout + assert "Workspace 2" in result2.stdout + + # They should have different IDs + assert sandbox1.sandbox_id != sandbox2.sandbox_id + + @pytest.mark.integration + @pytest.mark.daytona + @pytest.mark.skipif( + not pytest.importorskip("daytona_sdk", reason="Daytona SDK not installed"), + reason="Daytona SDK not available", + ) + async def test_daytona_environment_variables(self, daytona_available): + """Test environment variables with Daytona provider.""" + if not daytona_available: + pytest.skip("Daytona API key not available") + + config = SandboxConfig( + timeout=60, + environment_vars={"TEST_VAR": "test_value", "ANOTHER_VAR": "another_value"}, + ) + + async with Sandbox(provider="daytona", config=config) as sandbox: + # Test environment variables + result = await sandbox.execute("echo $TEST_VAR") + assert result.return_code == 0 + assert "test_value" in result.stdout + + result = await sandbox.execute("echo $ANOTHER_VAR") + assert result.return_code == 0 + assert "another_value" in result.stdout + + @pytest.mark.integration + @pytest.mark.daytona + @pytest.mark.skipif( + not pytest.importorskip("daytona_sdk", reason="Daytona SDK not installed"), + reason="Daytona SDK not available", + ) + async def test_daytona_working_directory(self, daytona_available): + """Test working directory with Daytona provider.""" + if not daytona_available: + pytest.skip("Daytona API key not available") + + config = SandboxConfig(timeout=60, working_directory="/workspace") + + async with Sandbox(provider="daytona", config=config) as sandbox: + # Test working directory + result = await sandbox.execute("pwd") + assert result.return_code == 0 + assert "/workspace" in result.stdout.strip() + + @pytest.mark.integration + @pytest.mark.daytona + @pytest.mark.skipif( + not pytest.importorskip("daytona_sdk", reason="Daytona SDK not installed"), + reason="Daytona SDK not available", + ) + async def test_daytona_development_workflow(self, daytona_available): + """Test a complete development workflow on Daytona.""" + if not daytona_available: + pytest.skip("Daytona API key not available") + + config = SandboxConfig(timeout=120) + + async with Sandbox(provider="daytona", config=config) as sandbox: + # Create a simple Python project + project_structure = { + "main.py": """ +def greet(name): + return f"Hello, {name}!" + +if __name__ == "__main__": + print(greet("Daytona")) +""", + "requirements.txt": "requests==2.31.0", + "README.md": "# Test Project\n\nA simple test project for Daytona.", + } + + # Upload project files + for filename, content in project_structure.items(): + await sandbox.upload_file(f"/workspace/{filename}", content) + + # Install dependencies + result = await sandbox.execute( + "cd /workspace && pip install -r requirements.txt" + ) + assert result.return_code == 0 + + # Run the main script + result = await sandbox.execute("cd /workspace && python main.py") + assert result.return_code == 0 + assert "Hello, Daytona!" in result.stdout + + # Test that requirements were installed + result = await sandbox.execute( + "python -c 'import requests; print(\"Requests available\")'" + ) + assert result.return_code == 0 + assert "Requests available" in result.stdout + + @pytest.mark.integration + @pytest.mark.daytona + @pytest.mark.skipif( + not pytest.importorskip("daytona_sdk", reason="Daytona SDK not installed"), + reason="Daytona SDK not available", + ) + async def test_daytona_large_output(self, daytona_available): + """Test handling of large command output.""" + if not daytona_available: + pytest.skip("Daytona API key not available") + + config = SandboxConfig(timeout=60) + + async with Sandbox(provider="daytona", config=config) as sandbox: + # Generate large output + result = await sandbox.execute("python3 -c 'print(\"x\" * 10000)'") + assert result.return_code == 0 + assert len(result.stdout) >= 10000 + + @pytest.mark.integration + @pytest.mark.daytona + @pytest.mark.skipif( + not pytest.importorskip("daytona_sdk", reason="Daytona SDK not installed"), + reason="Daytona SDK not available", + ) + async def test_daytona_binary_file_operations(self, daytona_available): + """Test binary file operations with Daytona provider.""" + if not daytona_available: + pytest.skip("Daytona API key not available") + + config = SandboxConfig(timeout=60) + + async with Sandbox(provider="daytona", config=config) as sandbox: + # Create binary content + binary_content = bytes(range(256)) + + # Upload binary file + await sandbox.upload_file( + "/workspace/binary_test.bin", binary_content, mode="wb" + ) + + # Download and verify + downloaded = await sandbox.download_file("/workspace/binary_test.bin") + assert downloaded == binary_content + + @pytest.mark.integration + @pytest.mark.daytona + @pytest.mark.skipif( + not pytest.importorskip("daytona_sdk", reason="Daytona SDK not installed"), + reason="Daytona SDK not available", + ) + async def test_daytona_persistent_workspace(self, daytona_available): + """Test workspace persistence features.""" + if not daytona_available: + pytest.skip("Daytona API key not available") + + config = SandboxConfig(timeout=60, keep_alive=True) + + async with Sandbox(provider="daytona", config=config) as sandbox: + # Create a file that should persist + await sandbox.upload_file( + "/workspace/persistent_file.txt", "This should persist" + ) + + # Verify file exists + result = await sandbox.execute("cat /workspace/persistent_file.txt") + assert result.return_code == 0 + assert "This should persist" in result.stdout + + # Note: In a real scenario, we would test that the workspace + # persists across sessions, but that's complex for automated tests diff --git a/tests/integration/test_e2b_provider.py b/tests/integration/test_e2b_provider.py new file mode 100644 index 0000000..c79ff9d --- /dev/null +++ b/tests/integration/test_e2b_provider.py @@ -0,0 +1,331 @@ +"""Integration tests for E2B provider.""" + +import pytest + +from grainchain import Sandbox, SandboxConfig +from grainchain.core.config import ProviderConfig +from grainchain.core.exceptions import ConfigurationError, ProviderError, SandboxError + + +class TestE2BProviderIntegration: + """Integration tests for E2B provider.""" + + @pytest.mark.integration + @pytest.mark.e2b + @pytest.mark.skipif( + not pytest.importorskip("e2b", reason="E2B not installed"), + reason="E2B package not available", + ) + async def test_e2b_provider_real_connection(self, e2b_available): + """Test real E2B provider connection (requires API key).""" + if not e2b_available: + pytest.skip("E2B API key not available") + + config = SandboxConfig(timeout=60) + + async with Sandbox(provider="e2b", config=config) as sandbox: + assert sandbox.provider_name == "e2b" + assert sandbox.sandbox_id is not None + + # Test basic command execution + result = await sandbox.execute("echo 'Hello from E2B!'") + assert result.return_code == 0 + assert "Hello from E2B!" in result.stdout + + @pytest.mark.integration + @pytest.mark.e2b + @pytest.mark.skipif( + not pytest.importorskip("e2b", reason="E2B not installed"), + reason="E2B package not available", + ) + async def test_e2b_python_execution(self, e2b_available): + """Test Python code execution on E2B.""" + if not e2b_available: + pytest.skip("E2B API key not available") + + config = SandboxConfig(timeout=60) + + async with Sandbox(provider="e2b", config=config) as sandbox: + # Test Python execution + python_code = "import sys; print(f'Python {sys.version_info.major}.{sys.version_info.minor}')" + result = await sandbox.execute(f'python3 -c "{python_code}"') + + assert result.return_code == 0 + assert "Python" in result.stdout + + @pytest.mark.integration + @pytest.mark.e2b + @pytest.mark.skipif( + not pytest.importorskip("e2b", reason="E2B not installed"), + reason="E2B package not available", + ) + async def test_e2b_file_operations(self, e2b_available): + """Test file operations with E2B provider.""" + if not e2b_available: + pytest.skip("E2B API key not available") + + config = SandboxConfig(timeout=60) + + async with Sandbox(provider="e2b", config=config) as sandbox: + # Upload a file + test_content = "Hello, E2B file system!" + await sandbox.upload_file("/tmp/test_file.txt", test_content) + + # Verify file exists + result = await sandbox.execute("cat /tmp/test_file.txt") + assert result.return_code == 0 + assert test_content in result.stdout + + # Download the file + downloaded = await sandbox.download_file("/tmp/test_file.txt") + assert downloaded.decode() == test_content + + # List files + files = await sandbox.list_files("/tmp") + file_names = [f.name for f in files] + assert "test_file.txt" in file_names + + @pytest.mark.integration + @pytest.mark.e2b + @pytest.mark.skipif( + not pytest.importorskip("e2b", reason="E2B not installed"), + reason="E2B package not available", + ) + async def test_e2b_package_installation(self, e2b_available): + """Test package installation on E2B.""" + if not e2b_available: + pytest.skip("E2B API key not available") + + config = SandboxConfig(timeout=120) # Longer timeout for package installation + + async with Sandbox(provider="e2b", config=config) as sandbox: + # Install a package + result = await sandbox.execute("pip install requests") + assert result.return_code == 0 + + # Test the package + result = await sandbox.execute( + "python3 -c 'import requests; print(requests.__version__)'" + ) + assert result.return_code == 0 + assert len(result.stdout.strip()) > 0 + + @pytest.mark.integration + @pytest.mark.e2b + @pytest.mark.skipif( + not pytest.importorskip("e2b", reason="E2B not installed"), + reason="E2B package not available", + ) + async def test_e2b_error_handling(self, e2b_available): + """Test error handling with E2B provider.""" + if not e2b_available: + pytest.skip("E2B API key not available") + + config = SandboxConfig(timeout=30) + + async with Sandbox(provider="e2b", config=config) as sandbox: + # Test command that fails + result = await sandbox.execute("exit 1") + assert result.return_code == 1 + + # Test non-existent command + result = await sandbox.execute("nonexistent_command_12345") + assert result.return_code != 0 + + @pytest.mark.integration + @pytest.mark.e2b + @pytest.mark.skipif( + not pytest.importorskip("e2b", reason="E2B not installed"), + reason="E2B package not available", + ) + async def test_e2b_timeout_handling(self, e2b_available): + """Test timeout handling with E2B provider.""" + if not e2b_available: + pytest.skip("E2B API key not available") + + config = SandboxConfig(timeout=5) # Short timeout + + async with Sandbox(provider="e2b", config=config) as sandbox: + # This should timeout + with pytest.raises(SandboxError): # Could be TimeoutError or SandboxError + await sandbox.execute("sleep 10") + + @pytest.mark.integration + @pytest.mark.e2b + async def test_e2b_invalid_api_key(self): + """Test E2B provider with invalid API key.""" + # Create config with invalid API key + provider_config = ProviderConfig("e2b", {"api_key": "invalid_key"}) + + from grainchain.providers.e2b import E2BProvider + + provider = E2BProvider(provider_config) + config = SandboxConfig(timeout=30) + + with pytest.raises(ProviderError): + await provider.create_sandbox(config) + + @pytest.mark.integration + @pytest.mark.e2b + async def test_e2b_missing_api_key(self): + """Test E2B provider without API key.""" + provider_config = ProviderConfig("e2b", {}) + + with pytest.raises(ConfigurationError, match="E2B API key is required"): + from grainchain.providers.e2b import E2BProvider + + E2BProvider(provider_config) + + @pytest.mark.integration + @pytest.mark.e2b + @pytest.mark.skipif( + not pytest.importorskip("e2b", reason="E2B not installed"), + reason="E2B package not available", + ) + async def test_e2b_custom_template(self, e2b_available): + """Test E2B provider with custom template.""" + if not e2b_available: + pytest.skip("E2B API key not available") + + # Test with Python template (should be available) + provider_config = ProviderConfig( + "e2b", + { + "api_key": pytest.importorskip("os").getenv("E2B_API_KEY"), + "template": "python3", + }, + ) + + from grainchain.providers.e2b import E2BProvider + + provider = E2BProvider(provider_config) + config = SandboxConfig(timeout=60) + + session = await provider.create_sandbox(config) + try: + # Test that Python is available + result = await session.execute("python3 --version") + assert result.return_code == 0 + assert "Python" in result.stdout + finally: + await session.close() + + @pytest.mark.integration + @pytest.mark.e2b + @pytest.mark.slow + @pytest.mark.skipif( + not pytest.importorskip("e2b", reason="E2B not installed"), + reason="E2B package not available", + ) + async def test_e2b_multiple_sessions(self, e2b_available): + """Test creating multiple E2B sessions.""" + if not e2b_available: + pytest.skip("E2B API key not available") + + config = SandboxConfig(timeout=60) + + # Create multiple sandboxes concurrently + sandbox1 = Sandbox(provider="e2b", config=config) + sandbox2 = Sandbox(provider="e2b", config=config) + + async with sandbox1: + async with sandbox2: + # Both should work independently + result1 = await sandbox1.execute("echo 'Sandbox 1'") + result2 = await sandbox2.execute("echo 'Sandbox 2'") + + assert result1.return_code == 0 + assert result2.return_code == 0 + assert "Sandbox 1" in result1.stdout + assert "Sandbox 2" in result2.stdout + + # They should have different IDs + assert sandbox1.sandbox_id != sandbox2.sandbox_id + + @pytest.mark.integration + @pytest.mark.e2b + @pytest.mark.skipif( + not pytest.importorskip("e2b", reason="E2B not installed"), + reason="E2B package not available", + ) + async def test_e2b_environment_variables(self, e2b_available): + """Test environment variables with E2B provider.""" + if not e2b_available: + pytest.skip("E2B API key not available") + + config = SandboxConfig( + timeout=60, + environment_vars={"TEST_VAR": "test_value", "ANOTHER_VAR": "another_value"}, + ) + + async with Sandbox(provider="e2b", config=config) as sandbox: + # Test environment variables + result = await sandbox.execute("echo $TEST_VAR") + assert result.return_code == 0 + assert "test_value" in result.stdout + + result = await sandbox.execute("echo $ANOTHER_VAR") + assert result.return_code == 0 + assert "another_value" in result.stdout + + @pytest.mark.integration + @pytest.mark.e2b + @pytest.mark.skipif( + not pytest.importorskip("e2b", reason="E2B not installed"), + reason="E2B package not available", + ) + async def test_e2b_working_directory(self, e2b_available): + """Test working directory with E2B provider.""" + if not e2b_available: + pytest.skip("E2B API key not available") + + config = SandboxConfig(timeout=60, working_directory="/tmp") + + async with Sandbox(provider="e2b", config=config) as sandbox: + # Test working directory + result = await sandbox.execute("pwd") + assert result.return_code == 0 + assert "/tmp" in result.stdout.strip() + + @pytest.mark.integration + @pytest.mark.e2b + @pytest.mark.skipif( + not pytest.importorskip("e2b", reason="E2B not installed"), + reason="E2B package not available", + ) + async def test_e2b_large_output(self, e2b_available): + """Test handling of large command output.""" + if not e2b_available: + pytest.skip("E2B API key not available") + + config = SandboxConfig(timeout=60) + + async with Sandbox(provider="e2b", config=config) as sandbox: + # Generate large output + result = await sandbox.execute("python3 -c 'print(\"x\" * 10000)'") + assert result.return_code == 0 + assert len(result.stdout) >= 10000 + + @pytest.mark.integration + @pytest.mark.e2b + @pytest.mark.skipif( + not pytest.importorskip("e2b", reason="E2B not installed"), + reason="E2B package not available", + ) + async def test_e2b_binary_file_operations(self, e2b_available): + """Test binary file operations with E2B provider.""" + if not e2b_available: + pytest.skip("E2B API key not available") + + config = SandboxConfig(timeout=60) + + async with Sandbox(provider="e2b", config=config) as sandbox: + # Create binary content + binary_content = bytes(range(256)) + + # Upload binary file + await sandbox.upload_file("/tmp/binary_test.bin", binary_content, mode="wb") + + # Download and verify + downloaded = await sandbox.download_file("/tmp/binary_test.bin") + assert downloaded == binary_content diff --git a/tests/integration/test_local_provider.py b/tests/integration/test_local_provider.py new file mode 100644 index 0000000..73e83fb --- /dev/null +++ b/tests/integration/test_local_provider.py @@ -0,0 +1,449 @@ +"""Integration tests for Local provider.""" + +import asyncio + +import pytest + +from grainchain import Sandbox, SandboxConfig +from grainchain.core.exceptions import SandboxError + + +class TestLocalProviderIntegration: + """Integration tests for Local provider.""" + + @pytest.mark.integration + @pytest.mark.local + async def test_local_provider_basic_functionality(self, temp_dir): + """Test basic Local provider functionality.""" + config = SandboxConfig(timeout=30, working_directory=str(temp_dir)) + + async with Sandbox(provider="local", config=config) as sandbox: + assert sandbox.provider_name == "local" + assert sandbox.sandbox_id is not None + + # Test basic command execution + result = await sandbox.execute("echo 'Hello from Local!'") + assert result.return_code == 0 + assert "Hello from Local!" in result.stdout + + @pytest.mark.integration + @pytest.mark.local + async def test_local_python_execution(self, temp_dir): + """Test Python code execution with Local provider.""" + config = SandboxConfig(timeout=30, working_directory=str(temp_dir)) + + async with Sandbox(provider="local", config=config) as sandbox: + # Test Python execution + python_code = "import sys; print(f'Python {sys.version_info.major}.{sys.version_info.minor}')" + result = await sandbox.execute(f'python3 -c "{python_code}"') + + assert result.return_code == 0 + assert "Python" in result.stdout + + @pytest.mark.integration + @pytest.mark.local + async def test_local_file_operations(self, temp_dir): + """Test file operations with Local provider.""" + config = SandboxConfig(timeout=30, working_directory=str(temp_dir)) + + async with Sandbox(provider="local", config=config) as sandbox: + # Upload a file + test_content = "Hello, Local file system!" + await sandbox.upload_file("test_file.txt", test_content) + + # Verify file exists on disk + file_path = temp_dir / "test_file.txt" + assert file_path.exists() + assert file_path.read_text() == test_content + + # Download the file + downloaded = await sandbox.download_file("test_file.txt") + assert downloaded.decode() == test_content + + # List files + files = await sandbox.list_files(".") + file_names = [f.name for f in files] + assert "test_file.txt" in file_names + + @pytest.mark.integration + @pytest.mark.local + async def test_local_binary_file_operations(self, temp_dir): + """Test binary file operations with Local provider.""" + config = SandboxConfig(timeout=30, working_directory=str(temp_dir)) + + async with Sandbox(provider="local", config=config) as sandbox: + # Create binary content + binary_content = bytes(range(256)) + + # Upload binary file + await sandbox.upload_file("binary_test.bin", binary_content, mode="wb") + + # Verify file exists on disk + file_path = temp_dir / "binary_test.bin" + assert file_path.exists() + assert file_path.read_bytes() == binary_content + + # Download and verify + downloaded = await sandbox.download_file("binary_test.bin") + assert downloaded == binary_content + + @pytest.mark.integration + @pytest.mark.local + async def test_local_directory_operations(self, temp_dir): + """Test directory operations with Local provider.""" + config = SandboxConfig(timeout=30, working_directory=str(temp_dir)) + + async with Sandbox(provider="local", config=config) as sandbox: + # Create directory structure + await sandbox.execute("mkdir -p subdir/nested") + + # Upload files to different directories + await sandbox.upload_file("subdir/file1.txt", "File 1 content") + await sandbox.upload_file("subdir/nested/file2.txt", "File 2 content") + + # List files in subdirectory + files = await sandbox.list_files("subdir") + file_names = [f.name for f in files] + assert "file1.txt" in file_names + assert "nested" in file_names + + # List files in nested directory + nested_files = await sandbox.list_files("subdir/nested") + nested_file_names = [f.name for f in nested_files] + assert "file2.txt" in nested_file_names + + @pytest.mark.integration + @pytest.mark.local + async def test_local_command_with_pipes(self, temp_dir): + """Test command execution with pipes and redirects.""" + config = SandboxConfig(timeout=30, working_directory=str(temp_dir)) + + async with Sandbox(provider="local", config=config) as sandbox: + # Test pipe operations + result = await sandbox.execute("echo 'hello world' | grep 'world'") + assert result.return_code == 0 + assert "world" in result.stdout + + # Test output redirection + result = await sandbox.execute("echo 'test content' > output.txt") + assert result.return_code == 0 + + # Verify file was created + result = await sandbox.execute("cat output.txt") + assert result.return_code == 0 + assert "test content" in result.stdout + + @pytest.mark.integration + @pytest.mark.local + async def test_local_environment_variables(self, temp_dir): + """Test environment variables with Local provider.""" + config = SandboxConfig( + timeout=30, + working_directory=str(temp_dir), + environment_vars={"TEST_VAR": "test_value", "ANOTHER_VAR": "another_value"}, + ) + + async with Sandbox(provider="local", config=config) as sandbox: + # Test environment variables + result = await sandbox.execute("echo $TEST_VAR") + assert result.return_code == 0 + assert "test_value" in result.stdout + + result = await sandbox.execute("echo $ANOTHER_VAR") + assert result.return_code == 0 + assert "another_value" in result.stdout + + @pytest.mark.integration + @pytest.mark.local + async def test_local_working_directory(self, temp_dir): + """Test working directory with Local provider.""" + subdir = temp_dir / "workdir" + subdir.mkdir() + + config = SandboxConfig(timeout=30, working_directory=str(subdir)) + + async with Sandbox(provider="local", config=config) as sandbox: + # Test working directory + result = await sandbox.execute("pwd") + assert result.return_code == 0 + assert str(subdir) in result.stdout.strip() + + @pytest.mark.integration + @pytest.mark.local + async def test_local_error_handling(self, temp_dir): + """Test error handling with Local provider.""" + config = SandboxConfig(timeout=30, working_directory=str(temp_dir)) + + async with Sandbox(provider="local", config=config) as sandbox: + # Test command that fails + result = await sandbox.execute("exit 1") + assert result.return_code == 1 + + # Test non-existent command + result = await sandbox.execute("nonexistent_command_12345") + assert result.return_code != 0 + + @pytest.mark.integration + @pytest.mark.local + async def test_local_timeout_handling(self, temp_dir): + """Test timeout handling with Local provider.""" + config = SandboxConfig( + timeout=2, # Very short timeout + working_directory=str(temp_dir), + ) + + async with Sandbox(provider="local", config=config) as sandbox: + # This should timeout + with pytest.raises(SandboxError, match="Command execution failed"): + await sandbox.execute("sleep 5") + + @pytest.mark.integration + @pytest.mark.local + @pytest.mark.snapshot + async def test_local_snapshot_functionality(self, temp_dir): + """Test snapshot functionality with Local provider.""" + config = SandboxConfig(timeout=30, working_directory=str(temp_dir)) + + async with Sandbox(provider="local", config=config) as sandbox: + # Create initial state + await sandbox.upload_file("initial.txt", "Initial content") + await sandbox.execute("mkdir initial_dir") + + # Create snapshot + snapshot_id = await sandbox.create_snapshot() + assert isinstance(snapshot_id, str) + assert len(snapshot_id) > 0 + + # Make changes after snapshot + await sandbox.upload_file("after_snapshot.txt", "After snapshot content") + await sandbox.execute("mkdir after_dir") + + # Verify changes exist + files = await sandbox.list_files(".") + file_names = [f.name for f in files] + assert "after_snapshot.txt" in file_names + assert "after_dir" in file_names + + # Restore snapshot + await sandbox.restore_snapshot(snapshot_id) + + # Verify restoration + files = await sandbox.list_files(".") + file_names = [f.name for f in files] + assert "initial.txt" in file_names + assert "initial_dir" in file_names + assert "after_snapshot.txt" not in file_names + assert "after_dir" not in file_names + + @pytest.mark.integration + @pytest.mark.local + @pytest.mark.snapshot + async def test_local_multiple_snapshots(self, temp_dir): + """Test multiple snapshots with Local provider.""" + config = SandboxConfig(timeout=30, working_directory=str(temp_dir)) + + async with Sandbox(provider="local", config=config) as sandbox: + # State 1 + await sandbox.upload_file("state1.txt", "State 1") + snapshot1 = await sandbox.create_snapshot() + + # State 2 + await sandbox.upload_file("state2.txt", "State 2") + snapshot2 = await sandbox.create_snapshot() + + # State 3 + await sandbox.upload_file("state3.txt", "State 3") + + # Restore to state 1 + await sandbox.restore_snapshot(snapshot1) + files = await sandbox.list_files(".") + file_names = [f.name for f in files] + assert "state1.txt" in file_names + assert "state2.txt" not in file_names + assert "state3.txt" not in file_names + + # Restore to state 2 + await sandbox.restore_snapshot(snapshot2) + files = await sandbox.list_files(".") + file_names = [f.name for f in files] + assert "state1.txt" in file_names + assert "state2.txt" in file_names + assert "state3.txt" not in file_names + + @pytest.mark.integration + @pytest.mark.local + async def test_local_concurrent_operations(self, temp_dir): + """Test concurrent operations within a Local sandbox.""" + config = SandboxConfig(timeout=30, working_directory=str(temp_dir)) + + async with Sandbox(provider="local", config=config) as sandbox: + # Upload multiple files concurrently + upload_tasks = [] + for i in range(10): + content = f"File {i} content" + task = sandbox.upload_file(f"file_{i}.txt", content) + upload_tasks.append(task) + + await asyncio.gather(*upload_tasks) + + # Verify all files were uploaded + files = await sandbox.list_files(".") + file_names = [f.name for f in files] + + for i in range(10): + assert f"file_{i}.txt" in file_names + + @pytest.mark.integration + @pytest.mark.local + async def test_local_large_file_operations(self, temp_dir): + """Test operations with large files.""" + config = SandboxConfig( + timeout=60, # Longer timeout for large files + working_directory=str(temp_dir), + ) + + async with Sandbox(provider="local", config=config) as sandbox: + # Create large content (1MB) + large_content = "x" * (1024 * 1024) + + # Upload large file + await sandbox.upload_file("large_file.txt", large_content) + + # Verify file size + result = await sandbox.execute("wc -c large_file.txt") + assert result.return_code == 0 + assert "1048576" in result.stdout # 1MB in bytes + + # Download and verify + downloaded = await sandbox.download_file("large_file.txt") + assert len(downloaded) == 1024 * 1024 + assert downloaded.decode() == large_content + + @pytest.mark.integration + @pytest.mark.local + async def test_local_package_installation(self, temp_dir): + """Test package installation with Local provider.""" + config = SandboxConfig( + timeout=120, # Longer timeout for package installation + working_directory=str(temp_dir), + ) + + async with Sandbox(provider="local", config=config) as sandbox: + # Install a package (if pip is available) + result = await sandbox.execute("python3 -m pip --version") + if result.return_code == 0: + # pip is available, test installation + result = await sandbox.execute("python3 -m pip install --user requests") + assert result.return_code == 0 + + # Test the package + result = await sandbox.execute( + "python3 -c 'import requests; print(requests.__version__)'" + ) + assert result.return_code == 0 + assert len(result.stdout.strip()) > 0 + + @pytest.mark.integration + @pytest.mark.local + async def test_local_script_execution(self, temp_dir): + """Test script execution with Local provider.""" + config = SandboxConfig(timeout=30, working_directory=str(temp_dir)) + + async with Sandbox(provider="local", config=config) as sandbox: + # Create a Python script + script_content = """#!/usr/bin/env python3 +import sys +import json + +def main(): + data = { + "message": "Hello from script!", + "args": sys.argv[1:], + "python_version": f"{sys.version_info.major}.{sys.version_info.minor}" + } + print(json.dumps(data, indent=2)) + +if __name__ == "__main__": + main() +""" + + await sandbox.upload_file("test_script.py", script_content) + + # Make script executable + result = await sandbox.execute("chmod +x test_script.py") + assert result.return_code == 0 + + # Run script + result = await sandbox.execute("python3 test_script.py arg1 arg2") + assert result.return_code == 0 + assert "Hello from script!" in result.stdout + assert "arg1" in result.stdout + assert "arg2" in result.stdout + + @pytest.mark.integration + @pytest.mark.local + async def test_local_multiple_sessions(self, temp_dir): + """Test creating multiple Local sessions.""" + config = SandboxConfig(timeout=30, working_directory=str(temp_dir)) + + # Create multiple sandboxes + sandbox1 = Sandbox(provider="local", config=config) + sandbox2 = Sandbox(provider="local", config=config) + + async with sandbox1: + async with sandbox2: + # Both should work independently + result1 = await sandbox1.execute("echo 'Sandbox 1'") + result2 = await sandbox2.execute("echo 'Sandbox 2'") + + assert result1.return_code == 0 + assert result2.return_code == 0 + assert "Sandbox 1" in result1.stdout + assert "Sandbox 2" in result2.stdout + + # They should have different IDs + assert sandbox1.sandbox_id != sandbox2.sandbox_id + + @pytest.mark.integration + @pytest.mark.local + async def test_local_file_permissions(self, temp_dir): + """Test file permissions with Local provider.""" + config = SandboxConfig(timeout=30, working_directory=str(temp_dir)) + + async with Sandbox(provider="local", config=config) as sandbox: + # Create a file and set permissions + await sandbox.upload_file("test_permissions.txt", "Test content") + + # Make file executable + result = await sandbox.execute("chmod +x test_permissions.txt") + assert result.return_code == 0 + + # Check permissions + result = await sandbox.execute("ls -l test_permissions.txt") + assert result.return_code == 0 + assert "x" in result.stdout # Should have execute permission + + @pytest.mark.integration + @pytest.mark.local + async def test_local_symlink_operations(self, temp_dir): + """Test symbolic link operations with Local provider.""" + config = SandboxConfig(timeout=30, working_directory=str(temp_dir)) + + async with Sandbox(provider="local", config=config) as sandbox: + # Create a file + await sandbox.upload_file("original.txt", "Original content") + + # Create symbolic link + result = await sandbox.execute("ln -s original.txt link.txt") + assert result.return_code == 0 + + # Read through symlink + result = await sandbox.execute("cat link.txt") + assert result.return_code == 0 + assert "Original content" in result.stdout + + # List files to see symlink + files = await sandbox.list_files(".") + file_names = [f.name for f in files] + assert "original.txt" in file_names + assert "link.txt" in file_names diff --git a/tests/integration/test_modal_provider.py b/tests/integration/test_modal_provider.py new file mode 100644 index 0000000..e826419 --- /dev/null +++ b/tests/integration/test_modal_provider.py @@ -0,0 +1,412 @@ +"""Integration tests for Modal provider.""" + +import pytest + +from grainchain import Sandbox, SandboxConfig +from grainchain.core.config import ProviderConfig +from grainchain.core.exceptions import ConfigurationError, ProviderError, SandboxError + + +class TestModalProviderIntegration: + """Integration tests for Modal provider.""" + + @pytest.mark.integration + @pytest.mark.modal + @pytest.mark.skipif( + not pytest.importorskip("modal", reason="Modal not installed"), + reason="Modal package not available", + ) + async def test_modal_provider_real_connection(self, modal_available): + """Test real Modal provider connection (requires credentials).""" + if not modal_available: + pytest.skip("Modal credentials not available") + + config = SandboxConfig(timeout=60) + + async with Sandbox(provider="modal", config=config) as sandbox: + assert sandbox.provider_name == "modal" + assert sandbox.sandbox_id is not None + + # Test basic command execution + result = await sandbox.execute("echo 'Hello from Modal!'") + assert result.return_code == 0 + assert "Hello from Modal!" in result.stdout + + @pytest.mark.integration + @pytest.mark.modal + @pytest.mark.skipif( + not pytest.importorskip("modal", reason="Modal not installed"), + reason="Modal package not available", + ) + async def test_modal_python_execution(self, modal_available): + """Test Python code execution on Modal.""" + if not modal_available: + pytest.skip("Modal credentials not available") + + config = SandboxConfig(timeout=60) + + async with Sandbox(provider="modal", config=config) as sandbox: + # Test Python execution + python_code = "import sys; print(f'Python {sys.version_info.major}.{sys.version_info.minor}')" + result = await sandbox.execute(f'python -c "{python_code}"') + + assert result.return_code == 0 + assert "Python" in result.stdout + + @pytest.mark.integration + @pytest.mark.modal + @pytest.mark.skipif( + not pytest.importorskip("modal", reason="Modal not installed"), + reason="Modal package not available", + ) + async def test_modal_file_operations(self, modal_available): + """Test file operations with Modal provider.""" + if not modal_available: + pytest.skip("Modal credentials not available") + + config = SandboxConfig(timeout=60) + + async with Sandbox(provider="modal", config=config) as sandbox: + # Upload a file + test_content = "Hello, Modal file system!" + await sandbox.upload_file("/tmp/test_file.txt", test_content) + + # Verify file exists + result = await sandbox.execute("cat /tmp/test_file.txt") + assert result.return_code == 0 + assert test_content in result.stdout + + # Download the file + downloaded = await sandbox.download_file("/tmp/test_file.txt") + assert downloaded.decode() == test_content + + # List files + files = await sandbox.list_files("/tmp") + file_names = [f.name for f in files] + assert "test_file.txt" in file_names + + @pytest.mark.integration + @pytest.mark.modal + @pytest.mark.skipif( + not pytest.importorskip("modal", reason="Modal not installed"), + reason="Modal package not available", + ) + async def test_modal_package_installation(self, modal_available): + """Test package installation on Modal.""" + if not modal_available: + pytest.skip("Modal credentials not available") + + config = SandboxConfig(timeout=120) # Longer timeout for package installation + + async with Sandbox(provider="modal", config=config) as sandbox: + # Install a package + result = await sandbox.execute("pip install requests") + assert result.return_code == 0 + + # Test the package + result = await sandbox.execute( + "python -c 'import requests; print(requests.__version__)'" + ) + assert result.return_code == 0 + assert len(result.stdout.strip()) > 0 + + @pytest.mark.integration + @pytest.mark.modal + @pytest.mark.skipif( + not pytest.importorskip("modal", reason="Modal not installed"), + reason="Modal package not available", + ) + async def test_modal_error_handling(self, modal_available): + """Test error handling with Modal provider.""" + if not modal_available: + pytest.skip("Modal credentials not available") + + config = SandboxConfig(timeout=30) + + async with Sandbox(provider="modal", config=config) as sandbox: + # Test command that fails + result = await sandbox.execute("exit 1") + assert result.return_code == 1 + + # Test non-existent command + result = await sandbox.execute("nonexistent_command_12345") + assert result.return_code != 0 + + @pytest.mark.integration + @pytest.mark.modal + @pytest.mark.skipif( + not pytest.importorskip("modal", reason="Modal not installed"), + reason="Modal package not available", + ) + async def test_modal_timeout_handling(self, modal_available): + """Test timeout handling with Modal provider.""" + if not modal_available: + pytest.skip("Modal credentials not available") + + config = SandboxConfig(timeout=5) # Short timeout + + async with Sandbox(provider="modal", config=config) as sandbox: + # This should timeout + with pytest.raises(SandboxError): # Could be TimeoutError or SandboxError + await sandbox.execute("sleep 10") + + @pytest.mark.integration + @pytest.mark.modal + async def test_modal_invalid_credentials(self): + """Test Modal provider with invalid credentials.""" + # Create config with invalid credentials + provider_config = ProviderConfig( + "modal", {"token_id": "invalid_id", "token_secret": "invalid_secret"} + ) + + from grainchain.providers.modal import ModalProvider + + provider = ModalProvider(provider_config) + config = SandboxConfig(timeout=30) + + with pytest.raises(ProviderError): + await provider.create_sandbox(config) + + @pytest.mark.integration + @pytest.mark.modal + async def test_modal_missing_credentials(self): + """Test Modal provider without credentials.""" + provider_config = ProviderConfig("modal", {}) + + with pytest.raises(ConfigurationError, match="Modal credentials are required"): + from grainchain.providers.modal import ModalProvider + + ModalProvider(provider_config) + + @pytest.mark.integration + @pytest.mark.modal + @pytest.mark.skipif( + not pytest.importorskip("modal", reason="Modal not installed"), + reason="Modal package not available", + ) + async def test_modal_custom_image(self, modal_available): + """Test Modal provider with custom image.""" + if not modal_available: + pytest.skip("Modal credentials not available") + + config = SandboxConfig( + timeout=60, + image="python:3.11", # Custom Python image + ) + + async with Sandbox(provider="modal", config=config) as sandbox: + # Test that Python is available + result = await sandbox.execute("python --version") + assert result.return_code == 0 + assert "Python" in result.stdout + + @pytest.mark.integration + @pytest.mark.modal + @pytest.mark.slow + @pytest.mark.skipif( + not pytest.importorskip("modal", reason="Modal not installed"), + reason="Modal package not available", + ) + async def test_modal_multiple_sessions(self, modal_available): + """Test creating multiple Modal sessions.""" + if not modal_available: + pytest.skip("Modal credentials not available") + + config = SandboxConfig(timeout=60) + + # Create multiple sandboxes concurrently + sandbox1 = Sandbox(provider="modal", config=config) + sandbox2 = Sandbox(provider="modal", config=config) + + async with sandbox1: + async with sandbox2: + # Both should work independently + result1 = await sandbox1.execute("echo 'Sandbox 1'") + result2 = await sandbox2.execute("echo 'Sandbox 2'") + + assert result1.return_code == 0 + assert result2.return_code == 0 + assert "Sandbox 1" in result1.stdout + assert "Sandbox 2" in result2.stdout + + # They should have different IDs + assert sandbox1.sandbox_id != sandbox2.sandbox_id + + @pytest.mark.integration + @pytest.mark.modal + @pytest.mark.skipif( + not pytest.importorskip("modal", reason="Modal not installed"), + reason="Modal package not available", + ) + async def test_modal_environment_variables(self, modal_available): + """Test environment variables with Modal provider.""" + if not modal_available: + pytest.skip("Modal credentials not available") + + config = SandboxConfig( + timeout=60, + environment_vars={"TEST_VAR": "test_value", "ANOTHER_VAR": "another_value"}, + ) + + async with Sandbox(provider="modal", config=config) as sandbox: + # Test environment variables + result = await sandbox.execute("echo $TEST_VAR") + assert result.return_code == 0 + assert "test_value" in result.stdout + + result = await sandbox.execute("echo $ANOTHER_VAR") + assert result.return_code == 0 + assert "another_value" in result.stdout + + @pytest.mark.integration + @pytest.mark.modal + @pytest.mark.skipif( + not pytest.importorskip("modal", reason="Modal not installed"), + reason="Modal package not available", + ) + async def test_modal_working_directory(self, modal_available): + """Test working directory with Modal provider.""" + if not modal_available: + pytest.skip("Modal credentials not available") + + config = SandboxConfig(timeout=60, working_directory="/tmp") + + async with Sandbox(provider="modal", config=config) as sandbox: + # Test working directory + result = await sandbox.execute("pwd") + assert result.return_code == 0 + assert "/tmp" in result.stdout.strip() + + @pytest.mark.integration + @pytest.mark.modal + @pytest.mark.skipif( + not pytest.importorskip("modal", reason="Modal not installed"), + reason="Modal package not available", + ) + async def test_modal_resource_limits(self, modal_available): + """Test resource limits with Modal provider.""" + if not modal_available: + pytest.skip("Modal credentials not available") + + config = SandboxConfig(timeout=60, memory_limit="1GB", cpu_limit=1.0) + + async with Sandbox(provider="modal", config=config) as sandbox: + # Test that sandbox respects resource limits + result = await sandbox.execute("echo 'Resource limits test'") + assert result.return_code == 0 + + @pytest.mark.integration + @pytest.mark.modal + @pytest.mark.skipif( + not pytest.importorskip("modal", reason="Modal not installed"), + reason="Modal package not available", + ) + async def test_modal_large_output(self, modal_available): + """Test handling of large command output.""" + if not modal_available: + pytest.skip("Modal credentials not available") + + config = SandboxConfig(timeout=60) + + async with Sandbox(provider="modal", config=config) as sandbox: + # Generate large output + result = await sandbox.execute("python -c 'print(\"x\" * 10000)'") + assert result.return_code == 0 + assert len(result.stdout) >= 10000 + + @pytest.mark.integration + @pytest.mark.modal + @pytest.mark.skipif( + not pytest.importorskip("modal", reason="Modal not installed"), + reason="Modal package not available", + ) + async def test_modal_binary_file_operations(self, modal_available): + """Test binary file operations with Modal provider.""" + if not modal_available: + pytest.skip("Modal credentials not available") + + config = SandboxConfig(timeout=60) + + async with Sandbox(provider="modal", config=config) as sandbox: + # Create binary content + binary_content = bytes(range(256)) + + # Upload binary file + await sandbox.upload_file("/tmp/binary_test.bin", binary_content, mode="wb") + + # Download and verify + downloaded = await sandbox.download_file("/tmp/binary_test.bin") + assert downloaded == binary_content + + @pytest.mark.integration + @pytest.mark.modal + @pytest.mark.skipif( + not pytest.importorskip("modal", reason="Modal not installed"), + reason="Modal package not available", + ) + async def test_modal_concurrent_operations(self, modal_available): + """Test concurrent operations within a Modal sandbox.""" + if not modal_available: + pytest.skip("Modal credentials not available") + + config = SandboxConfig(timeout=60) + + async with Sandbox(provider="modal", config=config) as sandbox: + # Upload multiple files concurrently + import asyncio + + upload_tasks = [] + for i in range(5): + content = f"File {i} content" + task = sandbox.upload_file(f"/tmp/file_{i}.txt", content) + upload_tasks.append(task) + + await asyncio.gather(*upload_tasks) + + # Verify all files were uploaded + files = await sandbox.list_files("/tmp") + file_names = [f.name for f in files] + + for i in range(5): + assert f"file_{i}.txt" in file_names + + @pytest.mark.integration + @pytest.mark.modal + @pytest.mark.skipif( + not pytest.importorskip("modal", reason="Modal not installed"), + reason="Modal package not available", + ) + async def test_modal_data_science_workflow(self, modal_available): + """Test a typical data science workflow on Modal.""" + if not modal_available: + pytest.skip("Modal credentials not available") + + config = SandboxConfig(timeout=120) + + async with Sandbox(provider="modal", config=config) as sandbox: + # Install data science packages + result = await sandbox.execute("pip install numpy pandas") + assert result.return_code == 0 + + # Create a simple data analysis script + script_content = """ +import numpy as np +import pandas as pd + +# Create sample data +data = np.random.randn(100, 3) +df = pd.DataFrame(data, columns=['A', 'B', 'C']) + +# Basic analysis +print(f"Shape: {df.shape}") +print(f"Mean: {df.mean().to_dict()}") +print("Analysis complete!") +""" + + await sandbox.upload_file("/tmp/analysis.py", script_content) + + # Run the analysis + result = await sandbox.execute("python /tmp/analysis.py") + assert result.return_code == 0 + assert "Shape: (100, 3)" in result.stdout + assert "Analysis complete!" in result.stdout diff --git a/tests/unit/test_config.py b/tests/unit/test_config.py new file mode 100644 index 0000000..f4ccbcd --- /dev/null +++ b/tests/unit/test_config.py @@ -0,0 +1,487 @@ +"""Unit tests for configuration management.""" + +import os +import tempfile +from unittest.mock import patch + +import pytest +import yaml + +from grainchain.core.config import ( + ConfigManager, + ProviderConfig, + SandboxConfig, + get_config_manager, + set_config_manager, +) +from grainchain.core.exceptions import ConfigurationError + + +@pytest.fixture +def clean_env(): + """Fixture to temporarily unset GRAINCHAIN_DEFAULT_PROVIDER for testing.""" + original_value = os.environ.get("GRAINCHAIN_DEFAULT_PROVIDER") + if "GRAINCHAIN_DEFAULT_PROVIDER" in os.environ: + del os.environ["GRAINCHAIN_DEFAULT_PROVIDER"] + + yield + + # Restore original value + if original_value is not None: + os.environ["GRAINCHAIN_DEFAULT_PROVIDER"] = original_value + + +class TestProviderConfig: + """Test cases for ProviderConfig.""" + + @pytest.mark.unit + def test_provider_config_init(self): + """Test ProviderConfig initialization.""" + config = ProviderConfig("test_provider") + + assert config.name == "test_provider" + assert config.config == {} + + @pytest.mark.unit + def test_provider_config_with_data(self): + """Test ProviderConfig with initial data.""" + data = {"api_key": "test_key", "timeout": 30} + config = ProviderConfig("test_provider", data) + + assert config.name == "test_provider" + assert config.config == data + + @pytest.mark.unit + def test_provider_config_get(self): + """Test getting configuration values.""" + config = ProviderConfig("test", {"key1": "value1", "key2": "value2"}) + + assert config.get("key1") == "value1" + assert config.get("key2") == "value2" + assert config.get("missing") is None + assert config.get("missing", "default") == "default" + + @pytest.mark.unit + def test_provider_config_set(self): + """Test setting configuration values.""" + config = ProviderConfig("test") + + config.set("new_key", "new_value") + assert config.get("new_key") == "new_value" + + config.set("new_key", "updated_value") + assert config.get("new_key") == "updated_value" + + +class TestSandboxConfig: + """Test cases for SandboxConfig.""" + + @pytest.mark.unit + def test_sandbox_config_defaults(self): + """Test SandboxConfig default values.""" + config = SandboxConfig() + + assert config.timeout == 300 + assert config.memory_limit is None + assert config.cpu_limit is None + assert config.image is None + assert config.working_directory == "~" + assert config.environment_vars == {} + assert config.auto_cleanup is True + assert config.keep_alive is False + assert config.provider_config == {} + + @pytest.mark.unit + def test_sandbox_config_custom_values(self): + """Test SandboxConfig with custom values.""" + env_vars = {"TEST_VAR": "test_value"} + provider_config = {"custom_setting": "custom_value"} + + config = SandboxConfig( + timeout=600, + memory_limit="2GB", + cpu_limit=2.0, + image="custom:latest", + working_directory="/workspace", + environment_vars=env_vars, + auto_cleanup=False, + keep_alive=True, + provider_config=provider_config, + ) + + assert config.timeout == 600 + assert config.memory_limit == "2GB" + assert config.cpu_limit == 2.0 + assert config.image == "custom:latest" + assert config.working_directory == "/workspace" + assert config.environment_vars == env_vars + assert config.auto_cleanup is False + assert config.keep_alive is True + assert config.provider_config == provider_config + + @pytest.mark.unit + def test_sandbox_config_mutable_defaults(self): + """Test that mutable defaults are properly handled.""" + config1 = SandboxConfig() + config2 = SandboxConfig() + + # Modify one config's environment vars + config1.environment_vars["TEST"] = "value" + + # Other config should not be affected + assert "TEST" not in config2.environment_vars + + +class TestConfigManager: + """Test cases for ConfigManager.""" + + @pytest.mark.unit + def test_config_manager_init_empty(self): + """Test ConfigManager initialization without config file.""" + with patch.object(ConfigManager, "_load_config_file"): + with patch.object(ConfigManager, "_load_env_config"): + manager = ConfigManager() + assert manager._config == {} + + @pytest.mark.unit + def test_config_manager_default_provider(self): + """Test default provider configuration.""" + # Temporarily unset environment variable to test default + import os + + original_value = os.environ.get("GRAINCHAIN_DEFAULT_PROVIDER") + if "GRAINCHAIN_DEFAULT_PROVIDER" in os.environ: + del os.environ["GRAINCHAIN_DEFAULT_PROVIDER"] + + try: + manager = ConfigManager() + assert manager.default_provider == "e2b" # Default fallback + finally: + # Restore original value + if original_value is not None: + os.environ["GRAINCHAIN_DEFAULT_PROVIDER"] = original_value + + @pytest.mark.unit + def test_config_manager_custom_default_provider(self): + """Test custom default provider.""" + manager = ConfigManager() + manager._config["default_provider"] = "custom" + assert manager.default_provider == "custom" + + @pytest.mark.unit + def test_get_provider_config_existing(self): + """Test getting existing provider configuration.""" + manager = ConfigManager() + manager._providers["test"] = ProviderConfig("test", {"key": "value"}) + + config = manager.get_provider_config("test") + assert config.name == "test" + assert config.get("key") == "value" + + @pytest.mark.unit + def test_get_provider_config_new(self): + """Test getting configuration for new provider.""" + manager = ConfigManager() + + config = manager.get_provider_config("new_provider") + assert config.name == "new_provider" + assert config.config == {} + + @pytest.mark.unit + def test_get_sandbox_defaults(self): + """Test getting sandbox defaults.""" + manager = ConfigManager() + manager._config["sandbox_defaults"] = { + "timeout": 600, + "working_directory": "/custom", + } + + config = manager.get_sandbox_defaults() + assert config.timeout == 600 + assert config.working_directory == "/custom" + + @pytest.mark.unit + def test_get_sandbox_defaults_empty(self): + """Test getting sandbox defaults when none configured.""" + manager = ConfigManager() + + config = manager.get_sandbox_defaults() + assert isinstance(config, SandboxConfig) + assert config.timeout == 300 # Default value + + @pytest.mark.unit + def test_config_get_set(self): + """Test getting and setting configuration values.""" + manager = ConfigManager() + + assert manager.get("missing") is None + assert manager.get("missing", "default") == "default" + + manager.set("test_key", "test_value") + assert manager.get("test_key") == "test_value" + + @pytest.mark.unit + def test_load_config_file_yaml(self, clean_env): + """Test loading configuration from YAML file.""" + config_data = { + "default_provider": "test", + "providers": { + "test": {"api_key": "test_key"}, + }, + "sandbox_defaults": {"timeout": 600}, + } + + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: + yaml.dump(config_data, f) + config_path = f.name + + try: + manager = ConfigManager(config_path) + assert manager.default_provider == "test" + assert manager.get_provider_config("test").get("api_key") == "test_key" + assert manager.get_sandbox_defaults().timeout == 600 + finally: + os.unlink(config_path) + + @pytest.mark.unit + def test_load_config_file_invalid(self): + """Test loading invalid configuration file.""" + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: + f.write("invalid: yaml: content: [") + config_path = f.name + + try: + with pytest.raises(ConfigurationError, match="Failed to load config file"): + ConfigManager(config_path) + finally: + os.unlink(config_path) + + @pytest.mark.unit + def test_load_config_file_not_found(self): + """Test loading non-existent configuration file.""" + # Should raise ConfigurationError for explicit non-existent file + with pytest.raises(ConfigurationError, match="Failed to load config file"): + ConfigManager("/nonexistent/config.yaml") + + @pytest.mark.unit + def test_load_env_config(self, env_vars): + """Test loading configuration from environment variables.""" + manager = ConfigManager() + + assert manager.default_provider == "test" + assert manager.get_provider_config("e2b").get("api_key") == "test_e2b_key" + assert manager.get_provider_config("modal").get("token_id") == "test_modal_id" + assert ( + manager.get_provider_config("modal").get("token_secret") + == "test_modal_secret" + ) + assert ( + manager.get_provider_config("daytona").get("api_key") == "test_daytona_key" + ) + + @pytest.mark.unit + def test_init_providers(self): + """Test provider initialization from config.""" + manager = ConfigManager() + manager._config = { + "providers": { + "provider1": {"key1": "value1"}, + "provider2": {"key2": "value2"}, + } + } + manager._init_providers() + + assert "provider1" in manager._providers + assert "provider2" in manager._providers + assert manager._providers["provider1"].get("key1") == "value1" + assert manager._providers["provider2"].get("key2") == "value2" + + @pytest.mark.unit + def test_default_config_paths(self): + """Test default configuration file paths.""" + expected_paths = [ + "grainchain.yaml", + "grainchain.yml", + ".grainchain.yaml", + ".grainchain.yml", + "~/.grainchain.yaml", + "~/.grainchain.yml", + ] + + assert ConfigManager.DEFAULT_CONFIG_PATHS == expected_paths + + @pytest.mark.unit + def test_load_config_from_default_paths(self, clean_env): + """Test loading config from default paths.""" + config_data = {"default_provider": "from_file"} + + # Create a config file in current directory + with open("grainchain.yaml", "w") as f: + yaml.dump(config_data, f) + + try: + manager = ConfigManager() + assert manager.default_provider == "from_file" + finally: + os.unlink("grainchain.yaml") + + +class TestGlobalConfigManager: + """Test cases for global config manager functions.""" + + @pytest.mark.unit + def test_get_config_manager_singleton(self): + """Test that get_config_manager returns singleton.""" + manager1 = get_config_manager() + manager2 = get_config_manager() + + assert manager1 is manager2 + + @pytest.mark.unit + def test_set_config_manager(self): + """Test setting custom config manager.""" + original_manager = get_config_manager() + custom_manager = ConfigManager() + + set_config_manager(custom_manager) + + assert get_config_manager() is custom_manager + + # Restore original for other tests + set_config_manager(original_manager) + + @pytest.mark.unit + def test_config_manager_reset(self): + """Test resetting config manager.""" + # Get initial manager + manager1 = get_config_manager() + + # Set a custom one + custom_manager = ConfigManager() + set_config_manager(custom_manager) + + # Reset to None and get new one + set_config_manager(None) + manager2 = get_config_manager() + + # Should be a new instance + assert manager2 is not manager1 + assert manager2 is not custom_manager + + +class TestConfigurationIntegration: + """Integration tests for configuration system.""" + + @pytest.mark.unit + def test_full_config_integration(self, temp_dir): + """Test full configuration integration.""" + # Create config file + config_file = temp_dir / "test_config.yaml" + config_data = { + "default_provider": "e2b", + "providers": { + "e2b": { + "api_key": "file_e2b_key", + "template": "python", + }, + "local": { + "working_dir": "/tmp", + }, + }, + "sandbox_defaults": { + "timeout": 900, + "memory_limit": "4GB", + "auto_cleanup": False, + }, + } + + with open(config_file, "w") as f: + yaml.dump(config_data, f) + + # Set environment variables (should override file values) + with patch.dict( + os.environ, + { + "E2B_API_KEY": "env_e2b_key", + "GRAINCHAIN_DEFAULT_PROVIDER": "local", + "E2B_TEMPLATE": "env_template", # Override file template + }, + ): + manager = ConfigManager(str(config_file)) + + # Environment should override file for default provider + assert manager.default_provider == "local" + + # Environment should override file for API key + e2b_config = manager.get_provider_config("e2b") + assert e2b_config.get("api_key") == "env_e2b_key" + assert e2b_config.get("template") == "env_template" # From environment + + # Sandbox defaults from file + sandbox_config = manager.get_sandbox_defaults() + assert sandbox_config.timeout == 900 + assert sandbox_config.memory_limit == "4GB" + assert sandbox_config.auto_cleanup is False + + @pytest.mark.unit + def test_config_precedence(self, temp_dir): + """Test configuration precedence (env > file > defaults).""" + # Create config file + config_file = temp_dir / "precedence_test.yaml" + config_data = { + "default_provider": "file_provider", + "providers": { + "test": { + "file_only": "file_value", + "both_file_env": "file_value", + }, + }, + } + + with open(config_file, "w") as f: + yaml.dump(config_data, f) + + # Set environment variables + with patch.dict( + os.environ, + { + "GRAINCHAIN_DEFAULT_PROVIDER": "env_provider", + "E2B_API_KEY": "env_e2b_key", + }, + ): + manager = ConfigManager(str(config_file)) + + # Environment wins over file + assert manager.default_provider == "env_provider" + + # File values are preserved when no env override + test_config = manager.get_provider_config("test") + assert test_config.get("file_only") == "file_value" + + # Environment values are added + e2b_config = manager.get_provider_config("e2b") + assert e2b_config.get("api_key") == "env_e2b_key" + + @pytest.mark.unit + def test_config_error_handling(self): + """Test configuration error handling.""" + # Test with directory instead of file + with pytest.raises(ConfigurationError): + ConfigManager("/tmp") # Directory, not file + + # Test with permission denied (simulate) + with patch("builtins.open", side_effect=PermissionError("Permission denied")): + with pytest.raises(ConfigurationError, match="Failed to load config file"): + ConfigManager("test.yaml") + + @pytest.mark.unit + def test_empty_config_file(self, temp_dir, clean_env): + """Test handling of empty configuration file.""" + config_file = temp_dir / "empty.yaml" + config_file.write_text("") + + manager = ConfigManager(str(config_file)) + + # Should use defaults + assert manager.default_provider == "e2b" + assert isinstance(manager.get_sandbox_defaults(), SandboxConfig) diff --git a/tests/unit/test_exceptions.py b/tests/unit/test_exceptions.py new file mode 100644 index 0000000..3524e09 --- /dev/null +++ b/tests/unit/test_exceptions.py @@ -0,0 +1,382 @@ +"""Unit tests for exception classes.""" + +import pytest + +from grainchain.core.exceptions import ( + AuthenticationError, + ConfigurationError, + GrainchainError, + NetworkError, + ProviderError, + ResourceError, + SandboxError, +) +from grainchain.core.exceptions import ( + TimeoutError as GrainchainTimeoutError, +) + + +class TestGrainchainError: + """Test cases for base GrainchainError.""" + + @pytest.mark.unit + def test_grainchain_error_basic(self): + """Test basic GrainchainError functionality.""" + error = GrainchainError("Test error message") + assert str(error) == "Test error message" + assert isinstance(error, Exception) + + @pytest.mark.unit + def test_grainchain_error_inheritance(self): + """Test that all custom exceptions inherit from GrainchainError.""" + # Test exceptions with simple constructors + exceptions = [ + SandboxError, + ConfigurationError, + ResourceError, + NetworkError, + ] + + for exc_class in exceptions: + error = exc_class("Test message") + assert isinstance(error, GrainchainError) + assert isinstance(error, Exception) + assert str(error) == "Test message" + + # Test ProviderError separately due to different constructor + provider_error = ProviderError("Test message", "test_provider") + assert isinstance(provider_error, GrainchainError) + assert isinstance(provider_error, Exception) + assert str(provider_error) == "Test message" + assert provider_error.provider == "test_provider" + + # Test TimeoutError separately due to different constructor + timeout_error = GrainchainTimeoutError("Test timeout", 30) + assert isinstance(timeout_error, GrainchainError) + assert isinstance(timeout_error, Exception) + assert str(timeout_error) == "Test timeout" + assert timeout_error.timeout_seconds == 30 + + # Test AuthenticationError separately due to different constructor + auth_error = AuthenticationError("Test auth error", "test_provider") + assert isinstance(auth_error, GrainchainError) + assert isinstance(auth_error, Exception) + assert str(auth_error) == "Test auth error" + assert auth_error.provider == "test_provider" + + +class TestSandboxError: + """Test cases for SandboxError.""" + + @pytest.mark.unit + def test_sandbox_error_basic(self): + """Test basic SandboxError functionality.""" + error = SandboxError("Sandbox operation failed") + assert str(error) == "Sandbox operation failed" + assert isinstance(error, GrainchainError) + + @pytest.mark.unit + def test_sandbox_error_with_details(self): + """Test SandboxError with detailed message.""" + error = SandboxError("Command execution failed: exit code 1") + assert "Command execution failed" in str(error) + assert "exit code 1" in str(error) + + +class TestProviderError: + """Test cases for ProviderError.""" + + @pytest.mark.unit + def test_provider_error_basic(self): + """Test basic ProviderError functionality.""" + error = ProviderError("Provider failed", "test_provider") + assert str(error) == "Provider failed" + assert error.provider == "test_provider" + assert error.original_error is None + + @pytest.mark.unit + def test_provider_error_with_original(self): + """Test ProviderError with original exception.""" + original = ValueError("Original error") + error = ProviderError("Provider failed", "test_provider", original) + + assert str(error) == "Provider failed" + assert error.provider == "test_provider" + assert error.original_error is original + + @pytest.mark.unit + def test_provider_error_attributes(self): + """Test ProviderError attributes.""" + error = ProviderError("Test message", "e2b") + assert hasattr(error, "provider") + assert hasattr(error, "original_error") + assert error.provider == "e2b" + + +class TestConfigurationError: + """Test cases for ConfigurationError.""" + + @pytest.mark.unit + def test_configuration_error_basic(self): + """Test basic ConfigurationError functionality.""" + error = ConfigurationError("Invalid configuration") + assert str(error) == "Invalid configuration" + assert isinstance(error, GrainchainError) + + @pytest.mark.unit + def test_configuration_error_missing_key(self): + """Test ConfigurationError for missing configuration key.""" + error = ConfigurationError("Required configuration 'api_key' not found") + assert "api_key" in str(error) + assert "Required configuration" in str(error) + + +class TestTimeoutError: + """Test cases for TimeoutError.""" + + @pytest.mark.unit + def test_timeout_error_basic(self): + """Test basic TimeoutError functionality.""" + error = GrainchainTimeoutError("Operation timed out", 30) + assert str(error) == "Operation timed out" + assert error.timeout_seconds == 30 + + @pytest.mark.unit + def test_timeout_error_attributes(self): + """Test TimeoutError attributes.""" + error = GrainchainTimeoutError("Command timeout", 60) + assert hasattr(error, "timeout_seconds") + assert error.timeout_seconds == 60 + + @pytest.mark.unit + def test_timeout_error_different_timeouts(self): + """Test TimeoutError with different timeout values.""" + error1 = GrainchainTimeoutError("Short timeout", 5) + error2 = GrainchainTimeoutError("Long timeout", 300) + + assert error1.timeout_seconds == 5 + assert error2.timeout_seconds == 300 + + +class TestAuthenticationError: + """Test cases for AuthenticationError.""" + + @pytest.mark.unit + def test_authentication_error_basic(self): + """Test basic AuthenticationError functionality.""" + error = AuthenticationError("Authentication failed", "e2b") + assert str(error) == "Authentication failed" + assert error.provider == "e2b" + + @pytest.mark.unit + def test_authentication_error_attributes(self): + """Test AuthenticationError attributes.""" + error = AuthenticationError("Invalid API key", "modal") + assert hasattr(error, "provider") + assert error.provider == "modal" + + @pytest.mark.unit + def test_authentication_error_different_providers(self): + """Test AuthenticationError with different providers.""" + error1 = AuthenticationError("Auth failed", "e2b") + error2 = AuthenticationError("Auth failed", "daytona") + + assert error1.provider == "e2b" + assert error2.provider == "daytona" + + +class TestResourceError: + """Test cases for ResourceError.""" + + @pytest.mark.unit + def test_resource_error_basic(self): + """Test basic ResourceError functionality.""" + error = ResourceError("Resource allocation failed") + assert str(error) == "Resource allocation failed" + assert isinstance(error, GrainchainError) + + @pytest.mark.unit + def test_resource_error_memory_limit(self): + """Test ResourceError for memory limit.""" + error = ResourceError("Memory limit exceeded: 2GB") + assert "Memory limit exceeded" in str(error) + assert "2GB" in str(error) + + @pytest.mark.unit + def test_resource_error_cpu_limit(self): + """Test ResourceError for CPU limit.""" + error = ResourceError("CPU limit exceeded: 2.0 cores") + assert "CPU limit exceeded" in str(error) + assert "2.0 cores" in str(error) + + +class TestNetworkError: + """Test cases for NetworkError.""" + + @pytest.mark.unit + def test_network_error_basic(self): + """Test basic NetworkError functionality.""" + error = NetworkError("Network connection failed") + assert str(error) == "Network connection failed" + assert isinstance(error, GrainchainError) + + @pytest.mark.unit + def test_network_error_connection_timeout(self): + """Test NetworkError for connection timeout.""" + error = NetworkError("Connection timeout to provider API") + assert "Connection timeout" in str(error) + assert "provider API" in str(error) + + @pytest.mark.unit + def test_network_error_dns_resolution(self): + """Test NetworkError for DNS resolution.""" + error = NetworkError("DNS resolution failed for api.example.com") + assert "DNS resolution failed" in str(error) + assert "api.example.com" in str(error) + + +class TestExceptionChaining: + """Test exception chaining and context.""" + + @pytest.mark.unit + def test_exception_chaining_with_cause(self): + """Test exception chaining with __cause__.""" + original = ValueError("Original error") + + try: + raise original + except ValueError as e: + try: + raise SandboxError("Sandbox failed") from e + except SandboxError as chained: + assert chained.__cause__ is original + + @pytest.mark.unit + def test_exception_chaining_with_context(self): + """Test exception chaining with __context__.""" + try: + raise ValueError("Original error") + except ValueError: + try: + raise SandboxError("Sandbox failed") + except SandboxError as e: + assert isinstance(e.__context__, ValueError) + + @pytest.mark.unit + def test_provider_error_chaining(self): + """Test ProviderError with original exception.""" + original = ConnectionError("Network unreachable") + provider_error = ProviderError("Provider connection failed", "e2b", original) + + assert provider_error.original_error is original + assert str(provider_error) == "Provider connection failed" + + +class TestExceptionMessages: + """Test exception message formatting.""" + + @pytest.mark.unit + def test_detailed_error_messages(self): + """Test that error messages contain useful information.""" + # Test with detailed context + error = SandboxError( + "Command 'python script.py' failed with exit code 1 in sandbox 'test_123'" + ) + message = str(error) + + assert "python script.py" in message + assert "exit code 1" in message + assert "test_123" in message + + @pytest.mark.unit + def test_configuration_error_messages(self): + """Test ConfigurationError message formatting.""" + error = ConfigurationError( + "Required configuration 'api_key' not found for provider 'e2b'" + ) + message = str(error) + + assert "api_key" in message + assert "e2b" in message + assert "Required configuration" in message + + @pytest.mark.unit + def test_timeout_error_messages(self): + """Test TimeoutError message formatting.""" + error = GrainchainTimeoutError( + "Command execution timed out after 30 seconds", 30 + ) + message = str(error) + + assert "timed out" in message + assert "30 seconds" in message + assert error.timeout_seconds == 30 + + +class TestExceptionUsagePatterns: + """Test common exception usage patterns.""" + + @pytest.mark.unit + def test_exception_in_try_except(self): + """Test exception handling in try-except blocks.""" + + def failing_function(): + raise SandboxError("Test failure") + + with pytest.raises(SandboxError, match="Test failure"): + failing_function() + + @pytest.mark.unit + def test_exception_inheritance_catching(self): + """Test catching exceptions by base class.""" + + def failing_function(): + raise ProviderError("Provider failed", "test") + + # Should be caught by GrainchainError + with pytest.raises(GrainchainError): + failing_function() + + @pytest.mark.unit + def test_multiple_exception_types(self): + """Test handling multiple exception types.""" + + def maybe_failing_function(error_type): + if error_type == "sandbox": + raise SandboxError("Sandbox error") + elif error_type == "provider": + raise ProviderError("Provider error", "test") + elif error_type == "config": + raise ConfigurationError("Config error") + else: + return "success" + + # Test each exception type + with pytest.raises(SandboxError): + maybe_failing_function("sandbox") + + with pytest.raises(ProviderError): + maybe_failing_function("provider") + + with pytest.raises(ConfigurationError): + maybe_failing_function("config") + + # Test success case + result = maybe_failing_function("none") + assert result == "success" + + @pytest.mark.unit + def test_exception_with_additional_context(self): + """Test exceptions with additional context information.""" + context = { + "sandbox_id": "test_123", + "command": "python script.py", + "working_dir": "/tmp", + "timeout": 30, + } + + error_msg = f"Command '{context['command']}' failed in sandbox '{context['sandbox_id']}'" + error = SandboxError(error_msg) + + assert context["command"] in str(error) + assert context["sandbox_id"] in str(error) diff --git a/tests/unit/test_interfaces.py b/tests/unit/test_interfaces.py new file mode 100644 index 0000000..a869ee2 --- /dev/null +++ b/tests/unit/test_interfaces.py @@ -0,0 +1,529 @@ +"""Unit tests for core interfaces and data structures.""" + +from datetime import datetime + +import pytest + +from grainchain.core.interfaces import ( + ExecutionResult, + FileInfo, + SandboxConfig, + SandboxStatus, +) + + +class TestExecutionResult: + """Test cases for ExecutionResult.""" + + @pytest.mark.unit + def test_execution_result_creation(self): + """Test basic ExecutionResult creation.""" + result = ExecutionResult( + command="echo 'hello'", + return_code=0, + stdout="hello\n", + stderr="", + execution_time=0.1, + success=True, + ) + + assert result.command == "echo 'hello'" + assert result.return_code == 0 + assert result.stdout == "hello\n" + assert result.stderr == "" + assert result.execution_time == 0.1 + assert result.success is True + + @pytest.mark.unit + def test_execution_result_with_error(self): + """Test ExecutionResult with error output.""" + result = ExecutionResult( + command="exit 1", + return_code=1, + stdout="", + stderr="Command failed\n", + execution_time=0.05, + success=False, + ) + + assert result.command == "exit 1" + assert result.return_code == 1 + assert result.stdout == "" + assert result.stderr == "Command failed\n" + assert result.execution_time == 0.05 + + @pytest.mark.unit + def test_execution_result_success_property(self): + """Test ExecutionResult success property.""" + # Successful command + success_result = ExecutionResult( + command="echo 'test'", + return_code=0, + stdout="test\n", + stderr="", + execution_time=0.1, + success=True, + ) + assert success_result.success is True + + # Failed command + failed_result = ExecutionResult( + command="exit 1", + return_code=1, + stdout="", + stderr="error\n", + execution_time=0.1, + success=False, + ) + assert failed_result.success is False + + @pytest.mark.unit + def test_execution_result_with_long_output(self): + """Test ExecutionResult with long output.""" + long_output = "x" * 10000 + result = ExecutionResult( + command="python -c 'print(\"x\" * 10000)'", + return_code=0, + stdout=long_output, + stderr="", + execution_time=0.5, + success=True, + ) + + assert len(result.stdout) == 10000 + assert result.stdout == long_output + + @pytest.mark.unit + def test_execution_result_with_unicode(self): + """Test ExecutionResult with Unicode characters.""" + unicode_output = "Hello δΈ–η•Œ 🌍" + result = ExecutionResult( + command="echo 'Hello δΈ–η•Œ 🌍'", + return_code=0, + stdout=unicode_output, + stderr="", + execution_time=0.1, + success=True, + ) + + assert result.stdout == unicode_output + + @pytest.mark.unit + def test_execution_result_repr(self): + """Test ExecutionResult string representation.""" + result = ExecutionResult( + command="echo 'test'", + return_code=0, + stdout="test\n", + stderr="", + execution_time=0.1, + success=True, + ) + + repr_str = repr(result) + assert "ExecutionResult" in repr_str + assert "return_code=0" in repr_str + + @pytest.mark.unit + def test_execution_result_equality(self): + """Test ExecutionResult equality comparison.""" + import time + + timestamp = time.time() + + result1 = ExecutionResult( + command="echo 'test'", + return_code=0, + stdout="test\n", + stderr="", + execution_time=0.1, + success=True, + timestamp=timestamp, + ) + + result2 = ExecutionResult( + command="echo 'test'", + return_code=0, + stdout="test\n", + stderr="", + execution_time=0.1, + success=True, + timestamp=timestamp, + ) + + result3 = ExecutionResult( + command="echo 'different'", + return_code=0, + stdout="different\n", + stderr="", + execution_time=0.1, + success=True, + timestamp=timestamp, + ) + + assert result1 == result2 + assert result1 != result3 + + +class TestFileInfo: + """Test cases for FileInfo.""" + + @pytest.mark.unit + def test_file_info_basic(self): + """Test basic FileInfo creation.""" + file_info = FileInfo( + name="test.txt", + path="/tmp/test.txt", + size=1024, + is_directory=False, + modified_time=None, + ) + + assert file_info.name == "test.txt" + assert file_info.path == "/tmp/test.txt" + assert file_info.size == 1024 + assert file_info.is_directory is False + assert file_info.modified_time is None + + @pytest.mark.unit + def test_file_info_directory(self): + """Test FileInfo for directory.""" + dir_info = FileInfo( + name="mydir", + path="/tmp/mydir", + size=4096, + is_directory=True, + modified_time=datetime.now(), + ) + + assert dir_info.name == "mydir" + assert dir_info.path == "/tmp/mydir" + assert dir_info.size == 4096 + assert dir_info.is_directory is True + assert isinstance(dir_info.modified_time, datetime) + + @pytest.mark.unit + def test_file_info_with_timestamp(self): + """Test FileInfo with modification timestamp.""" + timestamp = datetime(2023, 1, 1, 12, 0, 0) + file_info = FileInfo( + name="timestamped.txt", + path="/tmp/timestamped.txt", + size=512, + is_directory=False, + modified_time=timestamp, + ) + + assert file_info.modified_time == timestamp + + @pytest.mark.unit + def test_file_info_large_file(self): + """Test FileInfo for large file.""" + large_size = 1024 * 1024 * 1024 # 1GB + file_info = FileInfo( + name="large_file.bin", + path="/tmp/large_file.bin", + size=large_size, + is_directory=False, + modified_time=None, + ) + + assert file_info.size == large_size + + @pytest.mark.unit + def test_file_info_repr(self): + """Test FileInfo string representation.""" + file_info = FileInfo( + name="test.txt", + path="/tmp/test.txt", + size=1024, + is_directory=False, + modified_time=None, + ) + + repr_str = repr(file_info) + assert "FileInfo" in repr_str + assert "test.txt" in repr_str + + @pytest.mark.unit + def test_file_info_equality(self): + """Test FileInfo equality comparison.""" + file1 = FileInfo( + name="test.txt", + path="/tmp/test.txt", + size=1024, + is_directory=False, + modified_time=None, + ) + + file2 = FileInfo( + name="test.txt", + path="/tmp/test.txt", + size=1024, + is_directory=False, + modified_time=None, + ) + + file3 = FileInfo( + name="different.txt", + path="/tmp/different.txt", + size=1024, + is_directory=False, + modified_time=None, + ) + + assert file1 == file2 + assert file1 != file3 + + +class TestSandboxStatus: + """Test cases for SandboxStatus enum.""" + + @pytest.mark.unit + def test_sandbox_status_values(self): + """Test SandboxStatus enum values.""" + assert SandboxStatus.CREATING.value == "creating" + assert SandboxStatus.RUNNING.value == "running" + assert SandboxStatus.STOPPED.value == "stopped" + assert SandboxStatus.ERROR.value == "error" + assert SandboxStatus.UNKNOWN.value == "unknown" + + @pytest.mark.unit + def test_sandbox_status_comparison(self): + """Test SandboxStatus comparison.""" + assert SandboxStatus.CREATING == SandboxStatus.CREATING + assert SandboxStatus.RUNNING != SandboxStatus.STOPPED + assert SandboxStatus.ERROR != SandboxStatus.UNKNOWN + + @pytest.mark.unit + def test_sandbox_status_string_representation(self): + """Test SandboxStatus string representation.""" + assert str(SandboxStatus.RUNNING) == "SandboxStatus.RUNNING" + assert SandboxStatus.RUNNING.value == "running" + + @pytest.mark.unit + def test_sandbox_status_iteration(self): + """Test iterating over SandboxStatus values.""" + statuses = list(SandboxStatus) + expected = [ + SandboxStatus.CREATING, + SandboxStatus.RUNNING, + SandboxStatus.STOPPED, + SandboxStatus.ERROR, + SandboxStatus.UNKNOWN, + ] + + assert len(statuses) == len(expected) + for status in expected: + assert status in statuses + + +class TestSandboxConfig: + """Test cases for SandboxConfig (already tested in test_config.py, but interface-specific tests here).""" + + @pytest.mark.unit + def test_sandbox_config_as_interface(self): + """Test SandboxConfig as part of the interface.""" + config = SandboxConfig( + timeout=300, + memory_limit="2GB", + cpu_limit=2.0, + working_directory="/workspace", + environment_vars={"TEST": "value"}, + auto_cleanup=True, + ) + + # Test that it can be used as expected in interfaces + assert isinstance(config.timeout, int) + assert isinstance(config.memory_limit, str) + assert isinstance(config.cpu_limit, float) + assert isinstance(config.working_directory, str) + assert isinstance(config.environment_vars, dict) + assert isinstance(config.auto_cleanup, bool) + + @pytest.mark.unit + def test_sandbox_config_optional_fields(self): + """Test SandboxConfig with optional fields.""" + config = SandboxConfig() + + # Optional fields should be None or have defaults + assert config.memory_limit is None + assert config.cpu_limit is None + assert config.image is None + assert config.timeout == 300 # Default + assert config.working_directory == "~" # Default + assert config.environment_vars == {} # Default + assert config.auto_cleanup is True # Default + + +class TestInterfaceIntegration: + """Test integration between different interface components.""" + + @pytest.mark.unit + def test_execution_result_in_context(self): + """Test ExecutionResult in realistic context.""" + # Simulate a command execution scenario + command = "python -c 'import sys; print(sys.version)'" + + # Successful execution + success_result = ExecutionResult( + command=command, + return_code=0, + stdout="Python 3.9.0\n", + stderr="", + execution_time=0.25, + success=True, + ) + + assert success_result.success + assert "Python" in success_result.stdout + assert success_result.execution_time > 0 + + @pytest.mark.unit + def test_file_info_listing_scenario(self): + """Test FileInfo in file listing scenario.""" + # Simulate a directory listing + files = [ + FileInfo( + name="script.py", + path="/workspace/script.py", + size=1024, + is_directory=False, + modified_time=datetime.now(), + ), + FileInfo( + name="data", + path="/workspace/data", + size=4096, + is_directory=True, + modified_time=datetime.now(), + ), + FileInfo( + name="output.txt", + path="/workspace/output.txt", + size=512, + is_directory=False, + modified_time=datetime.now(), + ), + ] + + # Test filtering + python_files = [f for f in files if f.name.endswith(".py")] + directories = [f for f in files if f.is_directory] + + assert len(python_files) == 1 + assert len(directories) == 1 + assert python_files[0].name == "script.py" + assert directories[0].name == "data" + + @pytest.mark.unit + def test_sandbox_status_workflow(self): + """Test SandboxStatus in workflow scenario.""" + # Simulate sandbox lifecycle + statuses = [ + SandboxStatus.CREATING, + SandboxStatus.RUNNING, + SandboxStatus.STOPPED, + ] + + # Test status transitions + assert statuses[0] == SandboxStatus.CREATING + assert statuses[1] == SandboxStatus.RUNNING + assert statuses[2] == SandboxStatus.STOPPED + + # Test that we can check for specific states + current_status = SandboxStatus.RUNNING + assert current_status in [SandboxStatus.RUNNING, SandboxStatus.CREATING] + assert current_status not in [SandboxStatus.STOPPED, SandboxStatus.ERROR] + + @pytest.mark.unit + def test_config_with_execution_context(self): + """Test SandboxConfig in execution context.""" + config = SandboxConfig( + timeout=60, + working_directory="/tmp", + environment_vars={"PYTHONPATH": "/custom/path"}, + ) + + # Simulate using config for command execution + + # The config should provide the context for execution + assert config.timeout == 60 + assert config.working_directory == "/tmp" + assert "PYTHONPATH" in config.environment_vars + assert config.environment_vars["PYTHONPATH"] == "/custom/path" + + +class TestInterfaceValidation: + """Test interface validation and edge cases.""" + + @pytest.mark.unit + def test_execution_result_edge_cases(self): + """Test ExecutionResult edge cases.""" + # Empty command + result = ExecutionResult( + command="", + return_code=0, + stdout="", + stderr="", + execution_time=0.0, + success=True, + ) + assert result.command == "" + assert result.execution_time == 0.0 + + # Negative return code + result = ExecutionResult( + command="test", + return_code=-1, + stdout="", + stderr="Signal terminated", + execution_time=0.1, + success=False, + ) + assert result.return_code == -1 + assert not result.success + + @pytest.mark.unit + def test_file_info_edge_cases(self): + """Test FileInfo edge cases.""" + # Zero-size file + file_info = FileInfo( + name="empty.txt", + path="/tmp/empty.txt", + size=0, + is_directory=False, + modified_time=None, + ) + assert file_info.size == 0 + + # File with special characters in name + file_info = FileInfo( + name="file with spaces & symbols!.txt", + path="/tmp/file with spaces & symbols!.txt", + size=100, + is_directory=False, + modified_time=None, + ) + assert " " in file_info.name + assert "&" in file_info.name + + @pytest.mark.unit + def test_sandbox_config_edge_cases(self): + """Test SandboxConfig edge cases.""" + # Very short timeout + config = SandboxConfig(timeout=1) + assert config.timeout == 1 + + # Very long timeout + config = SandboxConfig(timeout=86400) # 24 hours + assert config.timeout == 86400 + + # Empty environment variables + config = SandboxConfig(environment_vars={}) + assert config.environment_vars == {} + + # Many environment variables + many_vars = {f"VAR_{i}": f"value_{i}" for i in range(100)} + config = SandboxConfig(environment_vars=many_vars) + assert len(config.environment_vars) == 100 diff --git a/tests/unit/test_providers.py b/tests/unit/test_providers.py new file mode 100644 index 0000000..745de32 --- /dev/null +++ b/tests/unit/test_providers.py @@ -0,0 +1,590 @@ +"""Unit tests for provider implementations.""" + +import asyncio +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from grainchain.core.config import ProviderConfig +from grainchain.core.exceptions import ( + ConfigurationError, + ProviderError, +) +from grainchain.core.interfaces import SandboxConfig, SandboxStatus +from grainchain.providers.base import BaseSandboxProvider, BaseSandboxSession +from tests.conftest import MockSandboxSession + +# Check for optional dependencies at module level +try: + import e2b # noqa: F401 + + E2B_AVAILABLE = True +except ImportError: + E2B_AVAILABLE = False + +try: + import modal # noqa: F401 + + MODAL_AVAILABLE = True +except ImportError: + MODAL_AVAILABLE = False + +try: + import daytona_sdk # noqa: F401 + + DAYTONA_AVAILABLE = True +except ImportError: + DAYTONA_AVAILABLE = False + + +class TestBaseSandboxProvider: + """Test cases for the base sandbox provider.""" + + @pytest.mark.unit + def test_provider_init(self, provider_config): + """Test provider initialization.""" + + class TestProvider(BaseSandboxProvider): + @property + def name(self) -> str: + return "test" + + async def _create_session(self, config: SandboxConfig): + return MagicMock() + + provider = TestProvider(provider_config) + assert provider.config == provider_config + assert provider.name == "test" + assert not provider._closed + + @pytest.mark.unit + async def test_create_sandbox_success(self, mock_provider, test_config): + """Test successful sandbox creation.""" + session = await mock_provider.create_sandbox(test_config) + + assert session is not None + assert session.sandbox_id in mock_provider._sessions + assert len(mock_provider.created_sessions) == 1 + + @pytest.mark.unit + async def test_create_sandbox_after_close(self, mock_provider, test_config): + """Test sandbox creation after provider is closed.""" + await mock_provider.cleanup() + + with pytest.raises(ProviderError, match="Provider has been closed"): + await mock_provider.create_sandbox(test_config) + + @pytest.mark.unit + async def test_list_sandboxes(self, mock_provider, test_config): + """Test listing sandboxes.""" + # Create some sessions + session1 = await mock_provider.create_sandbox(test_config) + session2 = await mock_provider.create_sandbox(test_config) + + # List sandboxes + sandboxes = await mock_provider.list_sandboxes() + assert len(sandboxes) == 2 + assert session1.sandbox_id in sandboxes + assert session2.sandbox_id in sandboxes + + @pytest.mark.unit + async def test_get_sandbox_status(self, mock_provider, test_config): + """Test getting sandbox status.""" + # Unknown sandbox + status = await mock_provider.get_sandbox_status("unknown_id") + assert status == SandboxStatus.UNKNOWN + + # Known sandbox + session = await mock_provider.create_sandbox(test_config) + status = await mock_provider.get_sandbox_status(session.sandbox_id) + assert status == SandboxStatus.RUNNING + + @pytest.mark.unit + async def test_provider_cleanup(self, mock_provider, test_config): + """Test provider cleanup.""" + # Create some sessions + await mock_provider.create_sandbox(test_config) + await mock_provider.create_sandbox(test_config) + + # Cleanup + await mock_provider.cleanup() + + # Verify cleanup was called + assert mock_provider.cleanup_called + + @pytest.mark.unit + def test_get_config_value(self, mock_provider): + """Test getting configuration values.""" + value = mock_provider.get_config_value("api_key") + assert value == "test_key" + + default_value = mock_provider.get_config_value("nonexistent", "default") + assert default_value == "default" + + @pytest.mark.unit + def test_require_config_value(self, mock_provider): + """Test requiring configuration values.""" + value = mock_provider.require_config_value("api_key") + assert value == "test_key" + + with pytest.raises(ConfigurationError, match="Required configuration"): + mock_provider.require_config_value("nonexistent") + + @pytest.mark.unit + async def test_session_removal_on_close(self, mock_provider, test_config): + """Test that sessions are removed from provider when closed.""" + session = await mock_provider.create_sandbox(test_config) + sandbox_id = session.sandbox_id + + assert sandbox_id in mock_provider._sessions + + await session.close() + + assert sandbox_id not in mock_provider._sessions + + +class TestBaseSandboxSession: + """Test cases for the base sandbox session.""" + + @pytest.mark.unit + def test_session_init(self, mock_provider, test_config): + """Test session initialization.""" + # Use MockSandboxSession instead of abstract BaseSandboxSession + session = MockSandboxSession("test_id", mock_provider, test_config) + + assert session.sandbox_id == "test_id" + assert session.config == test_config + assert ( + session.status == SandboxStatus.RUNNING + ) # MockSandboxSession sets this to RUNNING + assert not session._closed + + @pytest.mark.unit + async def test_session_close(self, mock_session): + """Test session closing.""" + sandbox_id = mock_session.sandbox_id + provider = mock_session._provider + + await mock_session.close() + + assert mock_session._closed + assert mock_session.status == SandboxStatus.STOPPED + assert sandbox_id not in provider._sessions + + @pytest.mark.unit + async def test_session_double_close(self, mock_session): + """Test that double closing doesn't cause issues.""" + await mock_session.close() + await mock_session.close() # Should not raise + + assert mock_session._closed + + @pytest.mark.unit + async def test_ensure_not_closed(self, mock_session): + """Test operations on closed session.""" + await mock_session.close() + + with pytest.raises(ProviderError, match="is closed"): + mock_session._ensure_not_closed() + + @pytest.mark.unit + async def test_default_snapshot_not_implemented(self, mock_session): + """Test that default snapshot methods raise NotImplementedError.""" + # Create a minimal session that doesn't override snapshot methods + minimal_session = MinimalSandboxSession( + "test_id", mock_session._provider, mock_session._config + ) + + with pytest.raises(NotImplementedError, match="Snapshots not supported"): + await minimal_session.create_snapshot() + + with pytest.raises(NotImplementedError, match="Snapshots not supported"): + await minimal_session.restore_snapshot("test_snapshot") + + +class TestE2BProvider: + """Test cases for E2B provider.""" + + @pytest.mark.unit + @pytest.mark.skipif( + not E2B_AVAILABLE, + reason="E2B package not available", + ) + @patch("grainchain.providers.e2b.E2BSandbox") + def test_e2b_provider_init(self, mock_e2b, provider_config): + """Test E2B provider initialization.""" + from grainchain.providers.e2b import E2BProvider + + provider = E2BProvider(provider_config) + assert provider.name == "e2b" + + @pytest.mark.unit + def test_e2b_provider_missing_package(self): + """Test E2B provider with missing package.""" + # This test should run when e2b is not installed + try: + import e2b # noqa: F401 + + pytest.skip("E2B package is installed, skipping missing package test") + except ImportError: + pass + + from grainchain.core.config import ProviderConfig + from grainchain.providers.e2b import E2BProvider + + config = ProviderConfig("e2b", {"api_key": "test_key"}) + + with pytest.raises( + ImportError, match="E2B provider requires the 'e2b' package" + ): + E2BProvider(config) + + @pytest.mark.unit + @pytest.mark.skipif( + not E2B_AVAILABLE, + reason="E2B package not available", + ) + @patch("grainchain.providers.e2b.E2BSandbox") + def test_e2b_provider_missing_api_key(self, mock_e2b): + """Test E2B provider with missing API key.""" + from grainchain.providers.e2b import E2BProvider + + config = ProviderConfig("e2b", {}) + + with pytest.raises(ConfigurationError, match="E2B API key is required"): + E2BProvider(config) + + @pytest.mark.unit + @pytest.mark.skipif( + not E2B_AVAILABLE, + reason="E2B package not available", + ) + @patch("grainchain.providers.e2b.E2BSandbox") + async def test_e2b_session_creation(self, mock_e2b, provider_config, test_config): + """Test E2B session creation.""" + from grainchain.providers.e2b import E2BProvider + + # Mock E2B sandbox + mock_sandbox = AsyncMock() + mock_sandbox.id = "e2b_test_id" + mock_e2b.Sandbox.create.return_value = mock_sandbox + + provider = E2BProvider(provider_config) + session = await provider._create_session(test_config) + + assert session.sandbox_id == "e2b_test_id" + mock_e2b.Sandbox.create.assert_called_once() + + +class TestModalProvider: + """Test cases for Modal provider.""" + + @pytest.mark.unit + @pytest.mark.skipif( + not MODAL_AVAILABLE, + reason="Modal package not available", + ) + @patch("grainchain.providers.modal.modal") + def test_modal_provider_init(self, mock_modal, provider_config): + """Test Modal provider initialization.""" + from grainchain.providers.modal import ModalProvider + + provider = ModalProvider(provider_config) + assert provider.name == "modal" + + @pytest.mark.unit + def test_modal_provider_missing_package(self): + """Test Modal provider with missing package.""" + # This test should run when modal is not installed + try: + import modal # noqa: F401 + + pytest.skip("Modal package is installed, skipping missing package test") + except ImportError: + pass + + from grainchain.core.config import ProviderConfig + from grainchain.providers.modal import ModalProvider + + config = ProviderConfig("modal", {"token": "test_token"}) + + with pytest.raises( + ImportError, match="Modal provider requires the 'modal' package" + ): + ModalProvider(config) + + @pytest.mark.unit + @pytest.mark.skipif( + not MODAL_AVAILABLE, + reason="Modal package not available", + ) + @patch("grainchain.providers.modal.modal") + def test_modal_provider_missing_credentials(self, mock_modal): + """Test Modal provider with missing credentials.""" + from grainchain.providers.modal import ModalProvider + + config = ProviderConfig("modal", {}) + + with pytest.raises(ConfigurationError, match="Modal credentials are required"): + ModalProvider(config) + + +class TestDaytonaProvider: + """Test cases for Daytona provider.""" + + @pytest.mark.unit + @pytest.mark.skipif( + not DAYTONA_AVAILABLE, + reason="Daytona SDK package not available", + ) + @patch("grainchain.providers.daytona.Daytona") + def test_daytona_provider_init(self, mock_daytona, provider_config): + """Test Daytona provider initialization.""" + from grainchain.providers.daytona import DaytonaProvider + + provider = DaytonaProvider(provider_config) + assert provider.name == "daytona" + + @pytest.mark.unit + def test_daytona_provider_missing_package(self): + """Test Daytona provider with missing package.""" + # This test should run when daytona_sdk is not installed + try: + import daytona_sdk # noqa: F401 + + pytest.skip( + "Daytona SDK package is installed, skipping missing package test" + ) + except ImportError: + pass + + from grainchain.core.config import ProviderConfig + from grainchain.providers.daytona import DaytonaProvider + + config = ProviderConfig("daytona", {"api_key": "test_key"}) + + with pytest.raises( + ImportError, match="Daytona provider requires the 'daytona-sdk' package" + ): + DaytonaProvider(config) + + @pytest.mark.unit + @pytest.mark.skipif( + not DAYTONA_AVAILABLE, + reason="Daytona SDK package not available", + ) + @patch("grainchain.providers.daytona.Daytona") + def test_daytona_provider_missing_api_key(self, mock_daytona): + """Test Daytona provider with missing API key.""" + from grainchain.providers.daytona import DaytonaProvider + + config = ProviderConfig("daytona", {}) + + with pytest.raises(ConfigurationError, match="Daytona API key is required"): + DaytonaProvider(config) + + +class TestLocalProvider: + """Test cases for Local provider.""" + + @pytest.mark.unit + def test_local_provider_init(self, provider_config): + """Test Local provider initialization.""" + from grainchain.providers.local import LocalProvider + + provider = LocalProvider(provider_config) + assert provider.name == "local" + + @pytest.mark.unit + async def test_local_session_creation(self, provider_config, test_config): + """Test Local session creation.""" + from grainchain.providers.local import LocalProvider + + provider = LocalProvider(provider_config) + session = await provider._create_session(test_config) + + assert session.sandbox_id.startswith("local_") + assert session.status == SandboxStatus.RUNNING + + await session.close() + + @pytest.mark.unit + async def test_local_command_execution(self, provider_config, test_config): + """Test Local provider command execution.""" + from grainchain.providers.local import LocalProvider + + provider = LocalProvider(provider_config) + session = await provider._create_session(test_config) + + try: + result = await session.execute("echo 'test'") + assert result.return_code == 0 + assert "test" in result.stdout + finally: + await session.close() + + @pytest.mark.unit + async def test_local_file_operations(self, provider_config, test_config, temp_dir): + """Test Local provider file operations.""" + from grainchain.providers.local import LocalProvider + + # Override working directory to use temp dir + test_config.working_directory = str(temp_dir) + + provider = LocalProvider(provider_config) + session = await provider._create_session(test_config) + + try: + # Upload file + await session.upload_file("test.txt", "test content") + + # Download file + content = await session.download_file("test.txt") + assert content == b"test content" + + # List files + files = await session.list_files(".") + file_names = [f.name for f in files] + assert "test.txt" in file_names + finally: + await session.close() + + +class TestProviderErrorHandling: + """Test error handling across providers.""" + + @pytest.mark.unit + async def test_provider_error_propagation(self, failing_provider, test_config): + """Test that provider errors are properly propagated.""" + with pytest.raises(ProviderError, match="Failed to create sandbox"): + await failing_provider.create_sandbox(test_config) + + @pytest.mark.unit + async def test_timeout_error_handling(self, timeout_provider, test_config): + """Test timeout error handling.""" + with pytest.raises(TimeoutError): + # Use a very short timeout to trigger the error + await asyncio.wait_for( + timeout_provider.create_sandbox(test_config), timeout=0.1 + ) + + @pytest.mark.unit + @pytest.mark.skipif( + not E2B_AVAILABLE, + reason="E2B package not available", + ) + @patch("grainchain.providers.e2b.E2BSandbox") + async def test_authentication_error(self, mock_e2b, provider_config, test_config): + """Test authentication error handling.""" + from grainchain.providers.e2b import E2BProvider + + # Mock authentication failure + mock_e2b.Sandbox.create.side_effect = Exception("Authentication failed") + + provider = E2BProvider(provider_config) + + with pytest.raises(ProviderError, match="Failed to create sandbox"): + await provider._create_session(test_config) + + @pytest.mark.unit + async def test_network_error_handling(self, mock_provider, test_config): + """Test network error handling.""" + + class NetworkErrorProvider(BaseSandboxProvider): + @property + def name(self) -> str: + return "network_error" + + async def _create_session(self, config: SandboxConfig): + raise ConnectionError("Network unreachable") + + provider = NetworkErrorProvider(ProviderConfig("test", {})) + + with pytest.raises(ProviderError, match="Failed to create sandbox"): + await provider.create_sandbox(test_config) + + @pytest.mark.unit + async def test_resource_exhaustion_error(self, mock_provider, test_config): + """Test resource exhaustion error handling.""" + + class ResourceErrorProvider(BaseSandboxProvider): + @property + def name(self) -> str: + return "resource_error" + + async def _create_session(self, config: SandboxConfig): + raise Exception("Resource limit exceeded") + + provider = ResourceErrorProvider(ProviderConfig("test", {})) + + with pytest.raises(ProviderError, match="Failed to create sandbox"): + await provider.create_sandbox(test_config) + + +class TestProviderConfiguration: + """Test provider configuration handling.""" + + @pytest.mark.unit + def test_provider_config_validation(self): + """Test provider configuration validation.""" + config = ProviderConfig("test", {"key": "value"}) + + assert config.name == "test" + assert config.get("key") == "value" + assert config.get("missing", "default") == "default" + + @pytest.mark.unit + def test_provider_config_modification(self): + """Test provider configuration modification.""" + config = ProviderConfig("test", {}) + + config.set("new_key", "new_value") + assert config.get("new_key") == "new_value" + + @pytest.mark.unit + def test_sandbox_config_defaults(self): + """Test sandbox configuration defaults.""" + config = SandboxConfig() + + assert config.timeout == 300 + assert config.working_directory == "~" + assert config.auto_cleanup is True + assert config.keep_alive is False + + @pytest.mark.unit + def test_sandbox_config_customization(self): + """Test sandbox configuration customization.""" + config = SandboxConfig( + timeout=600, + memory_limit="4GB", + cpu_limit=2.0, + working_directory="/custom", + environment_vars={"CUSTOM": "value"}, + auto_cleanup=False, + ) + + assert config.timeout == 600 + assert config.memory_limit == "4GB" + assert config.cpu_limit == 2.0 + assert config.working_directory == "/custom" + assert config.environment_vars["CUSTOM"] == "value" + assert config.auto_cleanup is False + + +class MinimalSandboxSession(BaseSandboxSession): + """Minimal session implementation for testing base class behavior.""" + + async def execute(self, command: str, timeout: int = None) -> dict: + return {"stdout": "", "stderr": "", "exit_code": 0} + + async def upload_file(self, local_path: str, remote_path: str) -> None: + pass + + async def download_file(self, remote_path: str, local_path: str) -> None: + pass + + async def list_files(self, path: str = "/") -> list: + return [] + + async def _cleanup(self) -> None: + pass diff --git a/tests/unit/test_sandbox.py b/tests/unit/test_sandbox.py new file mode 100644 index 0000000..d94a4b0 --- /dev/null +++ b/tests/unit/test_sandbox.py @@ -0,0 +1,345 @@ +"""Unit tests for the core Sandbox class.""" + +from unittest.mock import patch + +import pytest + +from grainchain import Sandbox +from grainchain.core.exceptions import ( + ConfigurationError, + SandboxError, +) +from grainchain.core.interfaces import ExecutionResult, SandboxStatus + + +class TestSandbox: + """Test cases for the Sandbox class.""" + + @pytest.mark.unit + def test_sandbox_init_with_provider_string(self, config_manager): + """Test sandbox initialization with provider string.""" + with patch( + "grainchain.core.sandbox.get_config_manager", return_value=config_manager + ): + sandbox = Sandbox(provider="local") + assert sandbox.provider_name == "local" + + @pytest.mark.unit + def test_sandbox_init_with_provider_instance(self, mock_provider): + """Test sandbox initialization with provider instance.""" + sandbox = Sandbox(provider=mock_provider) + assert sandbox.provider_name == "mock" + + @pytest.mark.unit + def test_sandbox_init_with_config(self, mock_provider, test_config): + """Test sandbox initialization with custom config.""" + sandbox = Sandbox(provider=mock_provider, config=test_config) + assert sandbox._config == test_config + + @pytest.mark.unit + def test_sandbox_init_invalid_provider_type(self): + """Test sandbox initialization with invalid provider type.""" + with pytest.raises(ConfigurationError, match="Invalid provider type"): + Sandbox(provider=123) + + @pytest.mark.unit + async def test_sandbox_context_manager(self, mock_provider, test_config): + """Test sandbox as async context manager.""" + sandbox = Sandbox(provider=mock_provider, config=test_config) + + async with sandbox as ctx: + assert ctx is sandbox + assert sandbox.sandbox_id is not None + assert sandbox.status == SandboxStatus.RUNNING + + @pytest.mark.unit + async def test_sandbox_explicit_create_close(self, mock_provider, test_config): + """Test explicit sandbox creation and closing.""" + sandbox = Sandbox(provider=mock_provider, config=test_config) + + # Create sandbox + await sandbox.create() + assert sandbox.sandbox_id is not None + assert sandbox.status == SandboxStatus.RUNNING + + # Close sandbox + await sandbox.close() + assert sandbox._closed + + @pytest.mark.unit + async def test_sandbox_double_create_error(self, mock_provider, test_config): + """Test that creating sandbox twice raises error.""" + sandbox = Sandbox(provider=mock_provider, config=test_config) + + await sandbox.create() + with pytest.raises(SandboxError, match="Sandbox session already exists"): + await sandbox.create() + + @pytest.mark.unit + async def test_sandbox_reuse_closed_error(self, mock_provider, test_config): + """Test that reusing closed sandbox raises error.""" + sandbox = Sandbox(provider=mock_provider, config=test_config) + + async with sandbox: + pass # Sandbox is closed after context + + with pytest.raises(SandboxError, match="Cannot reuse a closed sandbox"): + async with sandbox: + pass + + @pytest.mark.unit + async def test_execute_command(self, mock_sandbox): + """Test command execution.""" + result = await mock_sandbox.execute("echo 'Hello, World!'") + + assert isinstance(result, ExecutionResult) + assert result.command == "echo 'Hello, World!'" + assert result.return_code == 0 + assert "Hello, World!" in result.stdout + + @pytest.mark.unit + async def test_execute_command_with_options(self, mock_sandbox): + """Test command execution with additional options.""" + result = await mock_sandbox.execute( + "python -c 'print(\"test\")'", + timeout=60, + working_dir="/tmp", + environment={"TEST_VAR": "test_value"}, + ) + + assert result.return_code == 0 + assert "Python output" in result.stdout + + @pytest.mark.unit + async def test_execute_command_failure(self, mock_sandbox): + """Test command execution failure.""" + result = await mock_sandbox.execute("exit 1") + + assert result.return_code == 1 + assert "Command failed" in result.stderr + + @pytest.mark.unit + async def test_execute_without_session(self, mock_provider, test_config): + """Test command execution without active session.""" + sandbox = Sandbox(provider=mock_provider, config=test_config) + + with pytest.raises(SandboxError, match="Sandbox session not initialized"): + await sandbox.execute("echo test") + + @pytest.mark.unit + async def test_upload_file_string(self, mock_sandbox): + """Test file upload with string content.""" + content = "Hello, World!" + await mock_sandbox.upload_file("/test/hello.txt", content) + + # Verify file was uploaded (check mock session) + session = mock_sandbox._session + assert "/test/hello.txt" in session.uploaded_files + assert session.uploaded_files["/test/hello.txt"]["content"] == content.encode() + + @pytest.mark.unit + async def test_upload_file_bytes(self, mock_sandbox): + """Test file upload with bytes content.""" + content = b"Binary content" + await mock_sandbox.upload_file("/test/binary.bin", content, mode="wb") + + session = mock_sandbox._session + assert "/test/binary.bin" in session.uploaded_files + assert session.uploaded_files["/test/binary.bin"]["content"] == content + + @pytest.mark.unit + async def test_download_file(self, mock_sandbox): + """Test file download.""" + # First upload a file + content = "Test file content" + await mock_sandbox.upload_file("/test/download.txt", content) + + # Then download it + downloaded = await mock_sandbox.download_file("/test/download.txt") + assert downloaded == content.encode() + + @pytest.mark.unit + async def test_download_nonexistent_file(self, mock_sandbox): + """Test downloading non-existent file.""" + with pytest.raises(SandboxError, match="File download failed"): + await mock_sandbox.download_file("/nonexistent/file.txt") + + @pytest.mark.unit + async def test_list_files(self, mock_sandbox): + """Test file listing.""" + # Upload some files + await mock_sandbox.upload_file("/test/file1.txt", "content1") + await mock_sandbox.upload_file("/test/file2.txt", "content2") + + files = await mock_sandbox.list_files("/test") + + assert len(files) >= 2 # At least our uploaded files + file_names = [f.name for f in files] + assert "file1.txt" in file_names + assert "file2.txt" in file_names + + @pytest.mark.unit + async def test_create_snapshot(self, mock_sandbox): + """Test snapshot creation.""" + # Upload a file first + await mock_sandbox.upload_file("/test/snapshot_test.txt", "snapshot content") + + snapshot_id = await mock_sandbox.create_snapshot() + + assert isinstance(snapshot_id, str) + assert snapshot_id.startswith("snapshot_") + + @pytest.mark.unit + async def test_restore_snapshot(self, mock_sandbox): + """Test snapshot restoration.""" + # Upload initial file + await mock_sandbox.upload_file("/test/initial.txt", "initial content") + + # Create snapshot + snapshot_id = await mock_sandbox.create_snapshot() + + # Upload another file + await mock_sandbox.upload_file("/test/after_snapshot.txt", "after snapshot") + + # Restore snapshot + await mock_sandbox.restore_snapshot(snapshot_id) + + # Verify restoration (the second file should be gone) + session = mock_sandbox._session + assert "/test/initial.txt" in session.uploaded_files + assert "/test/after_snapshot.txt" not in session.uploaded_files + + @pytest.mark.unit + async def test_restore_invalid_snapshot(self, mock_sandbox): + """Test restoring invalid snapshot.""" + with pytest.raises(SandboxError, match="Snapshot restoration failed"): + await mock_sandbox.restore_snapshot("invalid_snapshot_id") + + @pytest.mark.unit + def test_sandbox_properties(self, mock_sandbox): + """Test sandbox properties.""" + assert mock_sandbox.provider_name == "mock" + assert mock_sandbox.status == SandboxStatus.RUNNING + assert mock_sandbox.sandbox_id is not None + + @pytest.mark.unit + def test_sandbox_repr(self, mock_sandbox): + """Test sandbox string representation.""" + repr_str = repr(mock_sandbox) + assert "Sandbox" in repr_str + assert "provider=mock" in repr_str + assert "status=running" in repr_str + + @pytest.mark.unit + async def test_provider_creation_error(self, config_manager): + """Test error handling during provider creation.""" + with patch( + "grainchain.core.sandbox.get_config_manager", return_value=config_manager + ): + with pytest.raises(ConfigurationError, match="Unknown provider"): + Sandbox(provider="unknown_provider") + + @pytest.mark.unit + async def test_session_creation_error(self, failing_provider, test_config): + """Test error handling during session creation.""" + sandbox = Sandbox(provider=failing_provider, config=test_config) + + with pytest.raises(SandboxError, match="Failed to create sandbox"): + async with sandbox: + pass + + @pytest.mark.unit + async def test_operation_without_session_error(self, mock_provider, test_config): + """Test operations without active session raise appropriate errors.""" + sandbox = Sandbox(provider=mock_provider, config=test_config) + + operations = [ + lambda: sandbox.execute("echo test"), + lambda: sandbox.upload_file("/test.txt", "content"), + lambda: sandbox.download_file("/test.txt"), + lambda: sandbox.list_files("/"), + lambda: sandbox.create_snapshot(), + lambda: sandbox.restore_snapshot("test"), + ] + + for operation in operations: + with pytest.raises(SandboxError, match="Sandbox session not initialized"): + await operation() + + @pytest.mark.unit + async def test_timeout_handling(self, mock_sandbox): + """Test timeout handling in command execution.""" + with pytest.raises(SandboxError, match="Command execution failed"): + await mock_sandbox.execute("timeout_command") + + @pytest.mark.unit + async def test_config_override_timeout(self, mock_sandbox): + """Test that timeout parameter overrides config default.""" + # The mock session records the timeout parameter + await mock_sandbox.execute("echo test", timeout=120) + + session = mock_sandbox._session + last_command = session.executed_commands[-1] + assert last_command["timeout"] == 120 + + @pytest.mark.unit + async def test_environment_variables(self, mock_sandbox): + """Test environment variable passing.""" + env_vars = {"CUSTOM_VAR": "custom_value", "ANOTHER_VAR": "another_value"} + await mock_sandbox.execute("echo $CUSTOM_VAR", environment=env_vars) + + session = mock_sandbox._session + last_command = session.executed_commands[-1] + assert last_command["environment"] == env_vars + + @pytest.mark.unit + async def test_working_directory(self, mock_sandbox): + """Test working directory specification.""" + await mock_sandbox.execute("pwd", working_dir="/custom/dir") + + session = mock_sandbox._session + last_command = session.executed_commands[-1] + assert last_command["working_dir"] == "/custom/dir" + + +class TestSandboxIntegration: + """Integration tests for Sandbox with different providers.""" + + @pytest.mark.unit + async def test_multiple_sandboxes_same_provider(self, mock_provider, test_config): + """Test creating multiple sandboxes with the same provider.""" + sandbox1 = Sandbox(provider=mock_provider, config=test_config) + sandbox2 = Sandbox(provider=mock_provider, config=test_config) + + async with sandbox1: + async with sandbox2: + assert sandbox1.sandbox_id != sandbox2.sandbox_id + assert len(mock_provider.created_sessions) == 2 + + @pytest.mark.unit + async def test_sandbox_cleanup_on_exception(self, mock_provider, test_config): + """Test that sandbox is properly cleaned up on exception.""" + sandbox = Sandbox(provider=mock_provider, config=test_config) + + try: + async with sandbox: + raise ValueError("Test exception") + except ValueError: + pass + + assert sandbox._closed + + @pytest.mark.unit + async def test_provider_cleanup(self, mock_provider, test_config): + """Test provider cleanup functionality.""" + sandbox1 = Sandbox(provider=mock_provider, config=test_config) + sandbox2 = Sandbox(provider=mock_provider, config=test_config) + + async with sandbox1: + async with sandbox2: + pass + + # Clean up provider + await mock_provider.cleanup() + assert mock_provider._closed + assert len(mock_provider._sessions) == 0