From 6557758070a6c6ee56989d8abaa47ef36912ec6a Mon Sep 17 00:00:00 2001 From: DmitriyAlergant-T1A Date: Mon, 22 Sep 2025 12:57:05 -0400 Subject: [PATCH] Added LocalStorageClient and /storage/file API route --- backend/chainlit/data/__init__.py | 12 +- backend/chainlit/data/storage_clients/base.py | 8 + .../chainlit/data/storage_clients/local.py | 192 +++++++ backend/chainlit/server.py | 92 +++ backend/tests/conftest.py | 11 +- .../tests/data/storage_clients/test_local.py | 536 ++++++++++++++++++ 6 files changed, 849 insertions(+), 2 deletions(-) create mode 100644 backend/chainlit/data/storage_clients/local.py create mode 100644 backend/tests/data/storage_clients/test_local.py diff --git a/backend/chainlit/data/__init__.py b/backend/chainlit/data/__init__.py index 636362597b..a7652ca4d4 100644 --- a/backend/chainlit/data/__init__.py +++ b/backend/chainlit/data/__init__.py @@ -57,9 +57,13 @@ def get_data_layer(): azure_storage_key = os.getenv("APP_AZURE_STORAGE_ACCESS_KEY") is_using_azure = bool(azure_storage_account and azure_storage_key) + # Local Storage + local_storage_path = os.getenv("APP_LOCAL_STORAGE_PATH") + is_using_local = bool(local_storage_path) + storage_client = None - if sum([is_using_s3, is_using_gcs, is_using_azure]) > 1: + if sum([is_using_s3, is_using_gcs, is_using_azure, is_using_local]) > 1: warnings.warn( "Multiple storage configurations detected. Please use only one." ) @@ -92,6 +96,12 @@ def get_data_layer(): storage_account=azure_storage_account, storage_key=azure_storage_key, ) + elif is_using_local: + from chainlit.data.storage_clients.local import LocalStorageClient + + storage_client = LocalStorageClient( + storage_path=local_storage_path, + ) _data_layer = ChainlitDataLayer( database_url=database_url, storage_client=storage_client diff --git a/backend/chainlit/data/storage_clients/base.py b/backend/chainlit/data/storage_clients/base.py index 9bf458be35..cc03a69a51 100644 --- a/backend/chainlit/data/storage_clients/base.py +++ b/backend/chainlit/data/storage_clients/base.py @@ -26,3 +26,11 @@ async def delete_file(self, object_key: str) -> bool: @abstractmethod async def get_read_url(self, object_key: str) -> str: pass + + async def download_file(self, object_key: str) -> tuple[bytes, str] | None: + """ + Optional method to download file content directly, to allow files downloads to be proxied by ChainLit backend itself + + Returns (file_content, mime_type) if implemented, None otherwise. + """ + return None diff --git a/backend/chainlit/data/storage_clients/local.py b/backend/chainlit/data/storage_clients/local.py new file mode 100644 index 0000000000..1b1cd45bd9 --- /dev/null +++ b/backend/chainlit/data/storage_clients/local.py @@ -0,0 +1,192 @@ +import mimetypes +import shutil +from pathlib import Path +from typing import Any, Dict, Union +from urllib.request import pathname2url + +from chainlit import make_async +from chainlit.data.storage_clients.base import BaseStorageClient +from chainlit.logger import logger + + +class LocalStorageClient(BaseStorageClient): + """ + Class to enable local file system storage provider + """ + + def __init__(self, storage_path: str): + try: + self.storage_path = Path(storage_path).resolve() + + # Create storage directory if it doesn't exist + self.storage_path.mkdir(parents=True, exist_ok=True) + + logger.info( + f"LocalStorageClient initialized with path: {self.storage_path}" + ) + except Exception as e: + logger.warning(f"LocalStorageClient initialization error: {e}") + raise + + def _validate_object_key(self, object_key: str) -> Path: + """ + Validate object_key and ensure the resolved path is within storage directory. + + Args: + object_key: The object key to validate + + Returns: + Resolved Path object within storage directory + + Raises: + ValueError: If path traversal is detected or path is invalid + """ + try: + # Reject absolute paths immediately + if object_key.startswith("/"): + logger.warning(f"Absolute path rejected: {object_key}") + raise ValueError("Invalid object key: absolute paths not allowed") + + # Normalize object_key and check for traversal patterns + normalized_key = object_key.strip() + if ".." in normalized_key or "\\" in normalized_key: + logger.warning(f"Path traversal patterns detected: {object_key}") + raise ValueError("Invalid object key: path traversal detected") + + # Create the file path + file_path = self.storage_path / normalized_key + resolved_path = file_path.resolve() + + # Ensure the resolved path is within the storage directory + resolved_path.relative_to(self.storage_path) + + return resolved_path + except ValueError as e: + # Re-raise ValueError as is (our custom errors) + raise e + except Exception as e: + logger.warning(f"Path validation error for {object_key}: {e}") + raise ValueError(f"Invalid object key: {e}") + + def sync_get_read_url(self, object_key: str) -> str: + try: + file_path = self._validate_object_key(object_key) + if file_path.exists(): + # Return URL pointing to the backend's storage route + url_path = pathname2url(object_key) + return f"/storage/file/{url_path}" + else: + logger.warning(f"LocalStorageClient: File not found: {object_key}") + return object_key + except ValueError: + # Path validation failed, return object_key as fallback + return object_key + except Exception as e: + logger.warning(f"LocalStorageClient, get_read_url error: {e}") + return object_key + + async def get_read_url(self, object_key: str) -> str: + return await make_async(self.sync_get_read_url)(object_key) + + def sync_upload_file( + self, + object_key: str, + data: Union[bytes, str], + mime: str = "application/octet-stream", + overwrite: bool = True, + content_disposition: str | None = None, + ) -> Dict[str, Any]: + try: + file_path = self._validate_object_key(object_key) + + # Create parent directories if they don't exist + file_path.parent.mkdir(parents=True, exist_ok=True) + + # Check if file exists and overwrite is False + if file_path.exists() and not overwrite: + logger.warning( + f"LocalStorageClient: File exists and overwrite=False: {object_key}" + ) + return {} + + # Write data to file + if isinstance(data, str): + file_path.write_text(data, encoding="utf-8") + else: + file_path.write_bytes(data) + + # Generate URL for the uploaded file using backend's storage route + relative_path = file_path.relative_to(self.storage_path) + url_path = pathname2url(str(relative_path)) + url = f"/storage/file/{url_path}" + + return {"object_key": object_key, "url": url} + except ValueError as e: + logger.warning(f"LocalStorageClient, upload_file error: {e}") + return {} + except Exception as e: + logger.warning(f"LocalStorageClient, upload_file error: {e}") + return {} + + async def upload_file( + self, + object_key: str, + data: Union[bytes, str], + mime: str = "application/octet-stream", + overwrite: bool = True, + content_disposition: str | None = None, + ) -> Dict[str, Any]: + return await make_async(self.sync_upload_file)( + object_key, data, mime, overwrite, content_disposition + ) + + def sync_delete_file(self, object_key: str) -> bool: + try: + file_path = self._validate_object_key(object_key) + if file_path.exists(): + if file_path.is_file(): + file_path.unlink() + elif file_path.is_dir(): + shutil.rmtree(file_path) + return True + else: + logger.warning( + f"LocalStorageClient: File not found for deletion: {object_key}" + ) + return False + except ValueError as e: + logger.warning(f"LocalStorageClient, delete_file error: {e}") + return False + except Exception as e: + logger.warning(f"LocalStorageClient, delete_file error: {e}") + return False + + async def delete_file(self, object_key: str) -> bool: + return await make_async(self.sync_delete_file)(object_key) + + def sync_download_file(self, object_key: str) -> tuple[bytes, str] | None: + try: + file_path = self._validate_object_key(object_key) + if not file_path.exists() or not file_path.is_file(): + logger.warning( + f"LocalStorageClient: File not found for download: {object_key}" + ) + return None + + # Get MIME type + mime_type, _ = mimetypes.guess_type(str(file_path)) + if not mime_type: + mime_type = "application/octet-stream" + + # Read file content + content = file_path.read_bytes() + return (content, mime_type) + except ValueError as e: + logger.warning(f"LocalStorageClient, download_file error: {e}") + return None + except Exception as e: + logger.warning(f"LocalStorageClient, download_file error: {e}") + return None + + async def download_file(self, object_key: str) -> tuple[bytes, str] | None: + return await make_async(self.sync_download_file)(object_key) diff --git a/backend/chainlit/server.py b/backend/chainlit/server.py index 800326165a..b13f8af318 100644 --- a/backend/chainlit/server.py +++ b/backend/chainlit/server.py @@ -1636,6 +1636,98 @@ async def get_file( raise HTTPException(status_code=404, detail="File not found") +@router.get("/storage/file/{object_key:path}") +async def get_storage_file( + object_key: str, + current_user: UserParam, +): + """Get a file from the storage client if it supports direct downloads.""" + from chainlit.data import get_data_layer + + data_layer = get_data_layer() + if not data_layer or not data_layer.storage_client: + raise HTTPException( + status_code=404, + detail="Storage not configured", + ) + + # Validate user authentication + if not current_user: + raise HTTPException(status_code=401, detail="Unauthorized") + + # Extract thread_id from object_key to validate thread ownership + # Object key patterns: + # 1. threads/{thread_id}/files/{element.id} (chainlit_data_layer) + # 2. {user_id}/{thread_id}/{element.id} (dynamodb) + # 3. {user_id}/{element.id}[/{element.name}] (sql_alchemy) + thread_id = None + + # Try to extract thread_id from different patterns + parts = object_key.split("/") + if len(parts) >= 3: + if parts[0] == "threads": + # Pattern: threads/{thread_id}/files/{element.id} + thread_id = parts[1] + elif len(parts) == 3: + # Pattern: {user_id}/{thread_id}/{element.id} (dynamodb) + # We need to verify this is actually a thread_id by checking if it exists + potential_thread_id = parts[1] + try: + # Check if this looks like a thread by validating thread author + await is_thread_author(current_user.identifier, potential_thread_id) + thread_id = potential_thread_id + except HTTPException: + # Not a valid thread or user doesn't have access + pass + + # If we found a thread_id, validate thread ownership + if thread_id: + await is_thread_author(current_user.identifier, thread_id) + else: + # For files without thread association (pattern 3), we should still + # validate that the user_id in the path matches the current user + if len(parts) >= 2: + user_id_in_path = parts[0] + if user_id_in_path != current_user.identifier: + raise HTTPException( + status_code=403, + detail="Access denied: file belongs to different user", + ) + + # Try to extract element_id and get the original filename from database + element_id = None + element_name = None + + # Extract element_id from object_key patterns + if len(parts) >= 4 and parts[0] == "threads" and parts[2] == "files": + # Pattern: threads/{thread_id}/files/{element_id} + element_id = parts[3] + # Query database for element details + if thread_id and element_id: + element = await data_layer.get_element(thread_id, element_id) + if element: + element_name = element.get("name") + + # Only serve files if storage client implements download_file + file_data = await data_layer.storage_client.download_file(object_key) + if file_data is None: + raise HTTPException( + status_code=404, + detail="File not found or storage client does not support direct downloads", + ) + + content, mime_type = file_data + + # Use the original filename if available, otherwise fall back to the UUID + filename = element_name if element_name else Path(object_key).name + + return Response( + content=content, + media_type=mime_type, + headers={"Content-Disposition": f"inline; filename={filename}"}, + ) + + @router.get("/favicon") async def get_favicon(): """Get the favicon for the UI.""" diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py index 91e1bc74c0..063aefb001 100644 --- a/backend/tests/conftest.py +++ b/backend/tests/conftest.py @@ -7,6 +7,7 @@ import pytest import pytest_asyncio +import chainlit.data as data_module from chainlit import config from chainlit.callbacks import data_layer from chainlit.context import ChainlitContext, context_var @@ -94,10 +95,18 @@ def mock_data_layer(monkeypatch: pytest.MonkeyPatch) -> AsyncMock: @pytest.fixture -def mock_get_data_layer(mock_data_layer: AsyncMock, test_config: config.ChainlitConfig): +def mock_get_data_layer( + mock_data_layer: AsyncMock, + test_config: config.ChainlitConfig, + monkeypatch: pytest.MonkeyPatch, +): # Instantiate mock data layer mock_get_data_layer = Mock(return_value=mock_data_layer) + # Clear the cached data layer so every test exercises its own factory. + monkeypatch.setattr(data_module, "_data_layer", None) + monkeypatch.setattr(data_module, "_data_layer_initialized", False) + # Configure it using @data_layer decorator return data_layer(mock_get_data_layer) diff --git a/backend/tests/data/storage_clients/test_local.py b/backend/tests/data/storage_clients/test_local.py new file mode 100644 index 0000000000..8c911613b7 --- /dev/null +++ b/backend/tests/data/storage_clients/test_local.py @@ -0,0 +1,536 @@ +import tempfile +from pathlib import Path +from unittest.mock import AsyncMock, Mock, patch + +import pytest +from fastapi import HTTPException +from fastapi.testclient import TestClient + +from chainlit.auth import get_current_user +from chainlit.data.storage_clients.local import LocalStorageClient +from chainlit.server import app + + +class TestLocalStorageClient: + @pytest.fixture + def temp_storage_dir(self): + """Create a temporary directory for testing.""" + with tempfile.TemporaryDirectory() as tmp_dir: + yield Path(tmp_dir) + + @pytest.fixture + def local_client(self, temp_storage_dir): + """Create a LocalStorageClient instance with temporary storage.""" + return LocalStorageClient(storage_path=str(temp_storage_dir)) + + def test_init(self, temp_storage_dir): + """Test LocalStorageClient initialization.""" + client = LocalStorageClient(storage_path=str(temp_storage_dir)) + assert client.storage_path == temp_storage_dir.resolve() + assert client.storage_path.exists() + + @pytest.mark.asyncio + async def test_upload_file_text(self, local_client, temp_storage_dir): + """Test uploading a text file.""" + content = "Hello, World!" + result = await local_client.upload_file("test.txt", content, "text/plain") + + # Check return value + assert result["object_key"] == "test.txt" + assert "url" in result + + # Check file was actually written + file_path = temp_storage_dir / "test.txt" + assert file_path.exists() + assert file_path.read_text() == content + + @pytest.mark.asyncio + async def test_upload_file_bytes(self, local_client, temp_storage_dir): + """Test uploading binary data.""" + content = b"Binary content here" + result = await local_client.upload_file( + "test.bin", content, "application/octet-stream" + ) + + # Check return value + assert result["object_key"] == "test.bin" + + # Check file was actually written + file_path = temp_storage_dir / "test.bin" + assert file_path.exists() + assert file_path.read_bytes() == content + + @pytest.mark.asyncio + async def test_upload_file_overwrite_false(self, local_client, temp_storage_dir): + """Test upload with overwrite=False when file exists.""" + # First upload + await local_client.upload_file("test.txt", "Original content") + + # Second upload with overwrite=False + result = await local_client.upload_file( + "test.txt", "New content", overwrite=False + ) + + # Should return empty dict and not overwrite + assert result == {} + + file_path = temp_storage_dir / "test.txt" + assert file_path.read_text() == "Original content" + + @pytest.mark.asyncio + async def test_get_read_url(self, local_client, temp_storage_dir): + """Test getting read URL for existing file.""" + # Upload a file first + await local_client.upload_file("test.txt", "content") + + url = await local_client.get_read_url("test.txt") + assert url == "/storage/file/test.txt" + + @pytest.mark.asyncio + async def test_get_read_url_special_characters( + self, local_client, temp_storage_dir + ): + """Test getting read URL for file with special characters.""" + # Upload a file with special characters + object_key = "folder with spaces/file with spaces.txt" + await local_client.upload_file(object_key, "content") + + url = await local_client.get_read_url(object_key) + # URL should be properly encoded + assert "folder%20with%20spaces/file%20with%20spaces.txt" in url + + @pytest.mark.asyncio + async def test_url_consistency_upload_and_read( + self, local_client, temp_storage_dir + ): + """Test that URL format is consistent between upload_file and get_read_url.""" + object_key = "test/consistency.txt" + content = "test content" + + # Upload file and get the URL from upload response + upload_result = await local_client.upload_file(object_key, content) + upload_url = upload_result["url"] + + # Get read URL using get_read_url method + read_url = await local_client.get_read_url(object_key) + + # Both URLs should use the same format (backend's storage route) + assert upload_url == read_url, ( + f"URL inconsistency: upload={upload_url}, read={read_url}" + ) + assert upload_url.startswith("/storage/file/"), ( + f"Upload URL should use storage route: {upload_url}" + ) + assert read_url.startswith("/storage/file/"), ( + f"Read URL should use storage route: {read_url}" + ) + + # Verify URL contains the expected object key + assert "test/consistency.txt" in upload_url + assert "test/consistency.txt" in read_url + + @pytest.mark.asyncio + async def test_download_file(self, local_client, temp_storage_dir): + """Test downloading file content.""" + content = "File content for download" + await local_client.upload_file("download_test.txt", content) + + result = await local_client.download_file("download_test.txt") + assert result is not None + + file_content, mime_type = result + assert file_content == content.encode() + assert mime_type == "text/plain" + + @pytest.mark.asyncio + async def test_delete_file(self, local_client, temp_storage_dir): + """Test deleting a file.""" + # Upload a file first + await local_client.upload_file("to_delete.txt", "delete me") + file_path = temp_storage_dir / "to_delete.txt" + assert file_path.exists() + + # Delete the file + result = await local_client.delete_file("to_delete.txt") + assert result is True + assert not file_path.exists() + + # Security Tests + @pytest.mark.asyncio + async def test_path_traversal_attacks(self, local_client, temp_storage_dir): + """Test that path traversal attempts are blocked.""" + # Create a file outside storage directory to attempt to access + outside_file = temp_storage_dir.parent / "secret.txt" + outside_file.write_text("secret content") + + # Test various path traversal attempts + path_traversal_attempts = [ + "../secret.txt", + "../../secret.txt", + "../../../etc/passwd", + "/etc/passwd", + "threads/../../../secret.txt", + "valid/../../../secret.txt", + "..\\secret.txt", # Windows style + "threads/..\\..\\secret.txt", + ] + + for malicious_path in path_traversal_attempts: + # Upload should fail + result = await local_client.upload_file(malicious_path, "attack") + assert result == {}, f"Upload should fail for {malicious_path}" + + # Download should fail + result = await local_client.download_file(malicious_path) + assert result is None, f"Download should fail for {malicious_path}" + + # Delete should fail + result = await local_client.delete_file(malicious_path) + assert result is False, f"Delete should fail for {malicious_path}" + + # get_read_url should return fallback + result = await local_client.get_read_url(malicious_path) + assert result == malicious_path, ( + f"get_read_url should return fallback for {malicious_path}" + ) + + @pytest.mark.asyncio + async def test_path_validation_edge_cases(self, local_client): + """Test edge cases in path validation.""" + edge_cases = [ + "", # empty string + "/", # root + "//", # double slash + "./file.txt", # current directory + "file.txt/../other.txt", # traversal in middle + "null\x00byte.txt", # null byte injection + "very/deep/nested/../../../attack.txt", # deep nested traversal + ] + + for edge_case in edge_cases: + result = await local_client.upload_file(edge_case, "content") + # Most should fail, but some like "./file.txt" might be normalized + if edge_case in ["", "/", "//", "null\x00byte.txt"]: + assert result == {}, f"Should reject {edge_case}" + + @pytest.mark.asyncio + async def test_safe_paths_still_work(self, local_client): + """Test that legitimate paths still work after security fixes.""" + safe_paths = [ + "file.txt", + "folder/file.txt", + "deep/nested/folder/file.txt", + "threads/123/files/456.txt", + "user_id/element_id.txt", + "user_id/thread_id/element_id", + "file with spaces.txt", + "file-with-dashes_and_underscores.txt", + ] + + for safe_path in safe_paths: + # Upload should succeed + result = await local_client.upload_file(safe_path, "content") + assert result.get("object_key") == safe_path + + # Download should work + result = await local_client.download_file(safe_path) + assert result is not None + content, _mime_type = result + assert content == b"content" + + +class TestLocalStorageAPIIntegration: + """End-to-end integration tests with the FastAPI server.""" + + @pytest.fixture + def temp_storage_dir(self): + """Create a temporary directory for testing.""" + with tempfile.TemporaryDirectory() as tmp_dir: + yield Path(tmp_dir) + + @pytest.fixture + def mock_data_layer(self, temp_storage_dir): + """Mock the data layer with local storage client.""" + client = LocalStorageClient(storage_path=str(temp_storage_dir)) + + data_layer = Mock() + data_layer.storage_client = client + data_layer.get_element = AsyncMock(return_value=None) + + with patch("chainlit.data.get_data_layer", return_value=data_layer): + yield data_layer + + @pytest.fixture + def test_client(self): + """Create test client for FastAPI app.""" + return TestClient(app) + + @pytest.fixture + def mock_user(self): + """Mock authenticated user.""" + from unittest.mock import Mock + + user = Mock() + user.id = "test_user" + user.identifier = "test_user" + return user + + def test_get_storage_file_success( + self, test_client, mock_data_layer, temp_storage_dir, mock_user + ): + """Test successful file retrieval via API.""" + # Upload a file using the storage client directly + content = "Test file content" + storage_client = mock_data_layer.storage_client + + # Use sync method for direct upload in test + storage_client.sync_upload_file("test.txt", content, "text/plain") + + # Mock authentication + # Mock the dependency directly in the app + def mock_get_current_user(): + return mock_user + + app.dependency_overrides[get_current_user] = mock_get_current_user + try: + response = test_client.get("/storage/file/test.txt") + finally: + app.dependency_overrides.clear() + + assert response.status_code == 200 + assert response.text == content + assert response.headers["content-type"] == "text/plain; charset=utf-8" + + def test_get_storage_file_not_found(self, test_client, mock_data_layer, mock_user): + """Test file retrieval when file doesn't exist.""" + + def mock_get_current_user(): + return mock_user + + app.dependency_overrides[get_current_user] = mock_get_current_user + try: + response = test_client.get("/storage/file/nonexistent.txt") + assert response.status_code == 404 + assert "File not found" in response.json()["detail"] + finally: + app.dependency_overrides.clear() + + def test_get_storage_file_no_storage_configured(self, test_client, mock_user): + """Test API behavior when no storage is configured.""" + + def mock_get_current_user(): + return mock_user + + def mock_get_data_layer(): + return None + + with patch("chainlit.data.get_data_layer", side_effect=mock_get_data_layer): + app.dependency_overrides[get_current_user] = mock_get_current_user + try: + response = test_client.get("/storage/file/test.txt") + finally: + app.dependency_overrides.clear() + + assert response.status_code == 404 + assert "Storage not configured" in response.json()["detail"] + + def test_get_storage_file_storage_no_download_support(self, test_client, mock_user): + """Test API behavior when storage doesn't support direct downloads.""" + # Mock a storage client that doesn't implement download_file + mock_storage_client = Mock() + mock_storage_client.download_file = AsyncMock(return_value=None) + + mock_data_layer = Mock() + mock_data_layer.storage_client = mock_storage_client + + def mock_get_current_user(): + return mock_user + + with patch("chainlit.data.get_data_layer", return_value=mock_data_layer): + app.dependency_overrides[get_current_user] = mock_get_current_user + try: + response = test_client.get("/storage/file/test.txt") + finally: + app.dependency_overrides.clear() + + assert response.status_code == 404 + assert "does not support direct downloads" in response.json()["detail"] + + def test_get_storage_file_path_traversal_blocked( + self, test_client, mock_data_layer, mock_user + ): + """Test that path traversal attempts are blocked at API level.""" + path_traversal_attempts = [ + "../../../etc/passwd", + "/etc/passwd", + "threads/../../../secret.txt", + ] + + def mock_get_current_user(): + return mock_user + + app.dependency_overrides[get_current_user] = mock_get_current_user + try: + for malicious_path in path_traversal_attempts: + # URL encode the malicious path + from urllib.parse import quote + + encoded_path = quote(malicious_path, safe="") + + response = test_client.get(f"/storage/file/{encoded_path}") + + # Should fail - either 400 (bad request), 404 (file not found due to path validation) + # or 403 (access denied) + assert response.status_code in [400, 403, 404], ( + f"Expected 400, 403 or 404 for {malicious_path}, got {response.status_code}" + ) + finally: + app.dependency_overrides.clear() + + def test_get_storage_file_thread_authorization( + self, test_client, mock_data_layer, temp_storage_dir + ): + """Test thread authorization in storage file endpoint.""" + storage_client = mock_data_layer.storage_client + + # Create a file with thread structure + object_key = "threads/thread123/files/element456.txt" + storage_client.sync_upload_file(object_key, "thread content", "text/plain") + + # Mock different users + from unittest.mock import Mock + + authorized_user = Mock() + authorized_user.id = "user1" + authorized_user.identifier = "user1" + + unauthorized_user = Mock() + unauthorized_user.id = "user2" + unauthorized_user.identifier = "user2" + + # Mock is_thread_author to allow only user1 access to thread123 + async def mock_is_thread_author(user_id, thread_id): + if user_id == "user1" and thread_id == "thread123": + return True + raise HTTPException(status_code=403, detail="Access denied") + + with patch( + "chainlit.server.is_thread_author", side_effect=mock_is_thread_author + ): + # Authorized user should succeed + def mock_get_authorized_user(): + return authorized_user + + app.dependency_overrides[get_current_user] = mock_get_authorized_user + try: + response = test_client.get( + "/storage/file/threads/thread123/files/element456.txt" + ) + assert response.status_code == 200 + assert response.text == "thread content" + finally: + app.dependency_overrides.clear() + + # Unauthorized user should be denied + def mock_get_unauthorized_user(): + return unauthorized_user + + app.dependency_overrides[get_current_user] = mock_get_unauthorized_user + try: + response = test_client.get( + "/storage/file/threads/thread123/files/element456.txt" + ) + assert response.status_code == 403 + finally: + app.dependency_overrides.clear() + + def test_get_storage_file_user_file_authorization( + self, test_client, mock_data_layer, temp_storage_dir + ): + """Test user file authorization for non-thread files.""" + storage_client = mock_data_layer.storage_client + + # Create a file with user structure (sql_alchemy pattern) + object_key = "user1/element123.txt" + storage_client.sync_upload_file(object_key, "user file content", "text/plain") + + # Mock users + from unittest.mock import Mock + + correct_user = Mock() + correct_user.id = "user1" + correct_user.identifier = "user1" + + wrong_user = Mock() + wrong_user.id = "user2" + wrong_user.identifier = "user2" + + # Correct user should succeed + def mock_get_correct_user(): + return correct_user + + app.dependency_overrides[get_current_user] = mock_get_correct_user + try: + response = test_client.get("/storage/file/user1/element123.txt") + assert response.status_code == 200 + assert response.text == "user file content" + finally: + app.dependency_overrides.clear() + + # Wrong user should be denied + def mock_get_wrong_user(): + return wrong_user + + app.dependency_overrides[get_current_user] = mock_get_wrong_user + try: + response = test_client.get("/storage/file/user1/element123.txt") + assert response.status_code == 403 + assert ( + "Access denied: file belongs to different user" + in response.json()["detail"] + ) + finally: + app.dependency_overrides.clear() + + @pytest.mark.asyncio + async def test_security_and_edge_cases(self, temp_storage_dir): + """Test security features and edge cases.""" + client = LocalStorageClient(storage_path=str(temp_storage_dir)) + + # Test very long filename (legitimate use case) + long_name = "a" * 100 + ".txt" # Reduced from 200 to be more realistic + result = await client.upload_file(long_name, "content") + assert result["object_key"] == long_name + + # Test file with no extension + await client.upload_file("no_extension", "content") + download_result = await client.download_file("no_extension") + assert download_result is not None + _, mime_type = download_result + assert mime_type == "application/octet-stream" # Default MIME type + + # Test empty file + await client.upload_file("empty.txt", "") + download_result = await client.download_file("empty.txt") + assert download_result is not None + content, _ = download_result + assert content == b"" + + # Test that path traversal is blocked in all methods + traversal_path = "../../../etc/passwd" + + # Upload should fail + upload_result = await client.upload_file(traversal_path, "malicious") + assert upload_result == {} + + # Download should fail + download_result = await client.download_file(traversal_path) + assert download_result is None + + # Delete should fail + delete_result = await client.delete_file(traversal_path) + assert delete_result is False + + # get_read_url should return fallback + url_result = await client.get_read_url(traversal_path) + assert url_result == traversal_path