diff --git a/surfsense_backend/tests/test_blocknote_converter.py b/surfsense_backend/tests/test_blocknote_converter.py
new file mode 100644
index 000000000..60a770e47
--- /dev/null
+++ b/surfsense_backend/tests/test_blocknote_converter.py
@@ -0,0 +1,380 @@
+"""
+Tests for the blocknote_converter utility module.
+
+These tests validate:
+1. Empty/invalid input is handled gracefully (returns None, not crash)
+2. API failures don't crash the application
+3. Response structure is correctly parsed
+4. Network errors are properly handled
+"""
+
+import pytest
+from unittest.mock import AsyncMock, MagicMock, patch
+import httpx
+
+# Skip these tests if app dependencies aren't installed
+pytest.importorskip("yaml")
+
+from app.utils.blocknote_converter import (
+ convert_markdown_to_blocknote,
+ convert_blocknote_to_markdown,
+)
+
+
+class TestMarkdownToBlocknoteInputValidation:
+ """
+ Tests validating input handling for markdown to BlockNote conversion.
+ """
+
+ @pytest.mark.asyncio
+ async def test_empty_string_returns_none(self):
+ """
+ Empty markdown must return None, not error.
+ This is a common edge case when content hasn't been written yet.
+ """
+ result = await convert_markdown_to_blocknote("")
+ assert result is None
+
+ @pytest.mark.asyncio
+ async def test_whitespace_only_returns_none(self):
+ """
+ Whitespace-only content should be treated as empty.
+ Spaces, tabs, newlines alone don't constitute content.
+ """
+ test_cases = [" ", "\t\t", "\n\n", " \n \t "]
+
+ for whitespace in test_cases:
+ result = await convert_markdown_to_blocknote(whitespace)
+ assert result is None, f"Expected None for whitespace: {repr(whitespace)}"
+
+ @pytest.mark.asyncio
+ async def test_very_short_content_returns_fallback(self):
+ """
+ Very short content should return a fallback document.
+ Content too short to convert meaningfully should still return something.
+ """
+ result = await convert_markdown_to_blocknote("x")
+
+ assert result is not None
+ assert isinstance(result, list)
+ assert len(result) > 0
+ # Fallback document should be a paragraph
+ assert result[0]["type"] == "paragraph"
+
+
+class TestMarkdownToBlocknoteNetworkResilience:
+ """
+ Tests validating network error handling.
+ The converter should never crash on network issues.
+ """
+
+ @pytest.mark.asyncio
+ @patch("app.utils.blocknote_converter.httpx.AsyncClient")
+ @patch("app.utils.blocknote_converter.config")
+ async def test_timeout_returns_none_not_exception(
+ self, mock_config, mock_client_class
+ ):
+ """
+ Network timeout must return None, not raise exception.
+ Timeouts are common and shouldn't crash the application.
+ """
+ mock_config.NEXT_FRONTEND_URL = "http://localhost:3000"
+
+ mock_client = AsyncMock()
+ mock_client.post = AsyncMock(side_effect=httpx.TimeoutException("Timeout"))
+ mock_client.__aenter__ = AsyncMock(return_value=mock_client)
+ mock_client.__aexit__ = AsyncMock()
+ mock_client_class.return_value = mock_client
+
+ # Long enough content to trigger API call
+ result = await convert_markdown_to_blocknote(
+ "# Heading\n\nThis is a paragraph with enough content."
+ )
+
+ assert result is None # Not an exception
+
+ @pytest.mark.asyncio
+ @patch("app.utils.blocknote_converter.httpx.AsyncClient")
+ @patch("app.utils.blocknote_converter.config")
+ async def test_server_error_returns_none_not_exception(
+ self, mock_config, mock_client_class
+ ):
+ """
+ HTTP 5xx errors must return None, not raise exception.
+ Server errors shouldn't crash the caller.
+ """
+ mock_config.NEXT_FRONTEND_URL = "http://localhost:3000"
+
+ mock_response = MagicMock()
+ mock_response.status_code = 500
+ mock_response.text = "Internal Server Error"
+
+ mock_client = AsyncMock()
+ mock_client.post = AsyncMock(
+ side_effect=httpx.HTTPStatusError(
+ "Server error",
+ request=MagicMock(),
+ response=mock_response,
+ )
+ )
+ mock_client.__aenter__ = AsyncMock(return_value=mock_client)
+ mock_client.__aexit__ = AsyncMock()
+ mock_client_class.return_value = mock_client
+
+ result = await convert_markdown_to_blocknote(
+ "# Heading\n\nThis is a paragraph with enough content."
+ )
+
+ assert result is None
+
+ @pytest.mark.asyncio
+ @patch("app.utils.blocknote_converter.httpx.AsyncClient")
+ @patch("app.utils.blocknote_converter.config")
+ async def test_connection_error_returns_none(self, mock_config, mock_client_class):
+ """
+ Connection errors (server unreachable) must return None.
+ """
+ mock_config.NEXT_FRONTEND_URL = "http://localhost:3000"
+
+ mock_client = AsyncMock()
+ mock_client.post = AsyncMock(side_effect=httpx.ConnectError("Connection refused"))
+ mock_client.__aenter__ = AsyncMock(return_value=mock_client)
+ mock_client.__aexit__ = AsyncMock()
+ mock_client_class.return_value = mock_client
+
+ result = await convert_markdown_to_blocknote(
+ "# Heading\n\nThis is a paragraph with enough content."
+ )
+
+ assert result is None
+
+
+class TestMarkdownToBlocknoteSuccessfulConversion:
+ """
+ Tests for successful conversion scenarios.
+ """
+
+ @pytest.mark.asyncio
+ @patch("app.utils.blocknote_converter.httpx.AsyncClient")
+ @patch("app.utils.blocknote_converter.config")
+ async def test_successful_conversion_returns_document(
+ self, mock_config, mock_client_class
+ ):
+ """
+ Successful API response should return the BlockNote document.
+ """
+ mock_config.NEXT_FRONTEND_URL = "http://localhost:3000"
+
+ expected_document = [{"type": "paragraph", "content": [{"text": "Test"}]}]
+
+ mock_response = MagicMock()
+ mock_response.json.return_value = {"blocknote_document": expected_document}
+ mock_response.raise_for_status = MagicMock()
+
+ mock_client = AsyncMock()
+ mock_client.post = AsyncMock(return_value=mock_response)
+ mock_client.__aenter__ = AsyncMock(return_value=mock_client)
+ mock_client.__aexit__ = AsyncMock()
+ mock_client_class.return_value = mock_client
+
+ result = await convert_markdown_to_blocknote(
+ "# This is a heading\n\nThis is a paragraph with enough content."
+ )
+
+ assert result == expected_document
+
+ @pytest.mark.asyncio
+ @patch("app.utils.blocknote_converter.httpx.AsyncClient")
+ @patch("app.utils.blocknote_converter.config")
+ async def test_empty_api_response_returns_none(
+ self, mock_config, mock_client_class
+ ):
+ """
+ If API returns null/empty document, function should return None.
+ """
+ mock_config.NEXT_FRONTEND_URL = "http://localhost:3000"
+
+ mock_response = MagicMock()
+ mock_response.json.return_value = {"blocknote_document": None}
+ mock_response.raise_for_status = MagicMock()
+
+ mock_client = AsyncMock()
+ mock_client.post = AsyncMock(return_value=mock_response)
+ mock_client.__aenter__ = AsyncMock(return_value=mock_client)
+ mock_client.__aexit__ = AsyncMock()
+ mock_client_class.return_value = mock_client
+
+ result = await convert_markdown_to_blocknote(
+ "# Heading\n\nSome content that is long enough."
+ )
+
+ assert result is None
+
+
+class TestBlocknoteToMarkdownInputValidation:
+ """
+ Tests validating input handling for BlockNote to markdown conversion.
+ """
+
+ @pytest.mark.asyncio
+ async def test_none_document_returns_none(self):
+ """None input must return None, not crash."""
+ result = await convert_blocknote_to_markdown(None)
+ assert result is None
+
+ @pytest.mark.asyncio
+ async def test_empty_dict_returns_none(self):
+ """Empty dict should be treated as no content."""
+ result = await convert_blocknote_to_markdown({})
+ assert result is None
+
+ @pytest.mark.asyncio
+ async def test_empty_list_returns_none(self):
+ """Empty list should be treated as no content."""
+ result = await convert_blocknote_to_markdown([])
+ assert result is None
+
+
+class TestBlocknoteToMarkdownNetworkResilience:
+ """
+ Tests validating network error handling for BlockNote to markdown.
+ """
+
+ @pytest.mark.asyncio
+ @patch("app.utils.blocknote_converter.httpx.AsyncClient")
+ @patch("app.utils.blocknote_converter.config")
+ async def test_timeout_returns_none(self, mock_config, mock_client_class):
+ """Timeout must return None, not exception."""
+ mock_config.NEXT_FRONTEND_URL = "http://localhost:3000"
+
+ mock_client = AsyncMock()
+ mock_client.post = AsyncMock(side_effect=httpx.TimeoutException("Timeout"))
+ mock_client.__aenter__ = AsyncMock(return_value=mock_client)
+ mock_client.__aexit__ = AsyncMock()
+ mock_client_class.return_value = mock_client
+
+ blocknote_doc = [{"type": "paragraph", "content": []}]
+ result = await convert_blocknote_to_markdown(blocknote_doc)
+
+ assert result is None
+
+ @pytest.mark.asyncio
+ @patch("app.utils.blocknote_converter.httpx.AsyncClient")
+ @patch("app.utils.blocknote_converter.config")
+ async def test_server_error_returns_none(self, mock_config, mock_client_class):
+ """HTTP errors must return None, not exception."""
+ mock_config.NEXT_FRONTEND_URL = "http://localhost:3000"
+
+ mock_response = MagicMock()
+ mock_response.status_code = 500
+ mock_response.text = "Internal Server Error"
+
+ mock_client = AsyncMock()
+ mock_client.post = AsyncMock(
+ side_effect=httpx.HTTPStatusError(
+ "Server error",
+ request=MagicMock(),
+ response=mock_response,
+ )
+ )
+ mock_client.__aenter__ = AsyncMock(return_value=mock_client)
+ mock_client.__aexit__ = AsyncMock()
+ mock_client_class.return_value = mock_client
+
+ blocknote_doc = [{"type": "paragraph", "content": []}]
+ result = await convert_blocknote_to_markdown(blocknote_doc)
+
+ assert result is None
+
+
+class TestBlocknoteToMarkdownSuccessfulConversion:
+ """
+ Tests for successful BlockNote to markdown conversion.
+ """
+
+ @pytest.mark.asyncio
+ @patch("app.utils.blocknote_converter.httpx.AsyncClient")
+ @patch("app.utils.blocknote_converter.config")
+ async def test_successful_conversion_returns_markdown(
+ self, mock_config, mock_client_class
+ ):
+ """Successful conversion should return markdown string."""
+ mock_config.NEXT_FRONTEND_URL = "http://localhost:3000"
+
+ expected_markdown = "# Converted Heading\n\nParagraph text."
+
+ mock_response = MagicMock()
+ mock_response.json.return_value = {"markdown": expected_markdown}
+ mock_response.raise_for_status = MagicMock()
+
+ mock_client = AsyncMock()
+ mock_client.post = AsyncMock(return_value=mock_response)
+ mock_client.__aenter__ = AsyncMock(return_value=mock_client)
+ mock_client.__aexit__ = AsyncMock()
+ mock_client_class.return_value = mock_client
+
+ blocknote_doc = [
+ {"type": "heading", "content": [{"type": "text", "text": "Test"}]}
+ ]
+ result = await convert_blocknote_to_markdown(blocknote_doc)
+
+ assert result == expected_markdown
+
+ @pytest.mark.asyncio
+ @patch("app.utils.blocknote_converter.httpx.AsyncClient")
+ @patch("app.utils.blocknote_converter.config")
+ async def test_null_markdown_response_returns_none(
+ self, mock_config, mock_client_class
+ ):
+ """If API returns null markdown, function should return None."""
+ mock_config.NEXT_FRONTEND_URL = "http://localhost:3000"
+
+ mock_response = MagicMock()
+ mock_response.json.return_value = {"markdown": None}
+ mock_response.raise_for_status = MagicMock()
+
+ mock_client = AsyncMock()
+ mock_client.post = AsyncMock(return_value=mock_response)
+ mock_client.__aenter__ = AsyncMock(return_value=mock_client)
+ mock_client.__aexit__ = AsyncMock()
+ mock_client_class.return_value = mock_client
+
+ blocknote_doc = [{"type": "paragraph", "content": []}]
+ result = await convert_blocknote_to_markdown(blocknote_doc)
+
+ assert result is None
+
+ @pytest.mark.asyncio
+ @patch("app.utils.blocknote_converter.httpx.AsyncClient")
+ @patch("app.utils.blocknote_converter.config")
+ async def test_list_document_is_handled(self, mock_config, mock_client_class):
+ """
+ List documents (multiple blocks) should be handled correctly.
+ """
+ mock_config.NEXT_FRONTEND_URL = "http://localhost:3000"
+
+ expected_markdown = "- Item 1\n- Item 2"
+
+ mock_response = MagicMock()
+ mock_response.json.return_value = {"markdown": expected_markdown}
+ mock_response.raise_for_status = MagicMock()
+
+ mock_client = AsyncMock()
+ mock_client.post = AsyncMock(return_value=mock_response)
+ mock_client.__aenter__ = AsyncMock(return_value=mock_client)
+ mock_client.__aexit__ = AsyncMock()
+ mock_client_class.return_value = mock_client
+
+ blocknote_doc = [
+ {
+ "type": "bulletListItem",
+ "content": [{"type": "text", "text": "Item 1"}],
+ },
+ {
+ "type": "bulletListItem",
+ "content": [{"type": "text", "text": "Item 2"}],
+ },
+ ]
+ result = await convert_blocknote_to_markdown(blocknote_doc)
+
+ assert result == expected_markdown
diff --git a/surfsense_backend/tests/test_config.py b/surfsense_backend/tests/test_config.py
new file mode 100644
index 000000000..27b0abc55
--- /dev/null
+++ b/surfsense_backend/tests/test_config.py
@@ -0,0 +1,86 @@
+"""
+Tests for config module.
+Tests application configuration and environment variable handling.
+"""
+
+
+class TestConfigEnvironmentVariables:
+ """Tests for config environment variable handling."""
+
+ def test_config_loads_without_error(self):
+ """Test that config module loads without error."""
+ from app.config import config
+
+ # Config should be an object
+ assert config is not None
+
+ def test_config_has_expected_attributes(self):
+ """Test config has expected attributes."""
+ from app.config import config
+
+ # These should exist (may have default values)
+ assert hasattr(config, 'DATABASE_URL') or True # Optional
+ assert hasattr(config, 'SECRET_KEY') or True # Optional
+
+
+class TestGlobalLLMConfigs:
+ """Tests for global LLM configurations."""
+
+ def test_global_llm_configs_is_list(self):
+ """Test GLOBAL_LLM_CONFIGS is a list."""
+ from app.config import config
+
+ assert isinstance(config.GLOBAL_LLM_CONFIGS, list)
+
+ def test_global_llm_configs_have_required_fields(self):
+ """Test each global config has required fields."""
+ from app.config import config
+
+ required_fields = {"id", "name", "provider", "model_name"}
+
+ for cfg in config.GLOBAL_LLM_CONFIGS:
+ for field in required_fields:
+ assert field in cfg, f"Missing field {field} in global config"
+
+ def test_global_llm_configs_have_negative_ids(self):
+ """Test all global configs have negative IDs."""
+ from app.config import config
+
+ for cfg in config.GLOBAL_LLM_CONFIGS:
+ assert cfg["id"] < 0, f"Global config {cfg['name']} should have negative ID"
+
+
+class TestEmbeddingModelInstance:
+ """Tests for embedding model instance."""
+
+ def test_embedding_model_instance_exists(self):
+ """Test embedding model instance is configured."""
+ from app.config import config
+
+ # Should have an embedding model instance
+ assert hasattr(config, 'embedding_model_instance')
+
+ def test_embedding_model_has_embed_method(self):
+ """Test embedding model has embed method."""
+ from app.config import config
+
+ if config.embedding_model_instance is not None:
+ assert hasattr(config.embedding_model_instance, 'embed')
+
+
+class TestAuthConfiguration:
+ """Tests for authentication configuration."""
+
+ def test_auth_type_is_string(self):
+ """Test AUTH_TYPE is a string."""
+ from app.config import config
+
+ if hasattr(config, 'AUTH_TYPE'):
+ assert isinstance(config.AUTH_TYPE, str)
+
+ def test_registration_enabled_is_boolean(self):
+ """Test REGISTRATION_ENABLED is boolean."""
+ from app.config import config
+
+ if hasattr(config, 'REGISTRATION_ENABLED'):
+ assert isinstance(config.REGISTRATION_ENABLED, bool)
diff --git a/surfsense_backend/tests/test_db_models.py b/surfsense_backend/tests/test_db_models.py
new file mode 100644
index 000000000..2a3a2b8db
--- /dev/null
+++ b/surfsense_backend/tests/test_db_models.py
@@ -0,0 +1,325 @@
+"""
+Tests for database models and functions.
+Tests SQLAlchemy models, enums, and database utility functions.
+"""
+
+from app.db import (
+ DocumentType,
+ LiteLLMProvider,
+ SearchSourceConnectorType,
+ Permission,
+ SearchSpace,
+ Document,
+ Chunk,
+ Chat,
+ Podcast,
+ LLMConfig,
+ SearchSourceConnector,
+ SearchSpaceRole,
+ SearchSpaceMembership,
+ SearchSpaceInvite,
+ User,
+ LogLevel,
+ LogStatus,
+ ChatType,
+)
+
+
+class TestDocumentType:
+ """Tests for DocumentType enum."""
+
+ def test_all_document_types_are_strings(self):
+ """Test all document types have string values."""
+ for doc_type in list(DocumentType):
+ assert isinstance(doc_type.value, str)
+
+ def test_extension_type(self):
+ """Test EXTENSION document type."""
+ assert DocumentType.EXTENSION.value == "EXTENSION"
+
+ def test_file_type(self):
+ """Test FILE document type."""
+ assert DocumentType.FILE.value == "FILE"
+
+ def test_youtube_video_type(self):
+ """Test YOUTUBE_VIDEO document type."""
+ assert DocumentType.YOUTUBE_VIDEO.value == "YOUTUBE_VIDEO"
+
+ def test_crawled_url_type(self):
+ """Test CRAWLED_URL document type."""
+ assert DocumentType.CRAWLED_URL.value == "CRAWLED_URL"
+
+ def test_connector_types_exist(self):
+ """Test connector document types exist."""
+ connector_types = [
+ "SLACK_CONNECTOR",
+ "NOTION_CONNECTOR",
+ "GITHUB_CONNECTOR",
+ "JIRA_CONNECTOR",
+ "CONFLUENCE_CONNECTOR",
+ "LINEAR_CONNECTOR",
+ "DISCORD_CONNECTOR",
+ ]
+
+ for conn_type in connector_types:
+ assert hasattr(DocumentType, conn_type)
+
+
+class TestLiteLLMProvider:
+ """Tests for LiteLLMProvider enum."""
+
+ def test_openai_provider(self):
+ """Test OPENAI provider."""
+ assert LiteLLMProvider.OPENAI.value == "OPENAI"
+
+ def test_anthropic_provider(self):
+ """Test ANTHROPIC provider."""
+ assert LiteLLMProvider.ANTHROPIC.value == "ANTHROPIC"
+
+ def test_google_provider(self):
+ """Test GOOGLE provider."""
+ assert LiteLLMProvider.GOOGLE.value == "GOOGLE"
+
+ def test_ollama_provider(self):
+ """Test OLLAMA provider."""
+ assert LiteLLMProvider.OLLAMA.value == "OLLAMA"
+
+ def test_all_providers_are_strings(self):
+ """Test all providers have string values."""
+ for provider in list(LiteLLMProvider):
+ assert isinstance(provider.value, str)
+
+
+class TestSearchSourceConnectorType:
+ """Tests for SearchSourceConnectorType enum."""
+
+ def test_tavily_api(self):
+ """Test TAVILY_API connector type."""
+ assert SearchSourceConnectorType.TAVILY_API.value == "TAVILY_API"
+
+ def test_searxng_api(self):
+ """Test SEARXNG_API connector type."""
+ assert SearchSourceConnectorType.SEARXNG_API.value == "SEARXNG_API"
+
+ def test_slack_connector(self):
+ """Test SLACK_CONNECTOR connector type."""
+ assert SearchSourceConnectorType.SLACK_CONNECTOR.value == "SLACK_CONNECTOR"
+
+ def test_notion_connector(self):
+ """Test NOTION_CONNECTOR connector type."""
+ assert SearchSourceConnectorType.NOTION_CONNECTOR.value == "NOTION_CONNECTOR"
+
+ def test_all_connector_types_are_strings(self):
+ """Test all connector types have string values."""
+ for conn_type in list(SearchSourceConnectorType):
+ assert isinstance(conn_type.value, str)
+
+
+class TestPermission:
+ """Tests for Permission enum."""
+
+ def test_full_access_permission(self):
+ """Test FULL_ACCESS permission."""
+ assert Permission.FULL_ACCESS.value == "*"
+
+ def test_document_permissions(self):
+ """Test document permissions exist."""
+ doc_permissions = [
+ "DOCUMENTS_CREATE",
+ "DOCUMENTS_READ",
+ "DOCUMENTS_UPDATE",
+ "DOCUMENTS_DELETE",
+ ]
+
+ for perm in doc_permissions:
+ assert hasattr(Permission, perm)
+
+ def test_chat_permissions(self):
+ """Test chat permissions exist."""
+ chat_permissions = [
+ "CHATS_CREATE",
+ "CHATS_READ",
+ "CHATS_UPDATE",
+ "CHATS_DELETE",
+ ]
+
+ for perm in chat_permissions:
+ assert hasattr(Permission, perm)
+
+ def test_llm_config_permissions(self):
+ """Test LLM config permissions exist."""
+ llm_permissions = [
+ "LLM_CONFIGS_CREATE",
+ "LLM_CONFIGS_READ",
+ "LLM_CONFIGS_UPDATE",
+ "LLM_CONFIGS_DELETE",
+ ]
+
+ for perm in llm_permissions:
+ assert hasattr(Permission, perm)
+
+ def test_settings_permissions(self):
+ """Test settings permissions exist."""
+ settings_permissions = [
+ "SETTINGS_VIEW",
+ "SETTINGS_UPDATE",
+ "SETTINGS_DELETE",
+ ]
+
+ for perm in settings_permissions:
+ assert hasattr(Permission, perm)
+
+
+class TestSearchSpaceModel:
+ """Tests for SearchSpace model."""
+
+ def test_search_space_has_required_fields(self):
+ """Test SearchSpace has required fields."""
+ # Check that the model has expected columns
+ assert hasattr(SearchSpace, 'id')
+ assert hasattr(SearchSpace, 'name')
+ assert hasattr(SearchSpace, 'user_id')
+ assert hasattr(SearchSpace, 'created_at')
+
+
+class TestDocumentModel:
+ """Tests for Document model."""
+
+ def test_document_has_required_fields(self):
+ """Test Document has required fields."""
+ assert hasattr(Document, 'id')
+ assert hasattr(Document, 'title')
+ assert hasattr(Document, 'document_type')
+ assert hasattr(Document, 'content')
+ assert hasattr(Document, 'search_space_id')
+
+ def test_document_has_chunks_relationship(self):
+ """Test Document has chunks relationship."""
+ assert hasattr(Document, 'chunks')
+
+
+class TestChunkModel:
+ """Tests for Chunk model."""
+
+ def test_chunk_has_required_fields(self):
+ """Test Chunk has required fields."""
+ assert hasattr(Chunk, 'id')
+ assert hasattr(Chunk, 'content')
+ assert hasattr(Chunk, 'document_id')
+
+ def test_chunk_has_embedding_field(self):
+ """Test Chunk has embedding field."""
+ assert hasattr(Chunk, 'embedding')
+
+
+class TestChatModel:
+ """Tests for Chat model."""
+
+ def test_chat_has_required_fields(self):
+ """Test Chat has required fields."""
+ assert hasattr(Chat, 'id')
+ assert hasattr(Chat, 'title')
+ assert hasattr(Chat, 'search_space_id')
+
+
+class TestChatType:
+ """Tests for ChatType enum."""
+
+ def test_chat_type_values(self):
+ """Test ChatType values."""
+ assert hasattr(ChatType, 'QNA')
+
+
+class TestLogLevel:
+ """Tests for LogLevel enum."""
+
+ def test_log_level_values(self):
+ """Test LogLevel values exist."""
+ assert hasattr(LogLevel, 'INFO')
+ assert hasattr(LogLevel, 'WARNING')
+ assert hasattr(LogLevel, 'ERROR')
+
+
+class TestLogStatus:
+ """Tests for LogStatus enum."""
+
+ def test_log_status_values(self):
+ """Test LogStatus values exist."""
+ assert hasattr(LogStatus, 'IN_PROGRESS')
+ assert hasattr(LogStatus, 'SUCCESS')
+ assert hasattr(LogStatus, 'FAILED')
+ assert LogStatus.IN_PROGRESS.value == "IN_PROGRESS"
+
+
+class TestLLMConfigModel:
+ """Tests for LLMConfig model."""
+
+ def test_llm_config_has_required_fields(self):
+ """Test LLMConfig has required fields."""
+ assert hasattr(LLMConfig, 'id')
+ assert hasattr(LLMConfig, 'name')
+ assert hasattr(LLMConfig, 'provider')
+ assert hasattr(LLMConfig, 'model_name')
+ assert hasattr(LLMConfig, 'api_key')
+ assert hasattr(LLMConfig, 'search_space_id')
+
+
+class TestSearchSourceConnectorModel:
+ """Tests for SearchSourceConnector model."""
+
+ def test_connector_has_required_fields(self):
+ """Test SearchSourceConnector has required fields."""
+ assert hasattr(SearchSourceConnector, 'id')
+ assert hasattr(SearchSourceConnector, 'connector_type')
+ assert hasattr(SearchSourceConnector, 'config')
+ assert hasattr(SearchSourceConnector, 'search_space_id')
+
+
+class TestRBACModels:
+ """Tests for RBAC models."""
+
+ def test_search_space_role_has_required_fields(self):
+ """Test SearchSpaceRole has required fields."""
+ assert hasattr(SearchSpaceRole, 'id')
+ assert hasattr(SearchSpaceRole, 'name')
+ assert hasattr(SearchSpaceRole, 'permissions')
+ assert hasattr(SearchSpaceRole, 'search_space_id')
+
+ def test_search_space_membership_has_required_fields(self):
+ """Test SearchSpaceMembership has required fields."""
+ assert hasattr(SearchSpaceMembership, 'id')
+ assert hasattr(SearchSpaceMembership, 'user_id')
+ assert hasattr(SearchSpaceMembership, 'search_space_id')
+ assert hasattr(SearchSpaceMembership, 'role_id')
+ assert hasattr(SearchSpaceMembership, 'is_owner')
+
+ def test_search_space_invite_has_required_fields(self):
+ """Test SearchSpaceInvite has required fields."""
+ assert hasattr(SearchSpaceInvite, 'id')
+ assert hasattr(SearchSpaceInvite, 'invite_code')
+ assert hasattr(SearchSpaceInvite, 'search_space_id')
+ assert hasattr(SearchSpaceInvite, 'role_id')
+
+
+class TestUserModel:
+ """Tests for User model."""
+
+ def test_user_has_required_fields(self):
+ """Test User has required fields."""
+ assert hasattr(User, 'id')
+ assert hasattr(User, 'email')
+
+ def test_user_has_page_limit_fields(self):
+ """Test User has page limit fields."""
+ assert hasattr(User, 'pages_used')
+ assert hasattr(User, 'pages_limit')
+
+
+class TestPodcastModel:
+ """Tests for Podcast model."""
+
+ def test_podcast_has_required_fields(self):
+ """Test Podcast has required fields."""
+ assert hasattr(Podcast, 'id')
+ assert hasattr(Podcast, 'title')
+ assert hasattr(Podcast, 'search_space_id')
diff --git a/surfsense_backend/tests/test_document_converters.py b/surfsense_backend/tests/test_document_converters.py
new file mode 100644
index 000000000..2c06a3107
--- /dev/null
+++ b/surfsense_backend/tests/test_document_converters.py
@@ -0,0 +1,513 @@
+"""
+Tests for document_converters utility module.
+
+This module tests the document conversion functions including
+content hash generation, markdown conversion, and chunking utilities.
+"""
+
+import hashlib
+from unittest.mock import MagicMock
+
+import pytest
+
+from app.db import DocumentType
+from app.utils.document_converters import (
+ convert_chunks_to_langchain_documents,
+ convert_document_to_markdown,
+ convert_element_to_markdown,
+ generate_content_hash,
+ generate_unique_identifier_hash,
+)
+
+
+class TestGenerateContentHash:
+ """Tests for generate_content_hash function."""
+
+ def test_generates_sha256_hash(self):
+ """Test that function generates SHA-256 hash."""
+ content = "Test content"
+ search_space_id = 1
+ result = generate_content_hash(content, search_space_id)
+
+ # Verify it's a valid SHA-256 hash (64 hex characters)
+ assert len(result) == 64
+ assert all(c in "0123456789abcdef" for c in result)
+
+ def test_combines_content_and_search_space_id(self):
+ """Test that hash is generated from combined data."""
+ content = "Test content"
+ search_space_id = 1
+
+ # Manually compute expected hash
+ combined_data = f"{search_space_id}:{content}"
+ expected_hash = hashlib.sha256(combined_data.encode("utf-8")).hexdigest()
+
+ result = generate_content_hash(content, search_space_id)
+ assert result == expected_hash
+
+ def test_different_content_produces_different_hash(self):
+ """Test that different content produces different hashes."""
+ hash1 = generate_content_hash("Content 1", 1)
+ hash2 = generate_content_hash("Content 2", 1)
+ assert hash1 != hash2
+
+ def test_different_search_space_produces_different_hash(self):
+ """Test that different search space ID produces different hashes."""
+ hash1 = generate_content_hash("Same content", 1)
+ hash2 = generate_content_hash("Same content", 2)
+ assert hash1 != hash2
+
+ def test_same_input_produces_same_hash(self):
+ """Test that same input always produces same hash."""
+ content = "Consistent content"
+ search_space_id = 42
+
+ hash1 = generate_content_hash(content, search_space_id)
+ hash2 = generate_content_hash(content, search_space_id)
+ assert hash1 == hash2
+
+ def test_empty_content(self):
+ """Test with empty content."""
+ result = generate_content_hash("", 1)
+ assert len(result) == 64 # Still produces valid hash
+
+ def test_unicode_content(self):
+ """Test with unicode content."""
+ result = generate_content_hash("こんにちは世界 🌍", 1)
+ assert len(result) == 64
+
+
+class TestGenerateUniqueIdentifierHash:
+ """Tests for generate_unique_identifier_hash function."""
+
+ def test_generates_sha256_hash(self):
+ """Test that function generates SHA-256 hash."""
+ result = generate_unique_identifier_hash(
+ DocumentType.SLACK_CONNECTOR,
+ "message123",
+ 1,
+ )
+ assert len(result) == 64
+ assert all(c in "0123456789abcdef" for c in result)
+
+ def test_combines_all_parameters(self):
+ """Test that hash is generated from all parameters."""
+ doc_type = DocumentType.SLACK_CONNECTOR
+ unique_id = "message123"
+ search_space_id = 42
+
+ # Manually compute expected hash
+ combined_data = f"{doc_type.value}:{unique_id}:{search_space_id}"
+ expected_hash = hashlib.sha256(combined_data.encode("utf-8")).hexdigest()
+
+ result = generate_unique_identifier_hash(doc_type, unique_id, search_space_id)
+ assert result == expected_hash
+
+ def test_different_document_types_produce_different_hashes(self):
+ """Test different document types produce different hashes."""
+ hash1 = generate_unique_identifier_hash(DocumentType.SLACK_CONNECTOR, "id123", 1)
+ hash2 = generate_unique_identifier_hash(DocumentType.NOTION_CONNECTOR, "id123", 1)
+ assert hash1 != hash2
+
+ def test_different_identifiers_produce_different_hashes(self):
+ """Test different identifiers produce different hashes."""
+ hash1 = generate_unique_identifier_hash(DocumentType.SLACK_CONNECTOR, "id123", 1)
+ hash2 = generate_unique_identifier_hash(DocumentType.SLACK_CONNECTOR, "id456", 1)
+ assert hash1 != hash2
+
+ def test_integer_identifier(self):
+ """Test with integer unique identifier."""
+ result = generate_unique_identifier_hash(DocumentType.JIRA_CONNECTOR, 12345, 1)
+ assert len(result) == 64
+
+ def test_float_identifier(self):
+ """Test with float unique identifier (e.g., Slack timestamps)."""
+ result = generate_unique_identifier_hash(
+ DocumentType.SLACK_CONNECTOR,
+ 1234567890.123456,
+ 1,
+ )
+ assert len(result) == 64
+
+ def test_consistency(self):
+ """Test that same inputs always produce same hash."""
+ params = (DocumentType.GITHUB_CONNECTOR, "pr-123", 5)
+
+ hash1 = generate_unique_identifier_hash(*params)
+ hash2 = generate_unique_identifier_hash(*params)
+ assert hash1 == hash2
+
+
+class TestConvertElementToMarkdown:
+ """Tests for convert_element_to_markdown function."""
+
+ @pytest.mark.asyncio
+ async def test_formula_element(self):
+ """Test Formula element conversion."""
+ element = MagicMock()
+ element.metadata = {"category": "Formula"}
+ element.page_content = "E = mc^2"
+
+ result = await convert_element_to_markdown(element)
+ assert "```math" in result
+ assert "E = mc^2" in result
+
+ @pytest.mark.asyncio
+ async def test_figure_caption_element(self):
+ """Test FigureCaption element conversion."""
+ element = MagicMock()
+ element.metadata = {"category": "FigureCaption"}
+ element.page_content = "Figure 1: Test image"
+
+ result = await convert_element_to_markdown(element)
+ assert "*Figure:" in result
+
+ @pytest.mark.asyncio
+ async def test_narrative_text_element(self):
+ """Test NarrativeText element conversion."""
+ element = MagicMock()
+ element.metadata = {"category": "NarrativeText"}
+ element.page_content = "This is a paragraph of text."
+
+ result = await convert_element_to_markdown(element)
+ assert "This is a paragraph of text." in result
+ assert result.endswith("\n\n")
+
+ @pytest.mark.asyncio
+ async def test_list_item_element(self):
+ """Test ListItem element conversion."""
+ element = MagicMock()
+ element.metadata = {"category": "ListItem"}
+ element.page_content = "Item one"
+
+ result = await convert_element_to_markdown(element)
+ assert result.startswith("- ")
+ assert "Item one" in result
+
+ @pytest.mark.asyncio
+ async def test_title_element(self):
+ """Test Title element conversion."""
+ element = MagicMock()
+ element.metadata = {"category": "Title"}
+ element.page_content = "Document Title"
+
+ result = await convert_element_to_markdown(element)
+ assert result.startswith("# ")
+ assert "Document Title" in result
+
+ @pytest.mark.asyncio
+ async def test_address_element(self):
+ """Test Address element conversion."""
+ element = MagicMock()
+ element.metadata = {"category": "Address"}
+ element.page_content = "123 Main St"
+
+ result = await convert_element_to_markdown(element)
+ assert result.startswith("> ")
+
+ @pytest.mark.asyncio
+ async def test_email_address_element(self):
+ """Test EmailAddress element conversion."""
+ element = MagicMock()
+ element.metadata = {"category": "EmailAddress"}
+ element.page_content = "test@example.com"
+
+ result = await convert_element_to_markdown(element)
+ assert "`test@example.com`" in result
+
+ @pytest.mark.asyncio
+ async def test_table_element(self):
+ """Test Table element conversion."""
+ element = MagicMock()
+ element.metadata = {"category": "Table", "text_as_html": "
"}
+ element.page_content = "Table content"
+
+ result = await convert_element_to_markdown(element)
+ assert "```html" in result
+ assert "" in result
+
+ @pytest.mark.asyncio
+ async def test_header_element(self):
+ """Test Header element conversion."""
+ element = MagicMock()
+ element.metadata = {"category": "Header"}
+ element.page_content = "Section Header"
+
+ result = await convert_element_to_markdown(element)
+ assert result.startswith("## ")
+
+ @pytest.mark.asyncio
+ async def test_code_snippet_element(self):
+ """Test CodeSnippet element conversion."""
+ element = MagicMock()
+ element.metadata = {"category": "CodeSnippet"}
+ element.page_content = "print('hello')"
+
+ result = await convert_element_to_markdown(element)
+ assert "```" in result
+ assert "print('hello')" in result
+
+ @pytest.mark.asyncio
+ async def test_page_number_element(self):
+ """Test PageNumber element conversion."""
+ element = MagicMock()
+ element.metadata = {"category": "PageNumber"}
+ element.page_content = "42"
+
+ result = await convert_element_to_markdown(element)
+ assert "*Page 42*" in result
+
+ @pytest.mark.asyncio
+ async def test_page_break_element(self):
+ """Test PageBreak element conversion."""
+ element = MagicMock()
+ element.metadata = {"category": "PageBreak"}
+ # PageBreak with content returns horizontal rule
+ element.page_content = "page break content"
+
+ result = await convert_element_to_markdown(element)
+ assert "---" in result
+
+ @pytest.mark.asyncio
+ async def test_empty_content(self):
+ """Test element with empty content."""
+ element = MagicMock()
+ element.metadata = {"category": "NarrativeText"}
+ element.page_content = ""
+
+ result = await convert_element_to_markdown(element)
+ assert result == ""
+
+ @pytest.mark.asyncio
+ async def test_uncategorized_element(self):
+ """Test UncategorizedText element conversion."""
+ element = MagicMock()
+ element.metadata = {"category": "UncategorizedText"}
+ element.page_content = "Some uncategorized text"
+
+ result = await convert_element_to_markdown(element)
+ assert "Some uncategorized text" in result
+
+
+class TestConvertDocumentToMarkdown:
+ """Tests for convert_document_to_markdown function."""
+
+ @pytest.mark.asyncio
+ async def test_converts_multiple_elements(self):
+ """Test converting multiple elements."""
+ elements = []
+
+ # Title element
+ title = MagicMock()
+ title.metadata = {"category": "Title"}
+ title.page_content = "Document Title"
+ elements.append(title)
+
+ # Narrative text element
+ para = MagicMock()
+ para.metadata = {"category": "NarrativeText"}
+ para.page_content = "This is a paragraph."
+ elements.append(para)
+
+ result = await convert_document_to_markdown(elements)
+
+ assert "# Document Title" in result
+ assert "This is a paragraph." in result
+
+ @pytest.mark.asyncio
+ async def test_empty_elements(self):
+ """Test with empty elements list."""
+ result = await convert_document_to_markdown([])
+ assert result == ""
+
+ @pytest.mark.asyncio
+ async def test_preserves_order(self):
+ """Test that element order is preserved."""
+ elements = []
+
+ for i in range(3):
+ elem = MagicMock()
+ elem.metadata = {"category": "NarrativeText"}
+ elem.page_content = f"Paragraph {i}"
+ elements.append(elem)
+
+ result = await convert_document_to_markdown(elements)
+
+ # Check order is preserved
+ pos0 = result.find("Paragraph 0")
+ pos1 = result.find("Paragraph 1")
+ pos2 = result.find("Paragraph 2")
+
+ assert pos0 < pos1 < pos2
+
+
+class TestConvertChunksToLangchainDocuments:
+ """Tests for convert_chunks_to_langchain_documents function."""
+
+ def test_converts_basic_chunks(self):
+ """Test converting basic chunk structure."""
+ chunks = [
+ {
+ "chunk_id": 1,
+ "content": "This is chunk content",
+ "score": 0.95,
+ "document": {
+ "id": 10,
+ "title": "Test Document",
+ "document_type": "FILE",
+ "metadata": {"url": "https://example.com"},
+ },
+ }
+ ]
+
+ result = convert_chunks_to_langchain_documents(chunks)
+
+ assert len(result) == 1
+ assert "This is chunk content" in result[0].page_content
+ assert result[0].metadata["chunk_id"] == 1
+ assert result[0].metadata["document_id"] == 10
+ assert result[0].metadata["document_title"] == "Test Document"
+
+ def test_includes_source_id_in_content(self):
+ """Test that source_id is included in XML content."""
+ chunks = [
+ {
+ "chunk_id": 1,
+ "content": "Test content",
+ "score": 0.9,
+ "document": {
+ "id": 5,
+ "title": "Doc",
+ "document_type": "FILE",
+ "metadata": {},
+ },
+ }
+ ]
+
+ result = convert_chunks_to_langchain_documents(chunks)
+
+ assert "5" in result[0].page_content
+
+ def test_extracts_source_url(self):
+ """Test source URL extraction from metadata."""
+ chunks = [
+ {
+ "chunk_id": 1,
+ "content": "Content",
+ "score": 0.9,
+ "document": {
+ "id": 1,
+ "title": "Doc",
+ "document_type": "CRAWLED_URL",
+ "metadata": {"url": "https://example.com/page"},
+ },
+ }
+ ]
+
+ result = convert_chunks_to_langchain_documents(chunks)
+
+ assert result[0].metadata["source"] == "https://example.com/page"
+
+ def test_extracts_source_url_alternate_key(self):
+ """Test source URL extraction with sourceURL key."""
+ chunks = [
+ {
+ "chunk_id": 1,
+ "content": "Content",
+ "score": 0.9,
+ "document": {
+ "id": 1,
+ "title": "Doc",
+ "document_type": "CRAWLED_URL",
+ "metadata": {"sourceURL": "https://example.com/alternate"},
+ },
+ }
+ ]
+
+ result = convert_chunks_to_langchain_documents(chunks)
+
+ assert result[0].metadata["source"] == "https://example.com/alternate"
+
+ def test_handles_missing_document(self):
+ """Test handling chunks without document info."""
+ chunks = [
+ {
+ "chunk_id": 1,
+ "content": "Content without document",
+ "score": 0.8,
+ }
+ ]
+
+ result = convert_chunks_to_langchain_documents(chunks)
+
+ assert len(result) == 1
+ assert "Content without document" in result[0].page_content
+
+ def test_prefixes_document_metadata(self):
+ """Test document metadata is prefixed."""
+ chunks = [
+ {
+ "chunk_id": 1,
+ "content": "Content",
+ "score": 0.9,
+ "document": {
+ "id": 1,
+ "title": "Doc",
+ "document_type": "FILE",
+ "metadata": {"custom_field": "custom_value"},
+ },
+ }
+ ]
+
+ result = convert_chunks_to_langchain_documents(chunks)
+
+ assert result[0].metadata["doc_meta_custom_field"] == "custom_value"
+
+ def test_handles_rank_field(self):
+ """Test handling of rank field when present."""
+ chunks = [
+ {
+ "chunk_id": 1,
+ "content": "Content",
+ "score": 0.9,
+ "rank": 1,
+ "document": {
+ "id": 1,
+ "title": "Doc",
+ "document_type": "FILE",
+ "metadata": {},
+ },
+ }
+ ]
+
+ result = convert_chunks_to_langchain_documents(chunks)
+
+ assert result[0].metadata["rank"] == 1
+
+ def test_empty_chunks_list(self):
+ """Test with empty chunks list."""
+ result = convert_chunks_to_langchain_documents([])
+ assert result == []
+
+ def test_multiple_chunks(self):
+ """Test converting multiple chunks."""
+ chunks = [
+ {
+ "chunk_id": i,
+ "content": f"Content {i}",
+ "score": 0.9 - (i * 0.1),
+ "document": {
+ "id": i,
+ "title": f"Doc {i}",
+ "document_type": "FILE",
+ "metadata": {},
+ },
+ }
+ for i in range(3)
+ ]
+
+ result = convert_chunks_to_langchain_documents(chunks)
+
+ assert len(result) == 3
+ for i, doc in enumerate(result):
+ assert f"Content {i}" in doc.page_content
diff --git a/surfsense_backend/tests/test_permissions.py b/surfsense_backend/tests/test_permissions.py
new file mode 100644
index 000000000..3dbef13b6
--- /dev/null
+++ b/surfsense_backend/tests/test_permissions.py
@@ -0,0 +1,270 @@
+"""
+Tests for permission functions in db.py.
+
+This module tests the permission checking functions used in RBAC.
+"""
+
+from app.db import (
+ DEFAULT_ROLE_PERMISSIONS,
+ Permission,
+ get_default_roles_config,
+ has_all_permissions,
+ has_any_permission,
+ has_permission,
+)
+
+
+class TestHasPermission:
+ """Tests for has_permission function."""
+
+ def test_has_permission_with_exact_match(self):
+ """Test has_permission returns True for exact permission match."""
+ permissions = [Permission.DOCUMENTS_READ.value, Permission.CHATS_READ.value]
+ assert has_permission(permissions, Permission.DOCUMENTS_READ.value) is True
+
+ def test_has_permission_with_no_match(self):
+ """Test has_permission returns False when permission not in list."""
+ permissions = [Permission.DOCUMENTS_READ.value]
+ assert has_permission(permissions, Permission.DOCUMENTS_CREATE.value) is False
+
+ def test_has_permission_with_full_access(self):
+ """Test has_permission returns True for any permission when user has FULL_ACCESS."""
+ permissions = [Permission.FULL_ACCESS.value]
+ assert has_permission(permissions, Permission.DOCUMENTS_CREATE.value) is True
+ assert has_permission(permissions, Permission.SETTINGS_DELETE.value) is True
+ assert has_permission(permissions, Permission.MEMBERS_MANAGE_ROLES.value) is True
+
+ def test_has_permission_with_empty_list(self):
+ """Test has_permission returns False for empty permission list."""
+ assert has_permission([], Permission.DOCUMENTS_READ.value) is False
+
+ def test_has_permission_with_none(self):
+ """Test has_permission returns False for None."""
+ assert has_permission(None, Permission.DOCUMENTS_READ.value) is False
+
+
+class TestHasAnyPermission:
+ """Tests for has_any_permission function."""
+
+ def test_has_any_permission_with_one_match(self):
+ """Test has_any_permission returns True when at least one permission matches."""
+ user_permissions = [Permission.DOCUMENTS_READ.value, Permission.CHATS_READ.value]
+ required = [Permission.DOCUMENTS_READ.value, Permission.DOCUMENTS_CREATE.value]
+ assert has_any_permission(user_permissions, required) is True
+
+ def test_has_any_permission_with_all_match(self):
+ """Test has_any_permission returns True when all permissions match."""
+ user_permissions = [Permission.DOCUMENTS_READ.value, Permission.CHATS_READ.value]
+ required = [Permission.DOCUMENTS_READ.value, Permission.CHATS_READ.value]
+ assert has_any_permission(user_permissions, required) is True
+
+ def test_has_any_permission_with_no_match(self):
+ """Test has_any_permission returns False when no permissions match."""
+ user_permissions = [Permission.DOCUMENTS_READ.value]
+ required = [Permission.CHATS_CREATE.value, Permission.SETTINGS_UPDATE.value]
+ assert has_any_permission(user_permissions, required) is False
+
+ def test_has_any_permission_with_full_access(self):
+ """Test has_any_permission returns True with FULL_ACCESS."""
+ user_permissions = [Permission.FULL_ACCESS.value]
+ required = [Permission.SETTINGS_DELETE.value]
+ assert has_any_permission(user_permissions, required) is True
+
+ def test_has_any_permission_with_empty_user_permissions(self):
+ """Test has_any_permission returns False with empty user permissions."""
+ assert has_any_permission([], [Permission.DOCUMENTS_READ.value]) is False
+
+ def test_has_any_permission_with_none(self):
+ """Test has_any_permission returns False with None."""
+ assert has_any_permission(None, [Permission.DOCUMENTS_READ.value]) is False
+
+
+class TestHasAllPermissions:
+ """Tests for has_all_permissions function."""
+
+ def test_has_all_permissions_with_all_match(self):
+ """Test has_all_permissions returns True when all permissions match."""
+ user_permissions = [
+ Permission.DOCUMENTS_READ.value,
+ Permission.DOCUMENTS_CREATE.value,
+ Permission.CHATS_READ.value,
+ ]
+ required = [Permission.DOCUMENTS_READ.value, Permission.DOCUMENTS_CREATE.value]
+ assert has_all_permissions(user_permissions, required) is True
+
+ def test_has_all_permissions_with_partial_match(self):
+ """Test has_all_permissions returns False when only some permissions match."""
+ user_permissions = [Permission.DOCUMENTS_READ.value]
+ required = [Permission.DOCUMENTS_READ.value, Permission.DOCUMENTS_CREATE.value]
+ assert has_all_permissions(user_permissions, required) is False
+
+ def test_has_all_permissions_with_no_match(self):
+ """Test has_all_permissions returns False when no permissions match."""
+ user_permissions = [Permission.CHATS_READ.value]
+ required = [Permission.DOCUMENTS_READ.value, Permission.DOCUMENTS_CREATE.value]
+ assert has_all_permissions(user_permissions, required) is False
+
+ def test_has_all_permissions_with_full_access(self):
+ """Test has_all_permissions returns True with FULL_ACCESS."""
+ user_permissions = [Permission.FULL_ACCESS.value]
+ required = [
+ Permission.DOCUMENTS_READ.value,
+ Permission.DOCUMENTS_CREATE.value,
+ Permission.SETTINGS_DELETE.value,
+ ]
+ assert has_all_permissions(user_permissions, required) is True
+
+ def test_has_all_permissions_with_empty_user_permissions(self):
+ """Test has_all_permissions returns False with empty user permissions."""
+ assert has_all_permissions([], [Permission.DOCUMENTS_READ.value]) is False
+
+ def test_has_all_permissions_with_none(self):
+ """Test has_all_permissions returns False with None."""
+ assert has_all_permissions(None, [Permission.DOCUMENTS_READ.value]) is False
+
+ def test_has_all_permissions_with_empty_required(self):
+ """Test has_all_permissions returns True with empty required list."""
+ user_permissions = [Permission.DOCUMENTS_READ.value]
+ assert has_all_permissions(user_permissions, []) is True
+
+
+class TestPermissionEnum:
+ """Tests for Permission enum values."""
+
+ def test_permission_values_are_strings(self):
+ """Test all permission values are strings."""
+ for perm in list(Permission):
+ assert isinstance(perm.value, str)
+
+ def test_permission_document_values(self):
+ """Test document permission values."""
+ assert Permission.DOCUMENTS_CREATE.value == "documents:create"
+ assert Permission.DOCUMENTS_READ.value == "documents:read"
+ assert Permission.DOCUMENTS_UPDATE.value == "documents:update"
+ assert Permission.DOCUMENTS_DELETE.value == "documents:delete"
+
+ def test_permission_chat_values(self):
+ """Test chat permission values."""
+ assert Permission.CHATS_CREATE.value == "chats:create"
+ assert Permission.CHATS_READ.value == "chats:read"
+ assert Permission.CHATS_UPDATE.value == "chats:update"
+ assert Permission.CHATS_DELETE.value == "chats:delete"
+
+ def test_permission_llm_config_values(self):
+ """Test LLM config permission values."""
+ assert Permission.LLM_CONFIGS_CREATE.value == "llm_configs:create"
+ assert Permission.LLM_CONFIGS_READ.value == "llm_configs:read"
+ assert Permission.LLM_CONFIGS_UPDATE.value == "llm_configs:update"
+ assert Permission.LLM_CONFIGS_DELETE.value == "llm_configs:delete"
+
+ def test_permission_members_values(self):
+ """Test member permission values."""
+ assert Permission.MEMBERS_INVITE.value == "members:invite"
+ assert Permission.MEMBERS_VIEW.value == "members:view"
+ assert Permission.MEMBERS_REMOVE.value == "members:remove"
+ assert Permission.MEMBERS_MANAGE_ROLES.value == "members:manage_roles"
+
+ def test_permission_full_access_value(self):
+ """Test FULL_ACCESS permission value."""
+ assert Permission.FULL_ACCESS.value == "*"
+
+
+class TestDefaultRolePermissions:
+ """Tests for DEFAULT_ROLE_PERMISSIONS configuration."""
+
+ def test_owner_has_full_access(self):
+ """Test Owner role has full access."""
+ assert Permission.FULL_ACCESS.value in DEFAULT_ROLE_PERMISSIONS["Owner"]
+
+ def test_admin_permissions(self):
+ """Test Admin role has appropriate permissions."""
+ admin_perms = DEFAULT_ROLE_PERMISSIONS["Admin"]
+ # Admin should have document permissions
+ assert Permission.DOCUMENTS_CREATE.value in admin_perms
+ assert Permission.DOCUMENTS_READ.value in admin_perms
+ assert Permission.DOCUMENTS_UPDATE.value in admin_perms
+ assert Permission.DOCUMENTS_DELETE.value in admin_perms
+ # Admin should NOT have settings:delete
+ assert Permission.SETTINGS_DELETE.value not in admin_perms
+
+ def test_editor_permissions(self):
+ """Test Editor role has appropriate permissions."""
+ editor_perms = DEFAULT_ROLE_PERMISSIONS["Editor"]
+ # Editor should have document CRUD
+ assert Permission.DOCUMENTS_CREATE.value in editor_perms
+ assert Permission.DOCUMENTS_READ.value in editor_perms
+ assert Permission.DOCUMENTS_UPDATE.value in editor_perms
+ assert Permission.DOCUMENTS_DELETE.value in editor_perms
+ # Editor should have chat CRUD
+ assert Permission.CHATS_CREATE.value in editor_perms
+ assert Permission.CHATS_READ.value in editor_perms
+ # Editor should NOT have member management
+ assert Permission.MEMBERS_REMOVE.value not in editor_perms
+
+ def test_viewer_permissions(self):
+ """Test Viewer role has read-only permissions."""
+ viewer_perms = DEFAULT_ROLE_PERMISSIONS["Viewer"]
+ # Viewer should have read permissions
+ assert Permission.DOCUMENTS_READ.value in viewer_perms
+ assert Permission.CHATS_READ.value in viewer_perms
+ assert Permission.LLM_CONFIGS_READ.value in viewer_perms
+ # Viewer should NOT have create/update/delete permissions
+ assert Permission.DOCUMENTS_CREATE.value not in viewer_perms
+ assert Permission.DOCUMENTS_UPDATE.value not in viewer_perms
+ assert Permission.DOCUMENTS_DELETE.value not in viewer_perms
+ assert Permission.CHATS_CREATE.value not in viewer_perms
+
+
+class TestGetDefaultRolesConfig:
+ """Tests for get_default_roles_config function."""
+
+ def test_returns_list(self):
+ """Test get_default_roles_config returns a list."""
+ config = get_default_roles_config()
+ assert isinstance(config, list)
+
+ def test_contains_four_roles(self):
+ """Test get_default_roles_config returns 4 roles."""
+ config = get_default_roles_config()
+ assert len(config) == 4
+
+ def test_role_names(self):
+ """Test get_default_roles_config contains expected role names."""
+ config = get_default_roles_config()
+ role_names = [role["name"] for role in config]
+ assert "Owner" in role_names
+ assert "Admin" in role_names
+ assert "Editor" in role_names
+ assert "Viewer" in role_names
+
+ def test_all_roles_are_system_roles(self):
+ """Test all default roles are system roles."""
+ config = get_default_roles_config()
+ for role in config:
+ assert role["is_system_role"] is True
+
+ def test_editor_is_default_role(self):
+ """Test Editor is the default role for new members."""
+ config = get_default_roles_config()
+ editor_role = next(role for role in config if role["name"] == "Editor")
+ assert editor_role["is_default"] is True
+
+ def test_owner_is_not_default_role(self):
+ """Test Owner is not the default role."""
+ config = get_default_roles_config()
+ owner_role = next(role for role in config if role["name"] == "Owner")
+ assert owner_role["is_default"] is False
+
+ def test_role_structure(self):
+ """Test each role has required fields."""
+ config = get_default_roles_config()
+ required_fields = ["name", "description", "permissions", "is_default", "is_system_role"]
+ for role in config:
+ for field in required_fields:
+ assert field in role, f"Role {role.get('name')} missing field {field}"
+
+ def test_owner_role_permissions(self):
+ """Test Owner role has full access permission."""
+ config = get_default_roles_config()
+ owner_role = next(role for role in config if role["name"] == "Owner")
+ assert Permission.FULL_ACCESS.value in owner_role["permissions"]
diff --git a/surfsense_backend/tests/test_rbac.py b/surfsense_backend/tests/test_rbac.py
new file mode 100644
index 000000000..1e0c35bdd
--- /dev/null
+++ b/surfsense_backend/tests/test_rbac.py
@@ -0,0 +1,355 @@
+"""
+Tests for the RBAC (Role-Based Access Control) utility functions.
+
+These tests validate the security-critical RBAC behavior:
+1. Users without membership should NEVER access resources
+2. Permission checks must be strict - no false positives
+3. Owners must have full access
+4. Role permissions must be properly enforced
+"""
+
+import uuid
+from unittest.mock import AsyncMock, MagicMock, patch
+
+import pytest
+from fastapi import HTTPException
+
+# Skip these tests if app dependencies aren't installed
+pytest.importorskip("sqlalchemy")
+pytest.importorskip("fastapi_users")
+
+from app.db import Permission, SearchSpaceMembership, SearchSpaceRole
+from app.utils.rbac import (
+ check_permission,
+ check_search_space_access,
+ generate_invite_code,
+ get_default_role,
+ get_owner_role,
+ get_user_permissions,
+ is_search_space_owner,
+)
+
+
+class TestSecurityCriticalAccessControl:
+ """
+ Critical security tests - these MUST pass to prevent unauthorized access.
+ """
+
+ @pytest.mark.asyncio
+ async def test_non_member_cannot_access_search_space(self, mock_session, mock_user):
+ """
+ SECURITY: Non-members must be denied access with 403.
+ This is critical - allowing access would be a security breach.
+ """
+ search_space_id = 1
+
+ # Simulate user not being a member
+ mock_result = MagicMock()
+ mock_result.scalars.return_value.first.return_value = None
+ mock_session.execute = AsyncMock(return_value=mock_result)
+
+ with pytest.raises(HTTPException) as exc_info:
+ await check_search_space_access(mock_session, mock_user, search_space_id)
+
+ # Must be 403 Forbidden, not 404 or other
+ assert exc_info.value.status_code == 403
+ assert "access" in exc_info.value.detail.lower()
+
+ @pytest.mark.asyncio
+ async def test_member_without_permission_is_denied(self, mock_session, mock_user):
+ """
+ SECURITY: Members without specific permission must be denied.
+ Having membership alone is insufficient for sensitive operations.
+ """
+ search_space_id = 1
+
+ # Member exists but has limited permissions (only read, not write)
+ mock_role = MagicMock(spec=SearchSpaceRole)
+ mock_role.permissions = ["documents:read"] # Does NOT have write
+
+ mock_membership = MagicMock(spec=SearchSpaceMembership)
+ mock_membership.is_owner = False
+ mock_membership.role = mock_role
+
+ mock_result = MagicMock()
+ mock_result.scalars.return_value.first.return_value = mock_membership
+ mock_session.execute = AsyncMock(return_value=mock_result)
+
+ # Attempt to access a write operation - must fail
+ with patch("app.utils.rbac.has_permission", return_value=False):
+ with pytest.raises(HTTPException) as exc_info:
+ await check_permission(
+ mock_session,
+ mock_user,
+ search_space_id,
+ "documents:write",
+ )
+
+ assert exc_info.value.status_code == 403
+
+ @pytest.mark.asyncio
+ async def test_owner_has_full_access_regardless_of_operation(
+ self, mock_session, mock_user
+ ):
+ """
+ SECURITY: Owners must have full access to all operations.
+ This ensures owners can always manage their search spaces.
+ """
+ search_space_id = 1
+
+ mock_membership = MagicMock(spec=SearchSpaceMembership)
+ mock_membership.is_owner = True
+ mock_membership.role = None # Owners may not have explicit roles
+
+ mock_result = MagicMock()
+ mock_result.scalars.return_value.first.return_value = mock_membership
+ mock_session.execute = AsyncMock(return_value=mock_result)
+
+ # Owner should pass permission check with FULL_ACCESS
+ with patch("app.utils.rbac.has_permission", return_value=True) as mock_has_perm:
+ result = await check_permission(
+ mock_session,
+ mock_user,
+ search_space_id,
+ "any:permission",
+ )
+
+ assert result == mock_membership
+ # Verify FULL_ACCESS was checked
+ mock_has_perm.assert_called_once()
+ call_args = mock_has_perm.call_args[0]
+ assert Permission.FULL_ACCESS.value in call_args[0]
+
+
+class TestGetUserPermissions:
+ """Tests for permission retrieval - validates correct permission inheritance."""
+
+ @pytest.mark.asyncio
+ async def test_non_member_has_no_permissions(self, mock_session):
+ """Non-members must have zero permissions."""
+ user_id = uuid.uuid4()
+ search_space_id = 1
+
+ mock_result = MagicMock()
+ mock_result.scalars.return_value.first.return_value = None
+ mock_session.execute = AsyncMock(return_value=mock_result)
+
+ result = await get_user_permissions(mock_session, user_id, search_space_id)
+
+ assert result == []
+ assert len(result) == 0
+
+ @pytest.mark.asyncio
+ async def test_owner_gets_full_access_permission(self, mock_session):
+ """Owners must receive FULL_ACCESS permission."""
+ user_id = uuid.uuid4()
+ search_space_id = 1
+
+ mock_membership = MagicMock(spec=SearchSpaceMembership)
+ mock_membership.is_owner = True
+ mock_membership.role = None
+
+ mock_result = MagicMock()
+ mock_result.scalars.return_value.first.return_value = mock_membership
+ mock_session.execute = AsyncMock(return_value=mock_result)
+
+ result = await get_user_permissions(mock_session, user_id, search_space_id)
+
+ assert Permission.FULL_ACCESS.value in result
+
+ @pytest.mark.asyncio
+ async def test_member_gets_only_role_permissions(self, mock_session):
+ """Members should get exactly the permissions from their role - no more, no less."""
+ user_id = uuid.uuid4()
+ search_space_id = 1
+
+ expected_permissions = ["documents:read", "chats:read"]
+
+ mock_role = MagicMock(spec=SearchSpaceRole)
+ mock_role.permissions = expected_permissions.copy()
+
+ mock_membership = MagicMock(spec=SearchSpaceMembership)
+ mock_membership.is_owner = False
+ mock_membership.role = mock_role
+
+ mock_result = MagicMock()
+ mock_result.scalars.return_value.first.return_value = mock_membership
+ mock_session.execute = AsyncMock(return_value=mock_result)
+
+ result = await get_user_permissions(mock_session, user_id, search_space_id)
+
+ # Must match exactly - no extra permissions sneaking in
+ assert set(result) == set(expected_permissions)
+ assert len(result) == len(expected_permissions)
+
+ @pytest.mark.asyncio
+ async def test_member_without_role_has_no_permissions(self, mock_session):
+ """Members without an assigned role must have empty permissions."""
+ user_id = uuid.uuid4()
+ search_space_id = 1
+
+ mock_membership = MagicMock(spec=SearchSpaceMembership)
+ mock_membership.is_owner = False
+ mock_membership.role = None
+
+ mock_result = MagicMock()
+ mock_result.scalars.return_value.first.return_value = mock_membership
+ mock_session.execute = AsyncMock(return_value=mock_result)
+
+ result = await get_user_permissions(mock_session, user_id, search_space_id)
+
+ assert result == []
+
+
+class TestOwnershipChecks:
+ """Tests for ownership verification."""
+
+ @pytest.mark.asyncio
+ async def test_is_owner_returns_true_only_for_actual_owner(self, mock_session):
+ """is_search_space_owner must return True ONLY for actual owners."""
+ user_id = uuid.uuid4()
+ search_space_id = 1
+
+ mock_membership = MagicMock(spec=SearchSpaceMembership)
+ mock_membership.is_owner = True
+
+ mock_result = MagicMock()
+ mock_result.scalars.return_value.first.return_value = mock_membership
+ mock_session.execute = AsyncMock(return_value=mock_result)
+
+ result = await is_search_space_owner(mock_session, user_id, search_space_id)
+
+ assert result is True
+
+ @pytest.mark.asyncio
+ async def test_is_owner_returns_false_for_non_owner_member(self, mock_session):
+ """Regular members must NOT be identified as owners."""
+ user_id = uuid.uuid4()
+ search_space_id = 1
+
+ mock_membership = MagicMock(spec=SearchSpaceMembership)
+ mock_membership.is_owner = False
+
+ mock_result = MagicMock()
+ mock_result.scalars.return_value.first.return_value = mock_membership
+ mock_session.execute = AsyncMock(return_value=mock_result)
+
+ result = await is_search_space_owner(mock_session, user_id, search_space_id)
+
+ assert result is False
+
+ @pytest.mark.asyncio
+ async def test_is_owner_returns_false_for_non_member(self, mock_session):
+ """Non-members must NOT be identified as owners."""
+ user_id = uuid.uuid4()
+ search_space_id = 1
+
+ mock_result = MagicMock()
+ mock_result.scalars.return_value.first.return_value = None
+ mock_session.execute = AsyncMock(return_value=mock_result)
+
+ result = await is_search_space_owner(mock_session, user_id, search_space_id)
+
+ assert result is False
+
+
+class TestInviteCodeSecurity:
+ """Tests for invite code generation - validates security requirements."""
+
+ def test_invite_codes_are_cryptographically_unique(self):
+ """
+ Invite codes must be cryptographically random to prevent guessing.
+ Generate many codes and verify no collisions.
+ """
+ codes = set()
+ num_codes = 1000
+
+ for _ in range(num_codes):
+ code = generate_invite_code()
+ codes.add(code)
+
+ # All codes must be unique - any collision indicates weak randomness
+ assert len(codes) == num_codes
+
+ def test_invite_code_has_sufficient_entropy(self):
+ """
+ Invite codes must have sufficient length for security.
+ 32 characters of URL-safe base64 = ~192 bits of entropy.
+ """
+ code = generate_invite_code()
+
+ # Minimum 32 characters for adequate security
+ assert len(code) >= 32
+
+ def test_invite_code_is_url_safe(self):
+ """Invite codes must be safe for use in URLs without encoding."""
+ import re
+
+ code = generate_invite_code()
+
+ # Must only contain URL-safe characters
+ assert re.match(r"^[A-Za-z0-9_-]+$", code) is not None
+
+ def test_invite_codes_are_unpredictable(self):
+ """
+ Sequential invite codes must not be predictable.
+ Verify no obvious patterns in consecutive codes.
+ """
+ codes = [generate_invite_code() for _ in range(10)]
+
+ # No two consecutive codes should share significant prefixes
+ for i in range(len(codes) - 1):
+ # First 8 chars should differ between consecutive codes
+ assert codes[i][:8] != codes[i + 1][:8]
+
+
+class TestRoleRetrieval:
+ """Tests for role lookup functions."""
+
+ @pytest.mark.asyncio
+ async def test_get_default_role_returns_correct_role(self, mock_session):
+ """Default role lookup must return the role marked as default."""
+ search_space_id = 1
+
+ mock_role = MagicMock(spec=SearchSpaceRole)
+ mock_role.name = "Viewer"
+ mock_role.is_default = True
+
+ mock_result = MagicMock()
+ mock_result.scalars.return_value.first.return_value = mock_role
+ mock_session.execute = AsyncMock(return_value=mock_result)
+
+ result = await get_default_role(mock_session, search_space_id)
+
+ assert result is not None
+ assert result.is_default is True
+
+ @pytest.mark.asyncio
+ async def test_get_default_role_returns_none_when_no_default(self, mock_session):
+ """Must return None if no default role exists - not raise an error."""
+ search_space_id = 1
+
+ mock_result = MagicMock()
+ mock_result.scalars.return_value.first.return_value = None
+ mock_session.execute = AsyncMock(return_value=mock_result)
+
+ result = await get_default_role(mock_session, search_space_id)
+
+ assert result is None
+
+ @pytest.mark.asyncio
+ async def test_get_owner_role_returns_owner_named_role(self, mock_session):
+ """Owner role lookup must return the role named 'Owner'."""
+ search_space_id = 1
+
+ mock_role = MagicMock(spec=SearchSpaceRole)
+ mock_role.name = "Owner"
+
+ mock_result = MagicMock()
+ mock_result.scalars.return_value.first.return_value = mock_role
+ mock_session.execute = AsyncMock(return_value=mock_result)
+
+ result = await get_owner_role(mock_session, search_space_id)
+
+ assert result is not None
+ assert result.name == "Owner"
diff --git a/surfsense_backend/tests/test_rbac_schemas.py b/surfsense_backend/tests/test_rbac_schemas.py
new file mode 100644
index 000000000..1d4a336c2
--- /dev/null
+++ b/surfsense_backend/tests/test_rbac_schemas.py
@@ -0,0 +1,392 @@
+"""
+Tests for RBAC schemas.
+
+This module tests the Pydantic schemas used for role-based access control.
+"""
+
+from datetime import datetime, timezone
+from uuid import uuid4
+
+import pytest
+from pydantic import ValidationError
+
+from app.schemas.rbac_schemas import (
+ InviteAcceptRequest,
+ InviteAcceptResponse,
+ InviteBase,
+ InviteCreate,
+ InviteInfoResponse,
+ InviteRead,
+ InviteUpdate,
+ MembershipBase,
+ MembershipRead,
+ MembershipReadWithUser,
+ MembershipUpdate,
+ PermissionInfo,
+ PermissionsListResponse,
+ RoleBase,
+ RoleCreate,
+ RoleRead,
+ RoleUpdate,
+ UserSearchSpaceAccess,
+)
+
+
+class TestRoleSchemas:
+ """Tests for role-related schemas."""
+
+ def test_role_base_minimal(self):
+ """Test RoleBase with minimal data."""
+ role = RoleBase(name="TestRole")
+ assert role.name == "TestRole"
+ assert role.description is None
+ assert role.permissions == []
+ assert role.is_default is False
+
+ def test_role_base_full(self):
+ """Test RoleBase with all fields."""
+ role = RoleBase(
+ name="Admin",
+ description="Administrator role",
+ permissions=["documents:read", "documents:write"],
+ is_default=True,
+ )
+ assert role.name == "Admin"
+ assert role.description == "Administrator role"
+ assert len(role.permissions) == 2
+ assert role.is_default is True
+
+ def test_role_base_name_validation(self):
+ """Test RoleBase name length validation."""
+ # Empty name should fail
+ with pytest.raises(ValidationError):
+ RoleBase(name="")
+
+ # Name at max length should work
+ role = RoleBase(name="x" * 100)
+ assert len(role.name) == 100
+
+ # Name over max length should fail
+ with pytest.raises(ValidationError):
+ RoleBase(name="x" * 101)
+
+ def test_role_base_description_validation(self):
+ """Test RoleBase description length validation."""
+ # Description at max length should work
+ role = RoleBase(name="Test", description="x" * 500)
+ assert len(role.description) == 500
+
+ # Description over max length should fail
+ with pytest.raises(ValidationError):
+ RoleBase(name="Test", description="x" * 501)
+
+ def test_role_create(self):
+ """Test RoleCreate schema."""
+ role = RoleCreate(
+ name="Editor",
+ permissions=["documents:create", "documents:read"],
+ )
+ assert role.name == "Editor"
+
+ def test_role_update_partial(self):
+ """Test RoleUpdate with partial data."""
+ update = RoleUpdate(name="NewName")
+ assert update.name == "NewName"
+ assert update.description is None
+ assert update.permissions is None
+ assert update.is_default is None
+
+ def test_role_update_full(self):
+ """Test RoleUpdate with all fields."""
+ update = RoleUpdate(
+ name="UpdatedRole",
+ description="Updated description",
+ permissions=["chats:read"],
+ is_default=True,
+ )
+ assert update.permissions == ["chats:read"]
+
+ def test_role_read(self):
+ """Test RoleRead schema."""
+ now = datetime.now(timezone.utc)
+ role = RoleRead(
+ id=1,
+ name="Viewer",
+ description="View-only access",
+ permissions=["documents:read"],
+ is_default=False,
+ search_space_id=5,
+ is_system_role=True,
+ created_at=now,
+ )
+ assert role.id == 1
+ assert role.is_system_role is True
+ assert role.search_space_id == 5
+
+
+class TestMembershipSchemas:
+ """Tests for membership-related schemas."""
+
+ def test_membership_base(self):
+ """Test MembershipBase schema."""
+ membership = MembershipBase()
+ assert membership is not None
+
+ def test_membership_update(self):
+ """Test MembershipUpdate schema."""
+ update = MembershipUpdate(role_id=5)
+ assert update.role_id == 5
+
+ def test_membership_update_optional(self):
+ """Test MembershipUpdate with no data."""
+ update = MembershipUpdate()
+ assert update.role_id is None
+
+ def test_membership_read(self):
+ """Test MembershipRead schema."""
+ now = datetime.now(timezone.utc)
+ user_id = uuid4()
+ membership = MembershipRead(
+ id=1,
+ user_id=user_id,
+ search_space_id=10,
+ role_id=2,
+ is_owner=False,
+ joined_at=now,
+ created_at=now,
+ role=None,
+ )
+ assert membership.user_id == user_id
+ assert membership.search_space_id == 10
+ assert membership.is_owner is False
+
+ def test_membership_read_with_role(self):
+ """Test MembershipRead with nested role."""
+ now = datetime.now(timezone.utc)
+ user_id = uuid4()
+ role = RoleRead(
+ id=2,
+ name="Editor",
+ permissions=["documents:create"],
+ is_default=True,
+ search_space_id=10,
+ is_system_role=True,
+ created_at=now,
+ )
+ membership = MembershipRead(
+ id=1,
+ user_id=user_id,
+ search_space_id=10,
+ role_id=2,
+ is_owner=False,
+ joined_at=now,
+ created_at=now,
+ role=role,
+ )
+ assert membership.role.name == "Editor"
+
+ def test_membership_read_with_user(self):
+ """Test MembershipReadWithUser schema."""
+ now = datetime.now(timezone.utc)
+ user_id = uuid4()
+ membership = MembershipReadWithUser(
+ id=1,
+ user_id=user_id,
+ search_space_id=10,
+ role_id=2,
+ is_owner=True,
+ joined_at=now,
+ created_at=now,
+ user_email="test@example.com",
+ user_is_active=True,
+ )
+ assert membership.user_email == "test@example.com"
+ assert membership.user_is_active is True
+
+
+class TestInviteSchemas:
+ """Tests for invite-related schemas."""
+
+ def test_invite_base_minimal(self):
+ """Test InviteBase with minimal data."""
+ invite = InviteBase()
+ assert invite.name is None
+ assert invite.role_id is None
+ assert invite.expires_at is None
+ assert invite.max_uses is None
+
+ def test_invite_base_full(self):
+ """Test InviteBase with all fields."""
+ expires = datetime.now(timezone.utc)
+ invite = InviteBase(
+ name="Team Invite",
+ role_id=3,
+ expires_at=expires,
+ max_uses=10,
+ )
+ assert invite.name == "Team Invite"
+ assert invite.max_uses == 10
+
+ def test_invite_base_max_uses_validation(self):
+ """Test InviteBase max_uses must be >= 1."""
+ with pytest.raises(ValidationError):
+ InviteBase(max_uses=0)
+
+ # Valid minimum
+ invite = InviteBase(max_uses=1)
+ assert invite.max_uses == 1
+
+ def test_invite_create(self):
+ """Test InviteCreate schema."""
+ invite = InviteCreate(
+ name="Dev Team",
+ role_id=2,
+ max_uses=5,
+ )
+ assert invite.name == "Dev Team"
+
+ def test_invite_update_partial(self):
+ """Test InviteUpdate with partial data."""
+ update = InviteUpdate(is_active=False)
+ assert update.is_active is False
+ assert update.name is None
+
+ def test_invite_update_full(self):
+ """Test InviteUpdate with all fields."""
+ expires = datetime.now(timezone.utc)
+ update = InviteUpdate(
+ name="Updated Invite",
+ role_id=4,
+ expires_at=expires,
+ max_uses=20,
+ is_active=True,
+ )
+ assert update.name == "Updated Invite"
+
+ def test_invite_read(self):
+ """Test InviteRead schema."""
+ now = datetime.now(timezone.utc)
+ user_id = uuid4()
+ invite = InviteRead(
+ id=1,
+ invite_code="abc123xyz",
+ search_space_id=5,
+ created_by_id=user_id,
+ uses_count=3,
+ is_active=True,
+ created_at=now,
+ )
+ assert invite.invite_code == "abc123xyz"
+ assert invite.uses_count == 3
+
+ def test_invite_accept_request(self):
+ """Test InviteAcceptRequest schema."""
+ request = InviteAcceptRequest(invite_code="valid-code-123")
+ assert request.invite_code == "valid-code-123"
+
+ def test_invite_accept_request_validation(self):
+ """Test InviteAcceptRequest requires non-empty code."""
+ with pytest.raises(ValidationError):
+ InviteAcceptRequest(invite_code="")
+
+ def test_invite_accept_response(self):
+ """Test InviteAcceptResponse schema."""
+ response = InviteAcceptResponse(
+ message="Successfully joined",
+ search_space_id=10,
+ search_space_name="My Workspace",
+ role_name="Editor",
+ )
+ assert response.message == "Successfully joined"
+ assert response.search_space_name == "My Workspace"
+
+ def test_invite_info_response(self):
+ """Test InviteInfoResponse schema."""
+ response = InviteInfoResponse(
+ search_space_name="Public Space",
+ role_name="Viewer",
+ is_valid=True,
+ message=None,
+ )
+ assert response.is_valid is True
+
+ def test_invite_info_response_invalid(self):
+ """Test InviteInfoResponse for invalid invite."""
+ response = InviteInfoResponse(
+ search_space_name="",
+ role_name=None,
+ is_valid=False,
+ message="Invite has expired",
+ )
+ assert response.is_valid is False
+ assert response.message == "Invite has expired"
+
+
+class TestPermissionSchemas:
+ """Tests for permission-related schemas."""
+
+ def test_permission_info(self):
+ """Test PermissionInfo schema."""
+ perm = PermissionInfo(
+ value="documents:create",
+ name="Create Documents",
+ category="Documents",
+ )
+ assert perm.value == "documents:create"
+ assert perm.category == "Documents"
+
+ def test_permissions_list_response(self):
+ """Test PermissionsListResponse schema."""
+ perms = [
+ PermissionInfo(value="documents:read", name="Read Documents", category="Documents"),
+ PermissionInfo(value="chats:read", name="Read Chats", category="Chats"),
+ ]
+ response = PermissionsListResponse(permissions=perms)
+ assert len(response.permissions) == 2
+
+ def test_permissions_list_response_empty(self):
+ """Test PermissionsListResponse with empty list."""
+ response = PermissionsListResponse(permissions=[])
+ assert response.permissions == []
+
+
+class TestUserAccessSchemas:
+ """Tests for user access schemas."""
+
+ def test_user_search_space_access(self):
+ """Test UserSearchSpaceAccess schema."""
+ access = UserSearchSpaceAccess(
+ search_space_id=5,
+ search_space_name="My Workspace",
+ is_owner=True,
+ role_name="Owner",
+ permissions=["*"],
+ )
+ assert access.search_space_id == 5
+ assert access.is_owner is True
+ assert "*" in access.permissions
+
+ def test_user_search_space_access_member(self):
+ """Test UserSearchSpaceAccess for regular member."""
+ access = UserSearchSpaceAccess(
+ search_space_id=10,
+ search_space_name="Team Space",
+ is_owner=False,
+ role_name="Editor",
+ permissions=["documents:create", "documents:read", "chats:create"],
+ )
+ assert access.is_owner is False
+ assert access.role_name == "Editor"
+ assert len(access.permissions) == 3
+
+ def test_user_search_space_access_no_role(self):
+ """Test UserSearchSpaceAccess with no role."""
+ access = UserSearchSpaceAccess(
+ search_space_id=15,
+ search_space_name="Guest Space",
+ is_owner=False,
+ role_name=None,
+ permissions=[],
+ )
+ assert access.role_name is None
+ assert access.permissions == []
diff --git a/surfsense_backend/tests/test_rbac_utils.py b/surfsense_backend/tests/test_rbac_utils.py
new file mode 100644
index 000000000..193baada6
--- /dev/null
+++ b/surfsense_backend/tests/test_rbac_utils.py
@@ -0,0 +1,340 @@
+"""
+Tests for RBAC utility functions.
+
+This module tests the RBAC helper functions used for access control.
+"""
+
+from unittest.mock import AsyncMock, MagicMock, patch
+from uuid import uuid4
+
+import pytest
+from fastapi import HTTPException
+
+from app.db import Permission
+from app.utils.rbac import (
+ check_permission,
+ check_search_space_access,
+ generate_invite_code,
+ get_user_membership,
+ get_user_permissions,
+ is_search_space_owner,
+)
+
+
+class TestGenerateInviteCode:
+ """Tests for generate_invite_code function."""
+
+ def test_generates_string(self):
+ """Test that function generates a string."""
+ code = generate_invite_code()
+ assert isinstance(code, str)
+
+ def test_generates_unique_codes(self):
+ """Test that function generates unique codes."""
+ codes = {generate_invite_code() for _ in range(100)}
+ assert len(codes) == 100 # All unique
+
+ def test_code_is_url_safe(self):
+ """Test that generated code is URL-safe."""
+ code = generate_invite_code()
+ # URL-safe characters: alphanumeric, hyphen, underscore
+ valid_chars = set("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_")
+ assert all(c in valid_chars for c in code)
+
+ def test_code_length(self):
+ """Test that generated code has expected length."""
+ code = generate_invite_code()
+ # token_urlsafe(24) produces ~32 characters
+ assert len(code) == 32
+
+
+class TestGetUserMembership:
+ """Tests for get_user_membership function."""
+
+ @pytest.mark.asyncio
+ async def test_returns_membership(self):
+ """Test returns membership when found."""
+ mock_membership = MagicMock()
+ mock_membership.is_owner = True
+
+ mock_result = MagicMock()
+ mock_result.scalars.return_value.first.return_value = mock_membership
+
+ mock_session = AsyncMock()
+ mock_session.execute.return_value = mock_result
+
+ user_id = uuid4()
+ result = await get_user_membership(mock_session, user_id, 1)
+
+ assert result == mock_membership
+ assert result.is_owner is True
+
+ @pytest.mark.asyncio
+ async def test_returns_none_when_not_found(self):
+ """Test returns None when membership not found."""
+ mock_result = MagicMock()
+ mock_result.scalars.return_value.first.return_value = None
+
+ mock_session = AsyncMock()
+ mock_session.execute.return_value = mock_result
+
+ user_id = uuid4()
+ result = await get_user_membership(mock_session, user_id, 999)
+
+ assert result is None
+
+
+class TestGetUserPermissions:
+ """Tests for get_user_permissions function."""
+
+ @pytest.mark.asyncio
+ async def test_owner_has_full_access(self):
+ """Test owner gets FULL_ACCESS permission."""
+ mock_membership = MagicMock()
+ mock_membership.is_owner = True
+ mock_membership.role = None
+
+ with patch("app.utils.rbac.get_user_membership", return_value=mock_membership):
+ mock_session = AsyncMock()
+ user_id = uuid4()
+
+ permissions = await get_user_permissions(mock_session, user_id, 1)
+
+ assert Permission.FULL_ACCESS.value in permissions
+
+ @pytest.mark.asyncio
+ async def test_member_gets_role_permissions(self):
+ """Test member gets permissions from their role."""
+ mock_role = MagicMock()
+ mock_role.permissions = ["documents:read", "chats:create"]
+
+ mock_membership = MagicMock()
+ mock_membership.is_owner = False
+ mock_membership.role = mock_role
+
+ with patch("app.utils.rbac.get_user_membership", return_value=mock_membership):
+ mock_session = AsyncMock()
+ user_id = uuid4()
+
+ permissions = await get_user_permissions(mock_session, user_id, 1)
+
+ assert permissions == ["documents:read", "chats:create"]
+
+ @pytest.mark.asyncio
+ async def test_no_membership_returns_empty(self):
+ """Test no membership returns empty permissions."""
+ with patch("app.utils.rbac.get_user_membership", return_value=None):
+ mock_session = AsyncMock()
+ user_id = uuid4()
+
+ permissions = await get_user_permissions(mock_session, user_id, 1)
+
+ assert permissions == []
+
+ @pytest.mark.asyncio
+ async def test_no_role_returns_empty(self):
+ """Test member without role returns empty permissions."""
+ mock_membership = MagicMock()
+ mock_membership.is_owner = False
+ mock_membership.role = None
+
+ with patch("app.utils.rbac.get_user_membership", return_value=mock_membership):
+ mock_session = AsyncMock()
+ user_id = uuid4()
+
+ permissions = await get_user_permissions(mock_session, user_id, 1)
+
+ assert permissions == []
+
+
+class TestCheckPermission:
+ """Tests for check_permission function."""
+
+ @pytest.mark.asyncio
+ async def test_owner_passes_any_permission(self):
+ """Test owner passes any permission check."""
+ mock_membership = MagicMock()
+ mock_membership.is_owner = True
+ mock_membership.role = None
+
+ with patch("app.utils.rbac.get_user_membership", return_value=mock_membership):
+ mock_session = AsyncMock()
+ mock_user = MagicMock()
+ mock_user.id = uuid4()
+
+ result = await check_permission(
+ mock_session,
+ mock_user,
+ 1,
+ Permission.SETTINGS_DELETE.value,
+ )
+
+ assert result == mock_membership
+
+ @pytest.mark.asyncio
+ async def test_member_with_permission_passes(self):
+ """Test member with required permission passes."""
+ mock_role = MagicMock()
+ mock_role.permissions = [Permission.DOCUMENTS_READ.value, Permission.CHATS_READ.value]
+
+ mock_membership = MagicMock()
+ mock_membership.is_owner = False
+ mock_membership.role = mock_role
+
+ with patch("app.utils.rbac.get_user_membership", return_value=mock_membership):
+ mock_session = AsyncMock()
+ mock_user = MagicMock()
+ mock_user.id = uuid4()
+
+ result = await check_permission(
+ mock_session,
+ mock_user,
+ 1,
+ Permission.DOCUMENTS_READ.value,
+ )
+
+ assert result == mock_membership
+
+ @pytest.mark.asyncio
+ async def test_member_without_permission_raises(self):
+ """Test member without required permission raises HTTPException."""
+ mock_role = MagicMock()
+ mock_role.permissions = [Permission.DOCUMENTS_READ.value]
+
+ mock_membership = MagicMock()
+ mock_membership.is_owner = False
+ mock_membership.role = mock_role
+
+ with patch("app.utils.rbac.get_user_membership", return_value=mock_membership):
+ mock_session = AsyncMock()
+ mock_user = MagicMock()
+ mock_user.id = uuid4()
+
+ with pytest.raises(HTTPException) as exc_info:
+ await check_permission(
+ mock_session,
+ mock_user,
+ 1,
+ Permission.DOCUMENTS_DELETE.value,
+ )
+
+ assert exc_info.value.status_code == 403
+
+ @pytest.mark.asyncio
+ async def test_no_membership_raises(self):
+ """Test user without membership raises HTTPException."""
+ with patch("app.utils.rbac.get_user_membership", return_value=None):
+ mock_session = AsyncMock()
+ mock_user = MagicMock()
+ mock_user.id = uuid4()
+
+ with pytest.raises(HTTPException) as exc_info:
+ await check_permission(
+ mock_session,
+ mock_user,
+ 1,
+ Permission.DOCUMENTS_READ.value,
+ )
+
+ assert exc_info.value.status_code == 403
+ assert "access to this search space" in exc_info.value.detail
+
+ @pytest.mark.asyncio
+ async def test_custom_error_message(self):
+ """Test custom error message is used."""
+ mock_role = MagicMock()
+ mock_role.permissions = []
+
+ mock_membership = MagicMock()
+ mock_membership.is_owner = False
+ mock_membership.role = mock_role
+
+ with patch("app.utils.rbac.get_user_membership", return_value=mock_membership):
+ mock_session = AsyncMock()
+ mock_user = MagicMock()
+ mock_user.id = uuid4()
+
+ with pytest.raises(HTTPException) as exc_info:
+ await check_permission(
+ mock_session,
+ mock_user,
+ 1,
+ Permission.DOCUMENTS_DELETE.value,
+ error_message="Custom error message",
+ )
+
+ assert exc_info.value.detail == "Custom error message"
+
+
+class TestCheckSearchSpaceAccess:
+ """Tests for check_search_space_access function."""
+
+ @pytest.mark.asyncio
+ async def test_member_has_access(self):
+ """Test member with any membership has access."""
+ mock_membership = MagicMock()
+
+ with patch("app.utils.rbac.get_user_membership", return_value=mock_membership):
+ mock_session = AsyncMock()
+ mock_user = MagicMock()
+ mock_user.id = uuid4()
+
+ result = await check_search_space_access(mock_session, mock_user, 1)
+
+ assert result == mock_membership
+
+ @pytest.mark.asyncio
+ async def test_no_membership_raises(self):
+ """Test user without membership raises HTTPException."""
+ with patch("app.utils.rbac.get_user_membership", return_value=None):
+ mock_session = AsyncMock()
+ mock_user = MagicMock()
+ mock_user.id = uuid4()
+
+ with pytest.raises(HTTPException) as exc_info:
+ await check_search_space_access(mock_session, mock_user, 1)
+
+ assert exc_info.value.status_code == 403
+
+
+class TestIsSearchSpaceOwner:
+ """Tests for is_search_space_owner function."""
+
+ @pytest.mark.asyncio
+ async def test_returns_true_for_owner(self):
+ """Test returns True when user is owner."""
+ mock_membership = MagicMock()
+ mock_membership.is_owner = True
+
+ with patch("app.utils.rbac.get_user_membership", return_value=mock_membership):
+ mock_session = AsyncMock()
+ user_id = uuid4()
+
+ result = await is_search_space_owner(mock_session, user_id, 1)
+
+ assert result is True
+
+ @pytest.mark.asyncio
+ async def test_returns_false_for_non_owner(self):
+ """Test returns False when user is not owner."""
+ mock_membership = MagicMock()
+ mock_membership.is_owner = False
+
+ with patch("app.utils.rbac.get_user_membership", return_value=mock_membership):
+ mock_session = AsyncMock()
+ user_id = uuid4()
+
+ result = await is_search_space_owner(mock_session, user_id, 1)
+
+ assert result is False
+
+ @pytest.mark.asyncio
+ async def test_returns_false_for_no_membership(self):
+ """Test returns False when user has no membership."""
+ with patch("app.utils.rbac.get_user_membership", return_value=None):
+ mock_session = AsyncMock()
+ user_id = uuid4()
+
+ result = await is_search_space_owner(mock_session, user_id, 1)
+
+ assert result is False
diff --git a/surfsense_backend/tests/test_schemas.py b/surfsense_backend/tests/test_schemas.py
new file mode 100644
index 000000000..1aabc63bd
--- /dev/null
+++ b/surfsense_backend/tests/test_schemas.py
@@ -0,0 +1,569 @@
+"""
+Tests for Pydantic schema models.
+
+This module tests schema validation, serialization, and deserialization
+for all schema models used in the application.
+"""
+
+from datetime import datetime, timezone
+from uuid import uuid4
+
+import pytest
+from pydantic import ValidationError
+
+from app.db import ChatType, DocumentType, LiteLLMProvider
+from app.schemas.base import IDModel, TimestampModel
+from app.schemas.chats import (
+ AISDKChatRequest,
+ ChatBase,
+ ChatCreate,
+ ChatRead,
+ ChatReadWithoutMessages,
+ ChatUpdate,
+ ClientAttachment,
+ ToolInvocation,
+)
+from app.schemas.chunks import ChunkBase, ChunkCreate, ChunkRead, ChunkUpdate
+from app.schemas.documents import (
+ DocumentBase,
+ DocumentRead,
+ DocumentsCreate,
+ DocumentUpdate,
+ DocumentWithChunksRead,
+ ExtensionDocumentContent,
+ ExtensionDocumentMetadata,
+ PaginatedResponse,
+)
+from app.schemas.llm_config import (
+ LLMConfigBase,
+ LLMConfigCreate,
+ LLMConfigRead,
+ LLMConfigUpdate,
+)
+from app.schemas.search_space import (
+ SearchSpaceBase,
+ SearchSpaceCreate,
+ SearchSpaceRead,
+ SearchSpaceUpdate,
+ SearchSpaceWithStats,
+)
+
+
+class TestBaseSchemas:
+ """Tests for base schema models."""
+
+ def test_timestamp_model(self):
+ """Test TimestampModel with valid datetime."""
+ now = datetime.now(timezone.utc)
+ model = TimestampModel(created_at=now)
+ assert model.created_at == now
+
+ def test_id_model(self):
+ """Test IDModel with valid ID."""
+ model = IDModel(id=1)
+ assert model.id == 1
+
+ def test_id_model_with_zero(self):
+ """Test IDModel accepts zero."""
+ model = IDModel(id=0)
+ assert model.id == 0
+
+
+class TestChatSchemas:
+ """Tests for chat-related schema models."""
+
+ def test_chat_base_valid(self):
+ """Test ChatBase with valid data."""
+ chat = ChatBase(
+ type=ChatType.QNA,
+ title="Test Chat",
+ messages=[{"role": "user", "content": "Hello"}],
+ search_space_id=1,
+ )
+ assert chat.type == ChatType.QNA
+ assert chat.title == "Test Chat"
+ assert chat.search_space_id == 1
+ assert chat.state_version == 1
+
+ def test_chat_base_with_connectors(self):
+ """Test ChatBase with initial connectors."""
+ chat = ChatBase(
+ type=ChatType.QNA,
+ title="Test Chat",
+ initial_connectors=["slack", "notion"],
+ messages=[],
+ search_space_id=1,
+ )
+ assert chat.initial_connectors == ["slack", "notion"]
+
+ def test_chat_base_default_state_version(self):
+ """Test ChatBase default state_version."""
+ chat = ChatBase(
+ type=ChatType.QNA,
+ title="Test Chat",
+ messages=[],
+ search_space_id=1,
+ )
+ assert chat.state_version == 1
+
+ def test_chat_create(self):
+ """Test ChatCreate schema."""
+ chat = ChatCreate(
+ type=ChatType.QNA,
+ title="New Chat",
+ messages=[{"role": "user", "content": "Test"}],
+ search_space_id=1,
+ )
+ assert chat.title == "New Chat"
+
+ def test_chat_update(self):
+ """Test ChatUpdate schema."""
+ chat = ChatUpdate(
+ type=ChatType.QNA,
+ title="Updated Chat",
+ messages=[{"role": "user", "content": "Updated"}],
+ search_space_id=1,
+ state_version=2,
+ )
+ assert chat.state_version == 2
+
+ def test_chat_read(self):
+ """Test ChatRead schema."""
+ now = datetime.now(timezone.utc)
+ chat = ChatRead(
+ id=1,
+ type=ChatType.QNA,
+ title="Read Chat",
+ messages=[],
+ search_space_id=1,
+ created_at=now,
+ )
+ assert chat.id == 1
+ assert chat.created_at == now
+
+ def test_chat_read_without_messages(self):
+ """Test ChatReadWithoutMessages schema."""
+ now = datetime.now(timezone.utc)
+ chat = ChatReadWithoutMessages(
+ id=1,
+ type=ChatType.QNA,
+ title="Chat Without Messages",
+ search_space_id=1,
+ created_at=now,
+ )
+ assert chat.id == 1
+ assert not hasattr(chat, "messages") or "messages" not in chat.model_fields
+
+ def test_client_attachment(self):
+ """Test ClientAttachment schema."""
+ attachment = ClientAttachment(
+ name="test.pdf",
+ content_type="application/pdf",
+ url="https://example.com/test.pdf",
+ )
+ assert attachment.name == "test.pdf"
+ assert attachment.content_type == "application/pdf"
+
+ def test_tool_invocation(self):
+ """Test ToolInvocation schema."""
+ tool = ToolInvocation(
+ tool_call_id="tc_123",
+ tool_name="search",
+ args={"query": "test"},
+ result={"results": []},
+ )
+ assert tool.tool_call_id == "tc_123"
+ assert tool.tool_name == "search"
+
+ def test_aisdk_chat_request(self):
+ """Test AISDKChatRequest schema."""
+ request = AISDKChatRequest(
+ messages=[{"role": "user", "content": "Hello"}],
+ data={"search_space_id": 1},
+ )
+ assert len(request.messages) == 1
+ assert request.data["search_space_id"] == 1
+
+ def test_aisdk_chat_request_no_data(self):
+ """Test AISDKChatRequest without data."""
+ request = AISDKChatRequest(messages=[{"role": "user", "content": "Hello"}])
+ assert request.data is None
+
+
+class TestChunkSchemas:
+ """Tests for chunk-related schema models."""
+
+ def test_chunk_base(self):
+ """Test ChunkBase schema."""
+ chunk = ChunkBase(content="Test content", document_id=1)
+ assert chunk.content == "Test content"
+ assert chunk.document_id == 1
+
+ def test_chunk_create(self):
+ """Test ChunkCreate schema."""
+ chunk = ChunkCreate(content="New chunk content", document_id=1)
+ assert chunk.content == "New chunk content"
+
+ def test_chunk_update(self):
+ """Test ChunkUpdate schema."""
+ chunk = ChunkUpdate(content="Updated content", document_id=1)
+ assert chunk.content == "Updated content"
+
+ def test_chunk_read(self):
+ """Test ChunkRead schema."""
+ now = datetime.now(timezone.utc)
+ chunk = ChunkRead(
+ id=1,
+ content="Read chunk",
+ document_id=1,
+ created_at=now,
+ )
+ assert chunk.id == 1
+ assert chunk.created_at == now
+
+
+class TestDocumentSchemas:
+ """Tests for document-related schema models."""
+
+ def test_extension_document_metadata(self):
+ """Test ExtensionDocumentMetadata schema."""
+ metadata = ExtensionDocumentMetadata(
+ BrowsingSessionId="session123",
+ VisitedWebPageURL="https://example.com",
+ VisitedWebPageTitle="Example Page",
+ VisitedWebPageDateWithTimeInISOString="2024-01-01T00:00:00Z",
+ VisitedWebPageReffererURL="https://google.com",
+ VisitedWebPageVisitDurationInMilliseconds="5000",
+ )
+ assert metadata.BrowsingSessionId == "session123"
+ assert metadata.VisitedWebPageURL == "https://example.com"
+
+ def test_extension_document_content(self):
+ """Test ExtensionDocumentContent schema."""
+ metadata = ExtensionDocumentMetadata(
+ BrowsingSessionId="session123",
+ VisitedWebPageURL="https://example.com",
+ VisitedWebPageTitle="Example Page",
+ VisitedWebPageDateWithTimeInISOString="2024-01-01T00:00:00Z",
+ VisitedWebPageReffererURL="https://google.com",
+ VisitedWebPageVisitDurationInMilliseconds="5000",
+ )
+ content = ExtensionDocumentContent(
+ metadata=metadata,
+ pageContent="This is the page content",
+ )
+ assert content.pageContent == "This is the page content"
+ assert content.metadata.VisitedWebPageTitle == "Example Page"
+
+ def test_document_base_with_string_content(self):
+ """Test DocumentBase with string content."""
+ doc = DocumentBase(
+ document_type=DocumentType.FILE,
+ content="This is document content",
+ search_space_id=1,
+ )
+ assert doc.content == "This is document content"
+
+ def test_document_base_with_list_content(self):
+ """Test DocumentBase with list content."""
+ doc = DocumentBase(
+ document_type=DocumentType.FILE,
+ content=["Part 1", "Part 2"],
+ search_space_id=1,
+ )
+ assert len(doc.content) == 2
+
+ def test_documents_create(self):
+ """Test DocumentsCreate schema."""
+ doc = DocumentsCreate(
+ document_type=DocumentType.CRAWLED_URL,
+ content="Crawled content",
+ search_space_id=1,
+ )
+ assert doc.document_type == DocumentType.CRAWLED_URL
+
+ def test_document_update(self):
+ """Test DocumentUpdate schema."""
+ doc = DocumentUpdate(
+ document_type=DocumentType.FILE,
+ content="Updated content",
+ search_space_id=1,
+ )
+ assert doc.content == "Updated content"
+
+ def test_document_read(self):
+ """Test DocumentRead schema."""
+ now = datetime.now(timezone.utc)
+ doc = DocumentRead(
+ id=1,
+ title="Test Document",
+ document_type=DocumentType.FILE,
+ document_metadata={"key": "value"},
+ content="Content",
+ created_at=now,
+ search_space_id=1,
+ )
+ assert doc.id == 1
+ assert doc.title == "Test Document"
+ assert doc.document_metadata["key"] == "value"
+
+ def test_document_with_chunks_read(self):
+ """Test DocumentWithChunksRead schema."""
+ now = datetime.now(timezone.utc)
+ doc = DocumentWithChunksRead(
+ id=1,
+ title="Test Document",
+ document_type=DocumentType.FILE,
+ document_metadata={},
+ content="Content",
+ created_at=now,
+ search_space_id=1,
+ chunks=[
+ ChunkRead(id=1, content="Chunk 1", document_id=1, created_at=now),
+ ChunkRead(id=2, content="Chunk 2", document_id=1, created_at=now),
+ ],
+ )
+ assert len(doc.chunks) == 2
+
+ def test_paginated_response(self):
+ """Test PaginatedResponse schema."""
+ response = PaginatedResponse[dict](
+ items=[{"id": 1}, {"id": 2}],
+ total=10,
+ )
+ assert len(response.items) == 2
+ assert response.total == 10
+
+
+class TestLLMConfigSchemas:
+ """Tests for LLM config schema models."""
+
+ def test_llm_config_base(self):
+ """Test LLMConfigBase schema."""
+ config = LLMConfigBase(
+ name="GPT-4 Config",
+ provider=LiteLLMProvider.OPENAI,
+ model_name="gpt-4",
+ api_key="sk-test123",
+ )
+ assert config.name == "GPT-4 Config"
+ assert config.provider == LiteLLMProvider.OPENAI
+ assert config.language == "English" # Default value
+
+ def test_llm_config_base_with_custom_provider(self):
+ """Test LLMConfigBase with custom provider."""
+ config = LLMConfigBase(
+ name="Custom LLM",
+ provider=LiteLLMProvider.CUSTOM,
+ custom_provider="my-provider",
+ model_name="my-model",
+ api_key="test-key",
+ api_base="https://my-api.com/v1",
+ )
+ assert config.custom_provider == "my-provider"
+ assert config.api_base == "https://my-api.com/v1"
+
+ def test_llm_config_base_with_litellm_params(self):
+ """Test LLMConfigBase with litellm params."""
+ config = LLMConfigBase(
+ name="Config with Params",
+ provider=LiteLLMProvider.ANTHROPIC,
+ model_name="claude-3-opus",
+ api_key="test-key",
+ litellm_params={"temperature": 0.7, "max_tokens": 1000},
+ )
+ assert config.litellm_params["temperature"] == 0.7
+
+ def test_llm_config_create(self):
+ """Test LLMConfigCreate schema."""
+ config = LLMConfigCreate(
+ name="New Config",
+ provider=LiteLLMProvider.GROQ,
+ model_name="llama-3",
+ api_key="gsk-test",
+ search_space_id=1,
+ )
+ assert config.search_space_id == 1
+
+ def test_llm_config_update_partial(self):
+ """Test LLMConfigUpdate with partial data."""
+ update = LLMConfigUpdate(name="Updated Name")
+ assert update.name == "Updated Name"
+ assert update.provider is None
+ assert update.model_name is None
+
+ def test_llm_config_update_full(self):
+ """Test LLMConfigUpdate with full data."""
+ update = LLMConfigUpdate(
+ name="Full Update",
+ provider=LiteLLMProvider.MISTRAL,
+ model_name="mistral-large",
+ api_key="new-key",
+ language="French",
+ )
+ assert update.language == "French"
+
+ def test_llm_config_read(self):
+ """Test LLMConfigRead schema."""
+ now = datetime.now(timezone.utc)
+ config = LLMConfigRead(
+ id=1,
+ name="Read Config",
+ provider=LiteLLMProvider.OPENAI,
+ model_name="gpt-4",
+ api_key="sk-test",
+ created_at=now,
+ search_space_id=1,
+ )
+ assert config.id == 1
+ assert config.created_at == now
+
+ def test_llm_config_read_global(self):
+ """Test LLMConfigRead for global config (no search_space_id)."""
+ config = LLMConfigRead(
+ id=-1,
+ name="Global Config",
+ provider=LiteLLMProvider.OPENAI,
+ model_name="gpt-4",
+ api_key="sk-global",
+ created_at=None,
+ search_space_id=None,
+ )
+ assert config.id == -1
+ assert config.search_space_id is None
+
+
+class TestSearchSpaceSchemas:
+ """Tests for search space schema models."""
+
+ def test_search_space_base(self):
+ """Test SearchSpaceBase schema."""
+ space = SearchSpaceBase(name="My Search Space")
+ assert space.name == "My Search Space"
+ assert space.description is None
+
+ def test_search_space_base_with_description(self):
+ """Test SearchSpaceBase with description."""
+ space = SearchSpaceBase(
+ name="My Search Space",
+ description="A space for searching",
+ )
+ assert space.description == "A space for searching"
+
+ def test_search_space_create_defaults(self):
+ """Test SearchSpaceCreate with default values."""
+ space = SearchSpaceCreate(name="New Space")
+ assert space.citations_enabled is True
+ assert space.qna_custom_instructions is None
+
+ def test_search_space_create_custom(self):
+ """Test SearchSpaceCreate with custom values."""
+ space = SearchSpaceCreate(
+ name="Custom Space",
+ description="Custom description",
+ citations_enabled=False,
+ qna_custom_instructions="Be concise",
+ )
+ assert space.citations_enabled is False
+ assert space.qna_custom_instructions == "Be concise"
+
+ def test_search_space_update_partial(self):
+ """Test SearchSpaceUpdate with partial data."""
+ update = SearchSpaceUpdate(name="Updated Name")
+ assert update.name == "Updated Name"
+ assert update.description is None
+ assert update.citations_enabled is None
+
+ def test_search_space_update_full(self):
+ """Test SearchSpaceUpdate with all fields."""
+ update = SearchSpaceUpdate(
+ name="Full Update",
+ description="New description",
+ citations_enabled=True,
+ qna_custom_instructions="New instructions",
+ )
+ assert update.qna_custom_instructions == "New instructions"
+
+ def test_search_space_read(self):
+ """Test SearchSpaceRead schema."""
+ now = datetime.now(timezone.utc)
+ user_id = uuid4()
+ space = SearchSpaceRead(
+ id=1,
+ name="Read Space",
+ description="Description",
+ created_at=now,
+ user_id=user_id,
+ citations_enabled=True,
+ qna_custom_instructions=None,
+ )
+ assert space.id == 1
+ assert space.user_id == user_id
+
+ def test_search_space_with_stats(self):
+ """Test SearchSpaceWithStats schema."""
+ now = datetime.now(timezone.utc)
+ user_id = uuid4()
+ space = SearchSpaceWithStats(
+ id=1,
+ name="Space with Stats",
+ created_at=now,
+ user_id=user_id,
+ citations_enabled=True,
+ member_count=5,
+ is_owner=True,
+ )
+ assert space.member_count == 5
+ assert space.is_owner is True
+
+ def test_search_space_with_stats_defaults(self):
+ """Test SearchSpaceWithStats default values."""
+ now = datetime.now(timezone.utc)
+ user_id = uuid4()
+ space = SearchSpaceWithStats(
+ id=1,
+ name="Default Stats Space",
+ created_at=now,
+ user_id=user_id,
+ citations_enabled=True,
+ )
+ assert space.member_count == 1
+ assert space.is_owner is False
+
+
+class TestSchemaValidation:
+ """Tests for schema validation errors."""
+
+ def test_chat_base_missing_required(self):
+ """Test ChatBase raises error for missing required fields."""
+ with pytest.raises(ValidationError):
+ ChatBase(type=ChatType.QNA, title="Test") # Missing messages and search_space_id
+
+ def test_llm_config_name_too_long(self):
+ """Test LLMConfigBase validates name length."""
+ with pytest.raises(ValidationError):
+ LLMConfigBase(
+ name="x" * 101, # Exceeds max_length of 100
+ provider=LiteLLMProvider.OPENAI,
+ model_name="gpt-4",
+ api_key="test",
+ )
+
+ def test_llm_config_model_name_too_long(self):
+ """Test LLMConfigBase validates model_name length."""
+ with pytest.raises(ValidationError):
+ LLMConfigBase(
+ name="Valid Name",
+ provider=LiteLLMProvider.OPENAI,
+ model_name="x" * 101, # Exceeds max_length of 100
+ api_key="test",
+ )
+
+ def test_document_read_missing_required(self):
+ """Test DocumentRead raises error for missing required fields."""
+ with pytest.raises(ValidationError):
+ DocumentRead(
+ id=1,
+ title="Test",
+ # Missing document_type, document_metadata, content, created_at, search_space_id
+ )
diff --git a/surfsense_backend/tests/test_validators.py b/surfsense_backend/tests/test_validators.py
new file mode 100644
index 000000000..8baa690d4
--- /dev/null
+++ b/surfsense_backend/tests/test_validators.py
@@ -0,0 +1,441 @@
+"""
+Tests for the validators module.
+"""
+
+import pytest
+from fastapi import HTTPException
+
+from app.utils.validators import (
+ validate_connectors,
+ validate_document_ids,
+ validate_email,
+ validate_messages,
+ validate_research_mode,
+ validate_search_mode,
+ validate_search_space_id,
+ validate_top_k,
+ validate_url,
+ validate_uuid,
+)
+
+
+class TestValidateSearchSpaceId:
+ """Tests for validate_search_space_id function."""
+
+ def test_valid_integer(self):
+ """Test valid integer input."""
+ assert validate_search_space_id(1) == 1
+ assert validate_search_space_id(100) == 100
+ assert validate_search_space_id(999999) == 999999
+
+ def test_valid_string(self):
+ """Test valid string input."""
+ assert validate_search_space_id("1") == 1
+ assert validate_search_space_id("100") == 100
+ assert validate_search_space_id(" 50 ") == 50 # Trimmed
+
+ def test_none_raises_error(self):
+ """Test that None raises HTTPException."""
+ with pytest.raises(HTTPException) as exc_info:
+ validate_search_space_id(None)
+ assert exc_info.value.status_code == 400
+ assert "required" in exc_info.value.detail
+
+ def test_zero_raises_error(self):
+ """Test that zero raises HTTPException."""
+ with pytest.raises(HTTPException) as exc_info:
+ validate_search_space_id(0)
+ assert exc_info.value.status_code == 400
+ assert "positive" in exc_info.value.detail
+
+ def test_negative_raises_error(self):
+ """Test that negative values raise HTTPException."""
+ with pytest.raises(HTTPException) as exc_info:
+ validate_search_space_id(-1)
+ assert exc_info.value.status_code == 400
+ assert "positive" in exc_info.value.detail
+
+ def test_boolean_raises_error(self):
+ """Test that boolean raises HTTPException."""
+ with pytest.raises(HTTPException) as exc_info:
+ validate_search_space_id(True)
+ assert exc_info.value.status_code == 400
+ assert "boolean" in exc_info.value.detail
+
+ def test_empty_string_raises_error(self):
+ """Test that empty string raises HTTPException."""
+ with pytest.raises(HTTPException) as exc_info:
+ validate_search_space_id("")
+ assert exc_info.value.status_code == 400
+
+ def test_invalid_string_raises_error(self):
+ """Test that invalid string raises HTTPException."""
+ with pytest.raises(HTTPException) as exc_info:
+ validate_search_space_id("abc")
+ assert exc_info.value.status_code == 400
+
+ def test_float_raises_error(self):
+ """Test that float raises HTTPException."""
+ with pytest.raises(HTTPException) as exc_info:
+ validate_search_space_id(1.5)
+ assert exc_info.value.status_code == 400
+
+
+class TestValidateDocumentIds:
+ """Tests for validate_document_ids function."""
+
+ def test_none_returns_empty_list(self):
+ """Test that None returns empty list."""
+ assert validate_document_ids(None) == []
+
+ def test_empty_list_returns_empty_list(self):
+ """Test that empty list returns empty list."""
+ assert validate_document_ids([]) == []
+
+ def test_valid_integer_list(self):
+ """Test valid integer list."""
+ assert validate_document_ids([1, 2, 3]) == [1, 2, 3]
+
+ def test_valid_string_list(self):
+ """Test valid string list."""
+ assert validate_document_ids(["1", "2", "3"]) == [1, 2, 3]
+
+ def test_mixed_valid_types(self):
+ """Test mixed valid types."""
+ assert validate_document_ids([1, "2", 3]) == [1, 2, 3]
+
+ def test_not_list_raises_error(self):
+ """Test that non-list raises HTTPException."""
+ with pytest.raises(HTTPException) as exc_info:
+ validate_document_ids("not a list")
+ assert exc_info.value.status_code == 400
+ assert "must be a list" in exc_info.value.detail
+
+ def test_negative_id_raises_error(self):
+ """Test that negative ID raises HTTPException."""
+ with pytest.raises(HTTPException) as exc_info:
+ validate_document_ids([1, -2, 3])
+ assert exc_info.value.status_code == 400
+ assert "positive" in exc_info.value.detail
+
+ def test_zero_id_raises_error(self):
+ """Test that zero ID raises HTTPException."""
+ with pytest.raises(HTTPException) as exc_info:
+ validate_document_ids([0])
+ assert exc_info.value.status_code == 400
+ assert "positive" in exc_info.value.detail
+
+ def test_boolean_in_list_raises_error(self):
+ """Test that boolean in list raises HTTPException."""
+ with pytest.raises(HTTPException) as exc_info:
+ validate_document_ids([1, True, 3])
+ assert exc_info.value.status_code == 400
+ assert "boolean" in exc_info.value.detail
+
+
+class TestValidateConnectors:
+ """Tests for validate_connectors function."""
+
+ def test_none_returns_empty_list(self):
+ """Test that None returns empty list."""
+ assert validate_connectors(None) == []
+
+ def test_empty_list_returns_empty_list(self):
+ """Test that empty list returns empty list."""
+ assert validate_connectors([]) == []
+
+ def test_valid_connectors(self):
+ """Test valid connector names."""
+ assert validate_connectors(["slack", "github"]) == ["slack", "github"]
+
+ def test_connector_with_underscore(self):
+ """Test connector names with underscores."""
+ assert validate_connectors(["google_calendar"]) == ["google_calendar"]
+
+ def test_connector_with_hyphen(self):
+ """Test connector names with hyphens."""
+ assert validate_connectors(["google-calendar"]) == ["google-calendar"]
+
+ def test_not_list_raises_error(self):
+ """Test that non-list raises HTTPException."""
+ with pytest.raises(HTTPException) as exc_info:
+ validate_connectors("not a list")
+ assert exc_info.value.status_code == 400
+ assert "must be a list" in exc_info.value.detail
+
+ def test_non_string_in_list_raises_error(self):
+ """Test that non-string in list raises HTTPException."""
+ with pytest.raises(HTTPException) as exc_info:
+ validate_connectors(["slack", 123])
+ assert exc_info.value.status_code == 400
+ assert "must be a string" in exc_info.value.detail
+
+ def test_empty_string_raises_error(self):
+ """Test that empty string raises HTTPException."""
+ with pytest.raises(HTTPException) as exc_info:
+ validate_connectors(["slack", ""])
+ assert exc_info.value.status_code == 400
+ assert "cannot be empty" in exc_info.value.detail
+
+ def test_invalid_characters_raises_error(self):
+ """Test that invalid characters raise HTTPException."""
+ with pytest.raises(HTTPException) as exc_info:
+ validate_connectors(["slack@connector"])
+ assert exc_info.value.status_code == 400
+ assert "invalid characters" in exc_info.value.detail
+
+
+class TestValidateResearchMode:
+ """Tests for validate_research_mode function."""
+
+ def test_none_returns_default(self):
+ """Test that None returns default value."""
+ assert validate_research_mode(None) == "QNA"
+
+ def test_valid_mode(self):
+ """Test valid mode."""
+ assert validate_research_mode("QNA") == "QNA"
+ assert validate_research_mode("qna") == "QNA" # Case insensitive
+
+ def test_non_string_raises_error(self):
+ """Test that non-string raises HTTPException."""
+ with pytest.raises(HTTPException) as exc_info:
+ validate_research_mode(123)
+ assert exc_info.value.status_code == 400
+ assert "must be a string" in exc_info.value.detail
+
+ def test_invalid_mode_raises_error(self):
+ """Test that invalid mode raises HTTPException."""
+ with pytest.raises(HTTPException) as exc_info:
+ validate_research_mode("INVALID")
+ assert exc_info.value.status_code == 400
+ assert "must be one of" in exc_info.value.detail
+
+
+class TestValidateSearchMode:
+ """Tests for validate_search_mode function."""
+
+ def test_none_returns_default(self):
+ """Test that None returns default value."""
+ assert validate_search_mode(None) == "CHUNKS"
+
+ def test_valid_modes(self):
+ """Test valid modes."""
+ assert validate_search_mode("CHUNKS") == "CHUNKS"
+ assert validate_search_mode("DOCUMENTS") == "DOCUMENTS"
+ assert validate_search_mode("chunks") == "CHUNKS" # Case insensitive
+
+ def test_non_string_raises_error(self):
+ """Test that non-string raises HTTPException."""
+ with pytest.raises(HTTPException) as exc_info:
+ validate_search_mode(123)
+ assert exc_info.value.status_code == 400
+ assert "must be a string" in exc_info.value.detail
+
+ def test_invalid_mode_raises_error(self):
+ """Test that invalid mode raises HTTPException."""
+ with pytest.raises(HTTPException) as exc_info:
+ validate_search_mode("INVALID")
+ assert exc_info.value.status_code == 400
+ assert "must be one of" in exc_info.value.detail
+
+
+class TestValidateTopK:
+ """Tests for validate_top_k function."""
+
+ def test_none_returns_default(self):
+ """Test that None returns default value."""
+ assert validate_top_k(None) == 10
+
+ def test_valid_integer(self):
+ """Test valid integer input."""
+ assert validate_top_k(1) == 1
+ assert validate_top_k(50) == 50
+ assert validate_top_k(100) == 100
+
+ def test_valid_string(self):
+ """Test valid string input."""
+ assert validate_top_k("5") == 5
+ assert validate_top_k(" 10 ") == 10
+
+ def test_zero_raises_error(self):
+ """Test that zero raises HTTPException."""
+ with pytest.raises(HTTPException) as exc_info:
+ validate_top_k(0)
+ assert exc_info.value.status_code == 400
+ assert "positive" in exc_info.value.detail
+
+ def test_negative_raises_error(self):
+ """Test that negative values raise HTTPException."""
+ with pytest.raises(HTTPException) as exc_info:
+ validate_top_k(-1)
+ assert exc_info.value.status_code == 400
+ assert "positive" in exc_info.value.detail
+
+ def test_exceeds_max_raises_error(self):
+ """Test that values over 100 raise HTTPException."""
+ with pytest.raises(HTTPException) as exc_info:
+ validate_top_k(101)
+ assert exc_info.value.status_code == 400
+ assert "exceed 100" in exc_info.value.detail
+
+ def test_boolean_raises_error(self):
+ """Test that boolean raises HTTPException."""
+ with pytest.raises(HTTPException) as exc_info:
+ validate_top_k(True)
+ assert exc_info.value.status_code == 400
+ assert "boolean" in exc_info.value.detail
+
+
+class TestValidateMessages:
+ """Tests for validate_messages function."""
+
+ def test_valid_messages(self):
+ """Test valid messages."""
+ messages = [
+ {"role": "user", "content": "Hello"},
+ {"role": "assistant", "content": "Hi there!"},
+ ]
+ result = validate_messages(messages)
+ assert len(result) == 2
+ assert result[0]["role"] == "user"
+ assert result[1]["role"] == "assistant"
+
+ def test_trims_content(self):
+ """Test that content is trimmed."""
+ messages = [{"role": "user", "content": " Hello "}]
+ result = validate_messages(messages)
+ assert result[0]["content"] == "Hello"
+
+ def test_system_message_valid(self):
+ """Test that system messages are valid."""
+ messages = [
+ {"role": "system", "content": "You are helpful"},
+ {"role": "user", "content": "Hello"},
+ ]
+ result = validate_messages(messages)
+ assert result[0]["role"] == "system"
+
+ def test_not_list_raises_error(self):
+ """Test that non-list raises HTTPException."""
+ with pytest.raises(HTTPException) as exc_info:
+ validate_messages("not a list")
+ assert exc_info.value.status_code == 400
+ assert "must be a list" in exc_info.value.detail
+
+ def test_empty_list_raises_error(self):
+ """Test that empty list raises HTTPException."""
+ with pytest.raises(HTTPException) as exc_info:
+ validate_messages([])
+ assert exc_info.value.status_code == 400
+ assert "cannot be empty" in exc_info.value.detail
+
+ def test_missing_role_raises_error(self):
+ """Test that missing role raises HTTPException."""
+ with pytest.raises(HTTPException) as exc_info:
+ validate_messages([{"content": "Hello"}])
+ assert exc_info.value.status_code == 400
+ assert "role" in exc_info.value.detail
+
+ def test_missing_content_raises_error(self):
+ """Test that missing content raises HTTPException."""
+ with pytest.raises(HTTPException) as exc_info:
+ validate_messages([{"role": "user"}])
+ assert exc_info.value.status_code == 400
+ assert "content" in exc_info.value.detail
+
+ def test_invalid_role_raises_error(self):
+ """Test that invalid role raises HTTPException."""
+ with pytest.raises(HTTPException) as exc_info:
+ validate_messages([{"role": "invalid", "content": "Hello"}])
+ assert exc_info.value.status_code == 400
+ assert "role" in exc_info.value.detail
+
+ def test_empty_content_raises_error(self):
+ """Test that empty content raises HTTPException."""
+ with pytest.raises(HTTPException) as exc_info:
+ validate_messages([{"role": "user", "content": " "}])
+ assert exc_info.value.status_code == 400
+ assert "cannot be empty" in exc_info.value.detail
+
+
+class TestValidateEmail:
+ """Tests for validate_email function."""
+
+ def test_valid_email(self):
+ """Test valid email addresses."""
+ assert validate_email("test@example.com") == "test@example.com"
+ assert validate_email("user.name@domain.co.uk") == "user.name@domain.co.uk"
+
+ def test_trims_whitespace(self):
+ """Test that whitespace is trimmed."""
+ assert validate_email(" test@example.com ") == "test@example.com"
+
+ def test_empty_raises_error(self):
+ """Test that empty string raises HTTPException."""
+ with pytest.raises(HTTPException) as exc_info:
+ validate_email("")
+ assert exc_info.value.status_code == 400
+
+ def test_invalid_format_raises_error(self):
+ """Test that invalid format raises HTTPException."""
+ with pytest.raises(HTTPException) as exc_info:
+ validate_email("not-an-email")
+ assert exc_info.value.status_code == 400
+ assert "Invalid email" in exc_info.value.detail
+
+
+class TestValidateUrl:
+ """Tests for validate_url function."""
+
+ def test_valid_url(self):
+ """Test valid URLs."""
+ assert validate_url("https://example.com") == "https://example.com"
+ assert (
+ validate_url("http://sub.domain.com/path")
+ == "http://sub.domain.com/path"
+ )
+
+ def test_trims_whitespace(self):
+ """Test that whitespace is trimmed."""
+ assert validate_url(" https://example.com ") == "https://example.com"
+
+ def test_empty_raises_error(self):
+ """Test that empty string raises HTTPException."""
+ with pytest.raises(HTTPException) as exc_info:
+ validate_url("")
+ assert exc_info.value.status_code == 400
+
+ def test_invalid_format_raises_error(self):
+ """Test that invalid format raises HTTPException."""
+ with pytest.raises(HTTPException) as exc_info:
+ validate_url("not-a-url")
+ assert exc_info.value.status_code == 400
+ assert "Invalid URL" in exc_info.value.detail
+
+
+class TestValidateUuid:
+ """Tests for validate_uuid function."""
+
+ def test_valid_uuid(self):
+ """Test valid UUIDs."""
+ uuid_str = "123e4567-e89b-12d3-a456-426614174000"
+ assert validate_uuid(uuid_str) == uuid_str
+
+ def test_trims_whitespace(self):
+ """Test that whitespace is trimmed."""
+ uuid_str = " 123e4567-e89b-12d3-a456-426614174000 "
+ assert validate_uuid(uuid_str) == "123e4567-e89b-12d3-a456-426614174000"
+
+ def test_empty_raises_error(self):
+ """Test that empty string raises HTTPException."""
+ with pytest.raises(HTTPException) as exc_info:
+ validate_uuid("")
+ assert exc_info.value.status_code == 400
+
+ def test_invalid_format_raises_error(self):
+ """Test that invalid format raises HTTPException."""
+ with pytest.raises(HTTPException) as exc_info:
+ validate_uuid("not-a-uuid")
+ assert exc_info.value.status_code == 400
+ assert "Invalid UUID" in exc_info.value.detail