diff --git a/surfsense_backend/tests/test_agent_configuration.py b/surfsense_backend/tests/test_agent_configuration.py new file mode 100644 index 000000000..5729d68e3 --- /dev/null +++ b/surfsense_backend/tests/test_agent_configuration.py @@ -0,0 +1,200 @@ +""" +Tests for researcher agent configuration. +Tests the Configuration dataclass and SearchMode enum. +""" +import pytest +from dataclasses import fields + +from app.agents.researcher.configuration import Configuration, SearchMode + + +class TestSearchMode: + """Tests for SearchMode enum.""" + + def test_chunks_mode_value(self): + """Test CHUNKS mode value.""" + assert SearchMode.CHUNKS.value == "CHUNKS" + + def test_documents_mode_value(self): + """Test DOCUMENTS mode value.""" + assert SearchMode.DOCUMENTS.value == "DOCUMENTS" + + def test_all_modes_are_strings(self): + """Test all modes have string values.""" + for mode in SearchMode: + assert isinstance(mode.value, str) + + def test_can_compare_modes(self): + """Test enum comparison.""" + chunks_mode = SearchMode.CHUNKS + documents_mode = SearchMode.DOCUMENTS + assert chunks_mode == SearchMode.CHUNKS + assert chunks_mode != documents_mode + + +class TestConfiguration: + """Tests for Configuration dataclass.""" + + def test_create_configuration_with_required_params(self): + """Test creating configuration with required parameters.""" + config = Configuration( + user_query="test query", + connectors_to_search=["TAVILY_API"], + user_id="user-123", + search_space_id=1, + search_mode=SearchMode.CHUNKS, + document_ids_to_add_in_context=[], + ) + + assert config.user_query == "test query" + assert config.connectors_to_search == ["TAVILY_API"] + assert config.user_id == "user-123" + assert config.search_space_id == 1 + assert config.search_mode == SearchMode.CHUNKS + assert config.document_ids_to_add_in_context == [] + + def test_create_configuration_with_optional_params(self): + """Test creating configuration with optional parameters.""" + config = Configuration( + user_query="test query", + connectors_to_search=["TAVILY_API"], + user_id="user-123", + search_space_id=1, + search_mode=SearchMode.DOCUMENTS, + document_ids_to_add_in_context=[1, 2, 3], + language="en", + top_k=20, + ) + + assert config.language == "en" + assert config.top_k == 20 + assert config.document_ids_to_add_in_context == [1, 2, 3] + + def test_default_language_is_none(self): + """Test default language is None.""" + config = Configuration( + user_query="test", + connectors_to_search=[], + user_id="user-123", + search_space_id=1, + search_mode=SearchMode.CHUNKS, + document_ids_to_add_in_context=[], + ) + + assert config.language is None + + def test_default_top_k_is_10(self): + """Test default top_k is 10.""" + config = Configuration( + user_query="test", + connectors_to_search=[], + user_id="user-123", + search_space_id=1, + search_mode=SearchMode.CHUNKS, + document_ids_to_add_in_context=[], + ) + + assert config.top_k == 10 + + def test_from_runnable_config_with_none(self): + """Test from_runnable_config with None returns defaults.""" + # This should not raise an error but will fail due to missing required fields + # We're testing that the method handles None gracefully + with pytest.raises(TypeError): + # Missing required fields should raise TypeError + Configuration.from_runnable_config(None) + + def test_from_runnable_config_with_empty_config(self): + """Test from_runnable_config with empty config.""" + with pytest.raises(TypeError): + # Missing required fields should raise TypeError + Configuration.from_runnable_config({}) + + def test_from_runnable_config_with_valid_config(self): + """Test from_runnable_config with valid config.""" + runnable_config = { + "configurable": { + "user_query": "test query", + "connectors_to_search": ["TAVILY_API"], + "user_id": "user-123", + "search_space_id": 1, + "search_mode": SearchMode.CHUNKS, + "document_ids_to_add_in_context": [], + "language": "en", + "top_k": 15, + } + } + + config = Configuration.from_runnable_config(runnable_config) + + assert config.user_query == "test query" + assert config.connectors_to_search == ["TAVILY_API"] + assert config.language == "en" + assert config.top_k == 15 + + def test_from_runnable_config_ignores_unknown_fields(self): + """Test from_runnable_config ignores unknown fields.""" + runnable_config = { + "configurable": { + "user_query": "test query", + "connectors_to_search": ["TAVILY_API"], + "user_id": "user-123", + "search_space_id": 1, + "search_mode": SearchMode.CHUNKS, + "document_ids_to_add_in_context": [], + "unknown_field": "should be ignored", + "another_unknown": 123, + } + } + + config = Configuration.from_runnable_config(runnable_config) + + assert not hasattr(config, "unknown_field") + assert not hasattr(config, "another_unknown") + + def test_configuration_has_expected_fields(self): + """Test Configuration has all expected fields.""" + field_names = {f.name for f in fields(Configuration)} + + expected_fields = { + "user_query", + "connectors_to_search", + "user_id", + "search_space_id", + "search_mode", + "document_ids_to_add_in_context", + "language", + "top_k", + } + + assert field_names == expected_fields + + def test_configuration_multiple_connectors(self): + """Test configuration with multiple connectors.""" + config = Configuration( + user_query="test", + connectors_to_search=["TAVILY_API", "SLACK_CONNECTOR", "NOTION_CONNECTOR"], + user_id="user-123", + search_space_id=1, + search_mode=SearchMode.CHUNKS, + document_ids_to_add_in_context=[], + ) + + assert len(config.connectors_to_search) == 3 + assert "TAVILY_API" in config.connectors_to_search + assert "SLACK_CONNECTOR" in config.connectors_to_search + assert "NOTION_CONNECTOR" in config.connectors_to_search + + def test_configuration_with_document_ids(self): + """Test configuration with document IDs to add to context.""" + config = Configuration( + user_query="test", + connectors_to_search=[], + user_id="user-123", + search_space_id=1, + search_mode=SearchMode.CHUNKS, + document_ids_to_add_in_context=[1, 2, 3, 4, 5], + ) + + assert config.document_ids_to_add_in_context == [1, 2, 3, 4, 5] + assert len(config.document_ids_to_add_in_context) == 5 diff --git a/surfsense_backend/tests/test_celery_tasks.py b/surfsense_backend/tests/test_celery_tasks.py new file mode 100644 index 000000000..c6fc8037c --- /dev/null +++ b/surfsense_backend/tests/test_celery_tasks.py @@ -0,0 +1,350 @@ +"""Tests for Celery tasks module.""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from app.tasks.celery_tasks.connector_tasks import ( + get_celery_session_maker, + _index_slack_messages, + _index_notion_pages, + _index_github_repos, + _index_linear_issues, + _index_jira_issues, + _index_confluence_pages, + _index_clickup_tasks, + _index_google_calendar_events, + _index_airtable_records, + _index_google_gmail_messages, + _index_discord_messages, + _index_luma_events, + _index_elasticsearch_documents, + _index_crawled_urls, +) + + +class TestGetCelerySessionMaker: + """Tests for get_celery_session_maker function.""" + + def test_returns_session_maker(self): + """Test that get_celery_session_maker returns a session maker.""" + with patch("app.tasks.celery_tasks.connector_tasks.create_async_engine") as mock_engine: + with patch("app.tasks.celery_tasks.connector_tasks.async_sessionmaker") as mock_session_maker: + mock_engine.return_value = MagicMock() + mock_session_maker.return_value = MagicMock() + + result = get_celery_session_maker() + + assert result is not None + mock_engine.assert_called_once() + mock_session_maker.assert_called_once() + + def test_uses_null_pool(self): + """Test that NullPool is used for Celery tasks.""" + from sqlalchemy.pool import NullPool + + with patch("app.tasks.celery_tasks.connector_tasks.create_async_engine") as mock_engine: + with patch("app.tasks.celery_tasks.connector_tasks.async_sessionmaker"): + get_celery_session_maker() + + # Check that NullPool was passed + call_kwargs = mock_engine.call_args[1] + assert call_kwargs.get("poolclass") == NullPool + + +class TestIndexSlackMessages: + """Tests for Slack message indexing task.""" + + @pytest.mark.asyncio + async def test_index_slack_messages_calls_run_slack_indexing(self): + """Test that _index_slack_messages calls run_slack_indexing.""" + mock_session = AsyncMock() + mock_run_indexing = AsyncMock() + + with patch("app.tasks.celery_tasks.connector_tasks.get_celery_session_maker") as mock_maker: + # Create a mock context manager + mock_context = AsyncMock() + mock_context.__aenter__.return_value = mock_session + mock_context.__aexit__.return_value = None + mock_maker.return_value.return_value = mock_context + + with patch("app.routes.search_source_connectors_routes.run_slack_indexing", mock_run_indexing): + await _index_slack_messages(1, 1, "user1", "2024-01-01", "2024-12-31") + + +class TestIndexNotionPages: + """Tests for Notion page indexing task.""" + + @pytest.mark.asyncio + async def test_index_notion_pages_calls_correct_function(self): + """Test that _index_notion_pages calls run_notion_indexing.""" + mock_session = AsyncMock() + + with patch("app.tasks.celery_tasks.connector_tasks.get_celery_session_maker") as mock_maker: + mock_context = AsyncMock() + mock_context.__aenter__.return_value = mock_session + mock_context.__aexit__.return_value = None + mock_maker.return_value.return_value = mock_context + + with patch("app.routes.search_source_connectors_routes.run_notion_indexing", new_callable=AsyncMock) as mock_run: + await _index_notion_pages(1, 1, "user1", "2024-01-01", "2024-12-31") + + +class TestIndexGithubRepos: + """Tests for GitHub repository indexing task.""" + + @pytest.mark.asyncio + async def test_index_github_repos_with_valid_params(self): + """Test GitHub repo indexing with valid parameters.""" + mock_session = AsyncMock() + + with patch("app.tasks.celery_tasks.connector_tasks.get_celery_session_maker") as mock_maker: + mock_context = AsyncMock() + mock_context.__aenter__.return_value = mock_session + mock_context.__aexit__.return_value = None + mock_maker.return_value.return_value = mock_context + + with patch("app.routes.search_source_connectors_routes.run_github_indexing", new_callable=AsyncMock): + await _index_github_repos(1, 1, "user1", "2024-01-01", "2024-12-31") + + +class TestIndexLinearIssues: + """Tests for Linear issues indexing task.""" + + @pytest.mark.asyncio + async def test_index_linear_issues_creates_session(self): + """Test that Linear indexing creates a proper session.""" + mock_session = AsyncMock() + + with patch("app.tasks.celery_tasks.connector_tasks.get_celery_session_maker") as mock_maker: + mock_context = AsyncMock() + mock_context.__aenter__.return_value = mock_session + mock_context.__aexit__.return_value = None + mock_maker.return_value.return_value = mock_context + + with patch("app.routes.search_source_connectors_routes.run_linear_indexing", new_callable=AsyncMock): + await _index_linear_issues(1, 1, "user1", "2024-01-01", "2024-12-31") + + +class TestIndexJiraIssues: + """Tests for Jira issues indexing task.""" + + @pytest.mark.asyncio + async def test_index_jira_issues_passes_correct_params(self): + """Test that Jira indexing passes correct parameters.""" + mock_session = AsyncMock() + + with patch("app.tasks.celery_tasks.connector_tasks.get_celery_session_maker") as mock_maker: + mock_context = AsyncMock() + mock_context.__aenter__.return_value = mock_session + mock_context.__aexit__.return_value = None + mock_maker.return_value.return_value = mock_context + + with patch("app.routes.search_source_connectors_routes.run_jira_indexing", new_callable=AsyncMock) as mock_run: + await _index_jira_issues(5, 10, "user123", "2024-06-01", "2024-06-30") + mock_run.assert_called_once_with( + mock_session, 5, 10, "user123", "2024-06-01", "2024-06-30" + ) + + +class TestIndexConfluencePages: + """Tests for Confluence pages indexing task.""" + + @pytest.mark.asyncio + async def test_index_confluence_pages_with_valid_params(self): + """Test Confluence indexing with valid parameters.""" + mock_session = AsyncMock() + + with patch("app.tasks.celery_tasks.connector_tasks.get_celery_session_maker") as mock_maker: + mock_context = AsyncMock() + mock_context.__aenter__.return_value = mock_session + mock_context.__aexit__.return_value = None + mock_maker.return_value.return_value = mock_context + + with patch("app.routes.search_source_connectors_routes.run_confluence_indexing", new_callable=AsyncMock): + await _index_confluence_pages(1, 1, "user1", "2024-01-01", "2024-12-31") + + +class TestIndexClickupTasks: + """Tests for ClickUp tasks indexing.""" + + @pytest.mark.asyncio + async def test_index_clickup_tasks_creates_session(self): + """Test ClickUp indexing creates session.""" + mock_session = AsyncMock() + + with patch("app.tasks.celery_tasks.connector_tasks.get_celery_session_maker") as mock_maker: + mock_context = AsyncMock() + mock_context.__aenter__.return_value = mock_session + mock_context.__aexit__.return_value = None + mock_maker.return_value.return_value = mock_context + + with patch("app.routes.search_source_connectors_routes.run_clickup_indexing", new_callable=AsyncMock): + await _index_clickup_tasks(1, 1, "user1", "2024-01-01", "2024-12-31") + + +class TestIndexGoogleCalendarEvents: + """Tests for Google Calendar events indexing.""" + + @pytest.mark.asyncio + async def test_index_google_calendar_events_with_valid_params(self): + """Test Google Calendar indexing with valid parameters.""" + mock_session = AsyncMock() + + with patch("app.tasks.celery_tasks.connector_tasks.get_celery_session_maker") as mock_maker: + mock_context = AsyncMock() + mock_context.__aenter__.return_value = mock_session + mock_context.__aexit__.return_value = None + mock_maker.return_value.return_value = mock_context + + with patch("app.routes.search_source_connectors_routes.run_google_calendar_indexing", new_callable=AsyncMock): + await _index_google_calendar_events(1, 1, "user1", "2024-01-01", "2024-12-31") + + +class TestIndexAirtableRecords: + """Tests for Airtable records indexing.""" + + @pytest.mark.asyncio + async def test_index_airtable_records_creates_session(self): + """Test Airtable indexing creates session.""" + mock_session = AsyncMock() + + with patch("app.tasks.celery_tasks.connector_tasks.get_celery_session_maker") as mock_maker: + mock_context = AsyncMock() + mock_context.__aenter__.return_value = mock_session + mock_context.__aexit__.return_value = None + mock_maker.return_value.return_value = mock_context + + with patch("app.routes.search_source_connectors_routes.run_airtable_indexing", new_callable=AsyncMock): + await _index_airtable_records(1, 1, "user1", "2024-01-01", "2024-12-31") + + +class TestIndexGoogleGmailMessages: + """Tests for Google Gmail messages indexing.""" + + @pytest.mark.asyncio + async def test_index_gmail_messages_calculates_days_back(self): + """Test Gmail indexing calculates days_back from start_date.""" + mock_session = AsyncMock() + + with patch("app.tasks.celery_tasks.connector_tasks.get_celery_session_maker") as mock_maker: + mock_context = AsyncMock() + mock_context.__aenter__.return_value = mock_session + mock_context.__aexit__.return_value = None + mock_maker.return_value.return_value = mock_context + + with patch("app.routes.search_source_connectors_routes.run_google_gmail_indexing", new_callable=AsyncMock) as mock_run: + await _index_google_gmail_messages(1, 1, "user1", "2024-01-01", "2024-12-31") + # Should have been called with calculated days_back + mock_run.assert_called_once() + + @pytest.mark.asyncio + async def test_index_gmail_messages_default_days_back(self): + """Test Gmail indexing uses default days_back when no start_date.""" + mock_session = AsyncMock() + + with patch("app.tasks.celery_tasks.connector_tasks.get_celery_session_maker") as mock_maker: + mock_context = AsyncMock() + mock_context.__aenter__.return_value = mock_session + mock_context.__aexit__.return_value = None + mock_maker.return_value.return_value = mock_context + + with patch("app.routes.search_source_connectors_routes.run_google_gmail_indexing", new_callable=AsyncMock) as mock_run: + await _index_google_gmail_messages(1, 1, "user1", None, None) + # Should have been called with max_messages=100 and default days_back=30 + # Args: session, connector_id, search_space_id, user_id, max_messages, days_back + mock_run.assert_called_once() + call_args = mock_run.call_args[0] + assert call_args[4] == 100 # max_messages (index 4) + assert call_args[5] == 30 # days_back (index 5) + + @pytest.mark.asyncio + async def test_index_gmail_messages_invalid_date_uses_default(self): + """Test Gmail indexing uses default when date parsing fails.""" + mock_session = AsyncMock() + + with patch("app.tasks.celery_tasks.connector_tasks.get_celery_session_maker") as mock_maker: + mock_context = AsyncMock() + mock_context.__aenter__.return_value = mock_session + mock_context.__aexit__.return_value = None + mock_maker.return_value.return_value = mock_context + + with patch("app.routes.search_source_connectors_routes.run_google_gmail_indexing", new_callable=AsyncMock) as mock_run: + await _index_google_gmail_messages(1, 1, "user1", "invalid-date", None) + mock_run.assert_called_once() + # Args: session, connector_id, search_space_id, user_id, max_messages, days_back + call_args = mock_run.call_args[0] + assert call_args[4] == 100 # max_messages (index 4) + assert call_args[5] == 30 # days_back default (index 5) + + +class TestIndexDiscordMessages: + """Tests for Discord messages indexing.""" + + @pytest.mark.asyncio + async def test_index_discord_messages_with_valid_params(self): + """Test Discord indexing with valid parameters.""" + mock_session = AsyncMock() + + with patch("app.tasks.celery_tasks.connector_tasks.get_celery_session_maker") as mock_maker: + mock_context = AsyncMock() + mock_context.__aenter__.return_value = mock_session + mock_context.__aexit__.return_value = None + mock_maker.return_value.return_value = mock_context + + with patch("app.routes.search_source_connectors_routes.run_discord_indexing", new_callable=AsyncMock): + await _index_discord_messages(1, 1, "user1", "2024-01-01", "2024-12-31") + + +class TestIndexLumaEvents: + """Tests for Luma events indexing.""" + + @pytest.mark.asyncio + async def test_index_luma_events_creates_session(self): + """Test Luma indexing creates session.""" + mock_session = AsyncMock() + + with patch("app.tasks.celery_tasks.connector_tasks.get_celery_session_maker") as mock_maker: + mock_context = AsyncMock() + mock_context.__aenter__.return_value = mock_session + mock_context.__aexit__.return_value = None + mock_maker.return_value.return_value = mock_context + + with patch("app.routes.search_source_connectors_routes.run_luma_indexing", new_callable=AsyncMock): + await _index_luma_events(1, 1, "user1", "2024-01-01", "2024-12-31") + + +class TestIndexElasticsearchDocuments: + """Tests for Elasticsearch documents indexing.""" + + @pytest.mark.asyncio + async def test_index_elasticsearch_documents_with_valid_params(self): + """Test Elasticsearch indexing with valid parameters.""" + mock_session = AsyncMock() + + with patch("app.tasks.celery_tasks.connector_tasks.get_celery_session_maker") as mock_maker: + mock_context = AsyncMock() + mock_context.__aenter__.return_value = mock_session + mock_context.__aexit__.return_value = None + mock_maker.return_value.return_value = mock_context + + with patch("app.routes.search_source_connectors_routes.run_elasticsearch_indexing", new_callable=AsyncMock): + await _index_elasticsearch_documents(1, 1, "user1", "2024-01-01", "2024-12-31") + + +class TestIndexCrawledUrls: + """Tests for web page URL indexing.""" + + @pytest.mark.asyncio + async def test_index_crawled_urls_creates_session(self): + """Test web page indexing creates session.""" + mock_session = AsyncMock() + + with patch("app.tasks.celery_tasks.connector_tasks.get_celery_session_maker") as mock_maker: + mock_context = AsyncMock() + mock_context.__aenter__.return_value = mock_session + mock_context.__aexit__.return_value = None + mock_maker.return_value.return_value = mock_context + + with patch("app.routes.search_source_connectors_routes.run_web_page_indexing", new_callable=AsyncMock): + await _index_crawled_urls(1, 1, "user1", "2024-01-01", "2024-12-31") diff --git a/surfsense_backend/tests/test_celery_tasks_comprehensive.py b/surfsense_backend/tests/test_celery_tasks_comprehensive.py new file mode 100644 index 000000000..02160b763 --- /dev/null +++ b/surfsense_backend/tests/test_celery_tasks_comprehensive.py @@ -0,0 +1,1046 @@ +"""Comprehensive tests for Celery tasks module.""" + +from datetime import datetime, timedelta, UTC +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +# ============================================================================ +# SCHEDULE CHECKER TASK TESTS +# ============================================================================ + + +class TestScheduleCheckerTaskSessionMaker: + """Tests for schedule checker task session maker.""" + + def test_get_celery_session_maker_returns_maker(self): + """Test that get_celery_session_maker returns a session maker.""" + from app.tasks.celery_tasks.schedule_checker_task import get_celery_session_maker + + with patch( + "app.tasks.celery_tasks.schedule_checker_task.create_async_engine" + ) as mock_engine: + with patch( + "app.tasks.celery_tasks.schedule_checker_task.async_sessionmaker" + ) as mock_maker: + mock_engine.return_value = MagicMock() + mock_maker.return_value = MagicMock() + + result = get_celery_session_maker() + + assert result is not None + mock_engine.assert_called_once() + + def test_get_celery_session_maker_uses_null_pool(self): + """Test that NullPool is used.""" + from sqlalchemy.pool import NullPool + from app.tasks.celery_tasks.schedule_checker_task import get_celery_session_maker + + with patch( + "app.tasks.celery_tasks.schedule_checker_task.create_async_engine" + ) as mock_engine: + with patch( + "app.tasks.celery_tasks.schedule_checker_task.async_sessionmaker" + ): + get_celery_session_maker() + + call_kwargs = mock_engine.call_args[1] + assert call_kwargs.get("poolclass") == NullPool + + +class TestCheckAndTriggerSchedules: + """Tests for _check_and_trigger_schedules function.""" + + @pytest.mark.asyncio + async def test_no_due_connectors(self): + """Test when no connectors are due for indexing.""" + from app.tasks.celery_tasks.schedule_checker_task import ( + _check_and_trigger_schedules, + ) + + mock_session = AsyncMock() + mock_result = MagicMock() + mock_result.scalars.return_value.all.return_value = [] + mock_session.execute.return_value = mock_result + + with patch( + "app.tasks.celery_tasks.schedule_checker_task.get_celery_session_maker" + ) as mock_maker: + mock_context = AsyncMock() + mock_context.__aenter__.return_value = mock_session + mock_context.__aexit__.return_value = None + mock_maker.return_value.return_value = mock_context + + await _check_and_trigger_schedules() + + mock_session.execute.assert_called_once() + + @pytest.mark.asyncio + async def test_triggers_slack_connector_task(self): + """Test triggering Slack connector indexing task.""" + from app.tasks.celery_tasks.schedule_checker_task import ( + _check_and_trigger_schedules, + ) + from app.db import SearchSourceConnectorType + + mock_session = AsyncMock() + mock_connector = MagicMock() + mock_connector.id = 1 + mock_connector.search_space_id = 1 + mock_connector.user_id = "user123" + mock_connector.connector_type = SearchSourceConnectorType.SLACK_CONNECTOR + mock_connector.indexing_frequency_minutes = 60 + mock_connector.next_scheduled_at = datetime.now(UTC) - timedelta(minutes=5) + + mock_result = MagicMock() + mock_result.scalars.return_value.all.return_value = [mock_connector] + mock_session.execute.return_value = mock_result + + with patch( + "app.tasks.celery_tasks.schedule_checker_task.get_celery_session_maker" + ) as mock_maker: + mock_context = AsyncMock() + mock_context.__aenter__.return_value = mock_session + mock_context.__aexit__.return_value = None + mock_maker.return_value.return_value = mock_context + + with patch( + "app.tasks.celery_tasks.connector_tasks.index_slack_messages_task" + ) as mock_slack_task: + mock_slack_task.delay = MagicMock() + + await _check_and_trigger_schedules() + + mock_slack_task.delay.assert_called_once_with( + 1, 1, "user123", None, None + ) + assert mock_connector.next_scheduled_at is not None + + @pytest.mark.asyncio + async def test_triggers_notion_connector_task(self): + """Test triggering Notion connector indexing task.""" + from app.tasks.celery_tasks.schedule_checker_task import ( + _check_and_trigger_schedules, + ) + from app.db import SearchSourceConnectorType + + mock_session = AsyncMock() + mock_connector = MagicMock() + mock_connector.id = 2 + mock_connector.search_space_id = 1 + mock_connector.user_id = "user456" + mock_connector.connector_type = SearchSourceConnectorType.NOTION_CONNECTOR + mock_connector.indexing_frequency_minutes = 120 + mock_connector.next_scheduled_at = datetime.now(UTC) - timedelta(minutes=10) + + mock_result = MagicMock() + mock_result.scalars.return_value.all.return_value = [mock_connector] + mock_session.execute.return_value = mock_result + + with patch( + "app.tasks.celery_tasks.schedule_checker_task.get_celery_session_maker" + ) as mock_maker: + mock_context = AsyncMock() + mock_context.__aenter__.return_value = mock_session + mock_context.__aexit__.return_value = None + mock_maker.return_value.return_value = mock_context + + with patch( + "app.tasks.celery_tasks.connector_tasks.index_notion_pages_task" + ) as mock_notion_task: + mock_notion_task.delay = MagicMock() + + await _check_and_trigger_schedules() + + mock_notion_task.delay.assert_called_once() + + @pytest.mark.asyncio + async def test_triggers_github_connector_task(self): + """Test triggering GitHub connector indexing task.""" + from app.tasks.celery_tasks.schedule_checker_task import ( + _check_and_trigger_schedules, + ) + from app.db import SearchSourceConnectorType + + mock_session = AsyncMock() + mock_connector = MagicMock() + mock_connector.id = 3 + mock_connector.search_space_id = 2 + mock_connector.user_id = "user789" + mock_connector.connector_type = SearchSourceConnectorType.GITHUB_CONNECTOR + mock_connector.indexing_frequency_minutes = 30 + mock_connector.next_scheduled_at = datetime.now(UTC) - timedelta(minutes=1) + + mock_result = MagicMock() + mock_result.scalars.return_value.all.return_value = [mock_connector] + mock_session.execute.return_value = mock_result + + with patch( + "app.tasks.celery_tasks.schedule_checker_task.get_celery_session_maker" + ) as mock_maker: + mock_context = AsyncMock() + mock_context.__aenter__.return_value = mock_session + mock_context.__aexit__.return_value = None + mock_maker.return_value.return_value = mock_context + + with patch( + "app.tasks.celery_tasks.connector_tasks.index_github_repos_task" + ) as mock_github_task: + mock_github_task.delay = MagicMock() + + await _check_and_trigger_schedules() + + mock_github_task.delay.assert_called_once() + + @pytest.mark.asyncio + async def test_triggers_multiple_connector_types(self): + """Test triggering multiple different connector types.""" + from app.tasks.celery_tasks.schedule_checker_task import ( + _check_and_trigger_schedules, + ) + from app.db import SearchSourceConnectorType + + mock_session = AsyncMock() + + # Create multiple connectors of different types + mock_connectors = [] + connector_types = [ + SearchSourceConnectorType.SLACK_CONNECTOR, + SearchSourceConnectorType.JIRA_CONNECTOR, + SearchSourceConnectorType.CONFLUENCE_CONNECTOR, + ] + + for i, ct in enumerate(connector_types): + mock_connector = MagicMock() + mock_connector.id = i + 1 + mock_connector.search_space_id = 1 + mock_connector.user_id = f"user{i}" + mock_connector.connector_type = ct + mock_connector.indexing_frequency_minutes = 60 + mock_connector.next_scheduled_at = datetime.now(UTC) - timedelta(minutes=5) + mock_connectors.append(mock_connector) + + mock_result = MagicMock() + mock_result.scalars.return_value.all.return_value = mock_connectors + mock_session.execute.return_value = mock_result + + with patch( + "app.tasks.celery_tasks.schedule_checker_task.get_celery_session_maker" + ) as mock_maker: + mock_context = AsyncMock() + mock_context.__aenter__.return_value = mock_session + mock_context.__aexit__.return_value = None + mock_maker.return_value.return_value = mock_context + + with patch( + "app.tasks.celery_tasks.connector_tasks.index_slack_messages_task" + ) as mock_slack: + with patch( + "app.tasks.celery_tasks.connector_tasks.index_jira_issues_task" + ) as mock_jira: + with patch( + "app.tasks.celery_tasks.connector_tasks.index_confluence_pages_task" + ) as mock_confluence: + mock_slack.delay = MagicMock() + mock_jira.delay = MagicMock() + mock_confluence.delay = MagicMock() + + await _check_and_trigger_schedules() + + mock_slack.delay.assert_called_once() + mock_jira.delay.assert_called_once() + mock_confluence.delay.assert_called_once() + + @pytest.mark.asyncio + async def test_handles_unknown_connector_type(self): + """Test handling of unknown connector type gracefully.""" + from app.tasks.celery_tasks.schedule_checker_task import ( + _check_and_trigger_schedules, + ) + + mock_session = AsyncMock() + mock_connector = MagicMock() + mock_connector.id = 1 + mock_connector.search_space_id = 1 + mock_connector.user_id = "user123" + mock_connector.connector_type = MagicMock() # Unknown type + mock_connector.connector_type.value = "UNKNOWN_CONNECTOR" + mock_connector.indexing_frequency_minutes = 60 + mock_connector.next_scheduled_at = datetime.now(UTC) - timedelta(minutes=5) + + mock_result = MagicMock() + mock_result.scalars.return_value.all.return_value = [mock_connector] + mock_session.execute.return_value = mock_result + + with patch( + "app.tasks.celery_tasks.schedule_checker_task.get_celery_session_maker" + ) as mock_maker: + mock_context = AsyncMock() + mock_context.__aenter__.return_value = mock_session + mock_context.__aexit__.return_value = None + mock_maker.return_value.return_value = mock_context + + # Should not raise an exception + await _check_and_trigger_schedules() + + @pytest.mark.asyncio + async def test_updates_next_scheduled_at(self): + """Test that next_scheduled_at is updated after triggering.""" + from app.tasks.celery_tasks.schedule_checker_task import ( + _check_and_trigger_schedules, + ) + from app.db import SearchSourceConnectorType + + mock_session = AsyncMock() + mock_connector = MagicMock() + mock_connector.id = 1 + mock_connector.search_space_id = 1 + mock_connector.user_id = "user123" + mock_connector.connector_type = SearchSourceConnectorType.SLACK_CONNECTOR + mock_connector.indexing_frequency_minutes = 60 + original_time = datetime.now(UTC) - timedelta(minutes=5) + mock_connector.next_scheduled_at = original_time + + mock_result = MagicMock() + mock_result.scalars.return_value.all.return_value = [mock_connector] + mock_session.execute.return_value = mock_result + + with patch( + "app.tasks.celery_tasks.schedule_checker_task.get_celery_session_maker" + ) as mock_maker: + mock_context = AsyncMock() + mock_context.__aenter__.return_value = mock_session + mock_context.__aexit__.return_value = None + mock_maker.return_value.return_value = mock_context + + with patch( + "app.tasks.celery_tasks.connector_tasks.index_slack_messages_task" + ) as mock_slack: + mock_slack.delay = MagicMock() + + await _check_and_trigger_schedules() + + # Check that next_scheduled_at was updated + assert mock_connector.next_scheduled_at != original_time + mock_session.commit.assert_called() + + @pytest.mark.asyncio + async def test_handles_database_error(self): + """Test handling of database errors.""" + from app.tasks.celery_tasks.schedule_checker_task import ( + _check_and_trigger_schedules, + ) + from sqlalchemy.exc import SQLAlchemyError + + mock_session = AsyncMock() + mock_session.execute.side_effect = SQLAlchemyError("DB error") + + with patch( + "app.tasks.celery_tasks.schedule_checker_task.get_celery_session_maker" + ) as mock_maker: + mock_context = AsyncMock() + mock_context.__aenter__.return_value = mock_session + mock_context.__aexit__.return_value = None + mock_maker.return_value.return_value = mock_context + + # Should not raise, just log error + await _check_and_trigger_schedules() + + mock_session.rollback.assert_called_once() + + +# ============================================================================ +# BLOCKNOTE MIGRATION TASK TESTS +# ============================================================================ + + +class TestBlocknoteMigrationTaskSessionMaker: + """Tests for blocknote migration task session maker.""" + + def test_get_celery_session_maker_returns_maker(self): + """Test that get_celery_session_maker returns a session maker.""" + from app.tasks.celery_tasks.blocknote_migration_tasks import ( + get_celery_session_maker, + ) + + with patch( + "app.tasks.celery_tasks.blocknote_migration_tasks.create_async_engine" + ) as mock_engine: + with patch( + "app.tasks.celery_tasks.blocknote_migration_tasks.async_sessionmaker" + ) as mock_maker: + mock_engine.return_value = MagicMock() + mock_maker.return_value = MagicMock() + + result = get_celery_session_maker() + + assert result is not None + + def test_get_celery_session_maker_uses_null_pool(self): + """Test that NullPool is used.""" + from sqlalchemy.pool import NullPool + from app.tasks.celery_tasks.blocknote_migration_tasks import ( + get_celery_session_maker, + ) + + with patch( + "app.tasks.celery_tasks.blocknote_migration_tasks.create_async_engine" + ) as mock_engine: + with patch( + "app.tasks.celery_tasks.blocknote_migration_tasks.async_sessionmaker" + ): + get_celery_session_maker() + + call_kwargs = mock_engine.call_args[1] + assert call_kwargs.get("poolclass") == NullPool + + +class TestPopulateBlocknoteForDocuments: + """Tests for _populate_blocknote_for_documents function.""" + + @pytest.mark.asyncio + async def test_no_documents_to_process(self): + """Test when no documents need blocknote population.""" + from app.tasks.celery_tasks.blocknote_migration_tasks import ( + _populate_blocknote_for_documents, + ) + + mock_session = AsyncMock() + mock_result = MagicMock() + mock_result.scalars.return_value.all.return_value = [] + mock_session.execute.return_value = mock_result + + with patch( + "app.tasks.celery_tasks.blocknote_migration_tasks.get_celery_session_maker" + ) as mock_maker: + mock_context = AsyncMock() + mock_context.__aenter__.return_value = mock_session + mock_context.__aexit__.return_value = None + mock_maker.return_value.return_value = mock_context + + await _populate_blocknote_for_documents() + + mock_session.execute.assert_called_once() + + @pytest.mark.asyncio + async def test_processes_documents_with_chunks(self): + """Test processing documents that have chunks.""" + from app.tasks.celery_tasks.blocknote_migration_tasks import ( + _populate_blocknote_for_documents, + ) + + mock_session = AsyncMock() + + # Create mock document with chunks + mock_chunk1 = MagicMock() + mock_chunk1.id = 1 + mock_chunk1.content = "# Header\n\nFirst chunk content" + + mock_chunk2 = MagicMock() + mock_chunk2.id = 2 + mock_chunk2.content = "Second chunk content" + + mock_document = MagicMock() + mock_document.id = 1 + mock_document.title = "Test Document" + mock_document.chunks = [mock_chunk1, mock_chunk2] + mock_document.blocknote_document = None + + mock_result = MagicMock() + mock_result.scalars.return_value.all.return_value = [mock_document] + mock_session.execute.return_value = mock_result + + with patch( + "app.tasks.celery_tasks.blocknote_migration_tasks.get_celery_session_maker" + ) as mock_maker: + mock_context = AsyncMock() + mock_context.__aenter__.return_value = mock_session + mock_context.__aexit__.return_value = None + mock_maker.return_value.return_value = mock_context + + with patch( + "app.tasks.celery_tasks.blocknote_migration_tasks.convert_markdown_to_blocknote", + new_callable=AsyncMock, + ) as mock_convert: + mock_convert.return_value = {"type": "doc", "content": []} + + await _populate_blocknote_for_documents() + + mock_convert.assert_called_once() + mock_session.commit.assert_called() + + @pytest.mark.asyncio + async def test_skips_documents_without_chunks(self): + """Test skipping documents that have no chunks.""" + from app.tasks.celery_tasks.blocknote_migration_tasks import ( + _populate_blocknote_for_documents, + ) + + mock_session = AsyncMock() + + mock_document = MagicMock() + mock_document.id = 1 + mock_document.title = "Empty Document" + mock_document.chunks = [] + mock_document.blocknote_document = None + + mock_result = MagicMock() + mock_result.scalars.return_value.all.return_value = [mock_document] + mock_session.execute.return_value = mock_result + + with patch( + "app.tasks.celery_tasks.blocknote_migration_tasks.get_celery_session_maker" + ) as mock_maker: + mock_context = AsyncMock() + mock_context.__aenter__.return_value = mock_session + mock_context.__aexit__.return_value = None + mock_maker.return_value.return_value = mock_context + + with patch( + "app.tasks.celery_tasks.blocknote_migration_tasks.convert_markdown_to_blocknote", + new_callable=AsyncMock, + ) as mock_convert: + await _populate_blocknote_for_documents() + + # Should not call convert for empty document + mock_convert.assert_not_called() + + @pytest.mark.asyncio + async def test_processes_specific_document_ids(self): + """Test processing only specific document IDs.""" + from app.tasks.celery_tasks.blocknote_migration_tasks import ( + _populate_blocknote_for_documents, + ) + + mock_session = AsyncMock() + + mock_chunk = MagicMock() + mock_chunk.id = 1 + mock_chunk.content = "Test content" + + mock_document = MagicMock() + mock_document.id = 5 + mock_document.title = "Specific Document" + mock_document.chunks = [mock_chunk] + mock_document.blocknote_document = None + + mock_result = MagicMock() + mock_result.scalars.return_value.all.return_value = [mock_document] + mock_session.execute.return_value = mock_result + + with patch( + "app.tasks.celery_tasks.blocknote_migration_tasks.get_celery_session_maker" + ) as mock_maker: + mock_context = AsyncMock() + mock_context.__aenter__.return_value = mock_session + mock_context.__aexit__.return_value = None + mock_maker.return_value.return_value = mock_context + + with patch( + "app.tasks.celery_tasks.blocknote_migration_tasks.convert_markdown_to_blocknote", + new_callable=AsyncMock, + ) as mock_convert: + mock_convert.return_value = {"type": "doc", "content": []} + + await _populate_blocknote_for_documents(document_ids=[5, 10, 15]) + + mock_session.execute.assert_called_once() + + @pytest.mark.asyncio + async def test_handles_conversion_failure(self): + """Test handling conversion failures gracefully.""" + from app.tasks.celery_tasks.blocknote_migration_tasks import ( + _populate_blocknote_for_documents, + ) + + mock_session = AsyncMock() + + mock_chunk = MagicMock() + mock_chunk.id = 1 + mock_chunk.content = "Test content" + + mock_document = MagicMock() + mock_document.id = 1 + mock_document.title = "Test Document" + mock_document.chunks = [mock_chunk] + mock_document.blocknote_document = None + + mock_result = MagicMock() + mock_result.scalars.return_value.all.return_value = [mock_document] + mock_session.execute.return_value = mock_result + + with patch( + "app.tasks.celery_tasks.blocknote_migration_tasks.get_celery_session_maker" + ) as mock_maker: + mock_context = AsyncMock() + mock_context.__aenter__.return_value = mock_session + mock_context.__aexit__.return_value = None + mock_maker.return_value.return_value = mock_context + + with patch( + "app.tasks.celery_tasks.blocknote_migration_tasks.convert_markdown_to_blocknote", + new_callable=AsyncMock, + ) as mock_convert: + mock_convert.return_value = None # Conversion failed + + await _populate_blocknote_for_documents() + + # Should still commit (with failures tracked) + mock_session.commit.assert_called() + + @pytest.mark.asyncio + async def test_batch_processing(self): + """Test batch processing of multiple documents.""" + from app.tasks.celery_tasks.blocknote_migration_tasks import ( + _populate_blocknote_for_documents, + ) + + mock_session = AsyncMock() + + # Create multiple documents + documents = [] + for i in range(5): + mock_chunk = MagicMock() + mock_chunk.id = i + mock_chunk.content = f"Content {i}" + + mock_doc = MagicMock() + mock_doc.id = i + mock_doc.title = f"Document {i}" + mock_doc.chunks = [mock_chunk] + mock_doc.blocknote_document = None + documents.append(mock_doc) + + mock_result = MagicMock() + mock_result.scalars.return_value.all.return_value = documents + mock_session.execute.return_value = mock_result + + with patch( + "app.tasks.celery_tasks.blocknote_migration_tasks.get_celery_session_maker" + ) as mock_maker: + mock_context = AsyncMock() + mock_context.__aenter__.return_value = mock_session + mock_context.__aexit__.return_value = None + mock_maker.return_value.return_value = mock_context + + with patch( + "app.tasks.celery_tasks.blocknote_migration_tasks.convert_markdown_to_blocknote", + new_callable=AsyncMock, + ) as mock_convert: + mock_convert.return_value = {"type": "doc", "content": []} + + await _populate_blocknote_for_documents(batch_size=2) + + # Should have called convert for each document + assert mock_convert.call_count == 5 + + +# ============================================================================ +# DOCUMENT REINDEX TASK TESTS +# ============================================================================ + + +class TestDocumentReindexTaskSessionMaker: + """Tests for document reindex task session maker.""" + + def test_get_celery_session_maker_returns_maker(self): + """Test that get_celery_session_maker returns a session maker.""" + from app.tasks.celery_tasks.document_reindex_tasks import get_celery_session_maker + + with patch( + "app.tasks.celery_tasks.document_reindex_tasks.create_async_engine" + ) as mock_engine: + with patch( + "app.tasks.celery_tasks.document_reindex_tasks.async_sessionmaker" + ) as mock_maker: + mock_engine.return_value = MagicMock() + mock_maker.return_value = MagicMock() + + result = get_celery_session_maker() + + assert result is not None + + def test_get_celery_session_maker_uses_null_pool(self): + """Test that NullPool is used.""" + from sqlalchemy.pool import NullPool + from app.tasks.celery_tasks.document_reindex_tasks import get_celery_session_maker + + with patch( + "app.tasks.celery_tasks.document_reindex_tasks.create_async_engine" + ) as mock_engine: + with patch( + "app.tasks.celery_tasks.document_reindex_tasks.async_sessionmaker" + ): + get_celery_session_maker() + + call_kwargs = mock_engine.call_args[1] + assert call_kwargs.get("poolclass") == NullPool + + +class TestReindexDocument: + """Tests for _reindex_document function.""" + + @pytest.mark.asyncio + async def test_document_not_found(self): + """Test handling when document is not found.""" + from app.tasks.celery_tasks.document_reindex_tasks import _reindex_document + + mock_session = AsyncMock() + mock_result = MagicMock() + mock_result.scalars.return_value.first.return_value = None + mock_session.execute.return_value = mock_result + + with patch( + "app.tasks.celery_tasks.document_reindex_tasks.get_celery_session_maker" + ) as mock_maker: + mock_context = AsyncMock() + mock_context.__aenter__.return_value = mock_session + mock_context.__aexit__.return_value = None + mock_maker.return_value.return_value = mock_context + + await _reindex_document(999, "user1") + + # Should not commit anything + mock_session.commit.assert_not_called() + + @pytest.mark.asyncio + async def test_document_without_blocknote_content(self): + """Test handling document without blocknote content.""" + from app.tasks.celery_tasks.document_reindex_tasks import _reindex_document + + mock_session = AsyncMock() + mock_document = MagicMock() + mock_document.id = 1 + mock_document.blocknote_document = None + + mock_result = MagicMock() + mock_result.scalars.return_value.first.return_value = mock_document + mock_session.execute.return_value = mock_result + + with patch( + "app.tasks.celery_tasks.document_reindex_tasks.get_celery_session_maker" + ) as mock_maker: + mock_context = AsyncMock() + mock_context.__aenter__.return_value = mock_session + mock_context.__aexit__.return_value = None + mock_maker.return_value.return_value = mock_context + + await _reindex_document(1, "user1") + + mock_session.commit.assert_not_called() + + @pytest.mark.asyncio + async def test_successful_reindex(self): + """Test successful document reindexing.""" + from app.tasks.celery_tasks.document_reindex_tasks import _reindex_document + from app.db import DocumentType + + mock_session = AsyncMock() + # session.add is synchronous, so use MagicMock + mock_session.add = MagicMock() + mock_document = MagicMock() + mock_document.id = 1 + mock_document.title = "Test Document" + mock_document.blocknote_document = {"type": "doc", "content": []} + mock_document.document_type = DocumentType.FILE + mock_document.search_space_id = 1 + mock_document.chunks = [] + + mock_result = MagicMock() + mock_result.scalars.return_value.first.return_value = mock_document + mock_session.execute.return_value = mock_result + + with patch( + "app.tasks.celery_tasks.document_reindex_tasks.get_celery_session_maker" + ) as mock_maker: + mock_context = AsyncMock() + mock_context.__aenter__.return_value = mock_session + mock_context.__aexit__.return_value = None + mock_maker.return_value.return_value = mock_context + + with patch( + "app.tasks.celery_tasks.document_reindex_tasks.convert_blocknote_to_markdown", + new_callable=AsyncMock, + ) as mock_convert: + mock_convert.return_value = "# Test Document\n\nContent here" + + with patch( + "app.tasks.celery_tasks.document_reindex_tasks.create_document_chunks", + new_callable=AsyncMock, + ) as mock_chunks: + mock_chunk = MagicMock() + mock_chunk.document_id = None + mock_chunks.return_value = [mock_chunk] + + with patch( + "app.tasks.celery_tasks.document_reindex_tasks.get_user_long_context_llm", + new_callable=AsyncMock, + ) as mock_llm: + mock_llm_instance = MagicMock() + mock_llm.return_value = mock_llm_instance + + with patch( + "app.tasks.celery_tasks.document_reindex_tasks.generate_document_summary", + new_callable=AsyncMock, + ) as mock_summary: + mock_summary.return_value = ( + "Summary content", + [0.1, 0.2, 0.3], + ) + + await _reindex_document(1, "user1") + + mock_convert.assert_called_once() + mock_chunks.assert_called_once() + mock_summary.assert_called_once() + mock_session.commit.assert_called_once() + + @pytest.mark.asyncio + async def test_reindex_deletes_old_chunks(self): + """Test that old chunks are deleted during reindex.""" + from app.tasks.celery_tasks.document_reindex_tasks import _reindex_document + from app.db import DocumentType + + mock_session = AsyncMock() + mock_document = MagicMock() + mock_document.id = 1 + mock_document.title = "Test" + mock_document.blocknote_document = {"type": "doc"} + mock_document.document_type = DocumentType.FILE + mock_document.search_space_id = 1 + + mock_result = MagicMock() + mock_result.scalars.return_value.first.return_value = mock_document + mock_session.execute.return_value = mock_result + + with patch( + "app.tasks.celery_tasks.document_reindex_tasks.get_celery_session_maker" + ) as mock_maker: + mock_context = AsyncMock() + mock_context.__aenter__.return_value = mock_session + mock_context.__aexit__.return_value = None + mock_maker.return_value.return_value = mock_context + + with patch( + "app.tasks.celery_tasks.document_reindex_tasks.convert_blocknote_to_markdown", + new_callable=AsyncMock, + ) as mock_convert: + mock_convert.return_value = "Content" + + with patch( + "app.tasks.celery_tasks.document_reindex_tasks.create_document_chunks", + new_callable=AsyncMock, + ) as mock_chunks: + mock_chunks.return_value = [] + + with patch( + "app.tasks.celery_tasks.document_reindex_tasks.get_user_long_context_llm", + new_callable=AsyncMock, + ): + with patch( + "app.tasks.celery_tasks.document_reindex_tasks.generate_document_summary", + new_callable=AsyncMock, + ) as mock_summary: + mock_summary.return_value = ("Summary", [0.1]) + + await _reindex_document(1, "user1") + + # Verify delete was called (execute is called for select and delete) + assert mock_session.execute.call_count >= 2 + mock_session.flush.assert_called() + + @pytest.mark.asyncio + async def test_handles_conversion_failure(self): + """Test handling markdown conversion failure.""" + from app.tasks.celery_tasks.document_reindex_tasks import _reindex_document + + mock_session = AsyncMock() + mock_document = MagicMock() + mock_document.id = 1 + mock_document.title = "Test" + mock_document.blocknote_document = {"type": "doc"} + + mock_result = MagicMock() + mock_result.scalars.return_value.first.return_value = mock_document + mock_session.execute.return_value = mock_result + + with patch( + "app.tasks.celery_tasks.document_reindex_tasks.get_celery_session_maker" + ) as mock_maker: + mock_context = AsyncMock() + mock_context.__aenter__.return_value = mock_session + mock_context.__aexit__.return_value = None + mock_maker.return_value.return_value = mock_context + + with patch( + "app.tasks.celery_tasks.document_reindex_tasks.convert_blocknote_to_markdown", + new_callable=AsyncMock, + ) as mock_convert: + mock_convert.return_value = None # Conversion failed + + await _reindex_document(1, "user1") + + mock_session.commit.assert_not_called() + + @pytest.mark.asyncio + async def test_handles_database_error(self): + """Test handling database errors during reindex.""" + from app.tasks.celery_tasks.document_reindex_tasks import _reindex_document + from sqlalchemy.exc import SQLAlchemyError + from app.db import DocumentType + + mock_session = AsyncMock() + mock_document = MagicMock() + mock_document.id = 1 + mock_document.title = "Test" + mock_document.blocknote_document = {"type": "doc"} + mock_document.document_type = DocumentType.FILE + mock_document.search_space_id = 1 + + mock_result = MagicMock() + mock_result.scalars.return_value.first.return_value = mock_document + mock_session.execute.return_value = mock_result + mock_session.commit.side_effect = SQLAlchemyError("DB error") + + with patch( + "app.tasks.celery_tasks.document_reindex_tasks.get_celery_session_maker" + ) as mock_maker: + mock_context = AsyncMock() + mock_context.__aenter__.return_value = mock_session + mock_context.__aexit__.return_value = None + mock_maker.return_value.return_value = mock_context + + with patch( + "app.tasks.celery_tasks.document_reindex_tasks.convert_blocknote_to_markdown", + new_callable=AsyncMock, + ) as mock_convert: + mock_convert.return_value = "Content" + + with patch( + "app.tasks.celery_tasks.document_reindex_tasks.create_document_chunks", + new_callable=AsyncMock, + ): + with patch( + "app.tasks.celery_tasks.document_reindex_tasks.get_user_long_context_llm", + new_callable=AsyncMock, + ): + with patch( + "app.tasks.celery_tasks.document_reindex_tasks.generate_document_summary", + new_callable=AsyncMock, + ) as mock_summary: + mock_summary.return_value = ("Summary", [0.1]) + + with pytest.raises(SQLAlchemyError): + await _reindex_document(1, "user1") + + mock_session.rollback.assert_called_once() + + +# ============================================================================ +# CONNECTOR TASKS ADDITIONAL TESTS +# ============================================================================ + + +class TestConnectorTasksGmailDaysBackCalculation: + """Additional tests for Gmail days_back calculation.""" + + @pytest.mark.asyncio + async def test_gmail_calculates_correct_days_back(self): + """Test Gmail indexing calculates correct days_back from start_date.""" + from app.tasks.celery_tasks.connector_tasks import _index_google_gmail_messages + from datetime import datetime, timedelta + + mock_session = AsyncMock() + + # Set start_date to 15 days ago + start_date = (datetime.now() - timedelta(days=15)).strftime("%Y-%m-%d") + + with patch( + "app.tasks.celery_tasks.connector_tasks.get_celery_session_maker" + ) as mock_maker: + mock_context = AsyncMock() + mock_context.__aenter__.return_value = mock_session + mock_context.__aexit__.return_value = None + mock_maker.return_value.return_value = mock_context + + with patch( + "app.routes.search_source_connectors_routes.run_google_gmail_indexing", + new_callable=AsyncMock, + ) as mock_run: + await _index_google_gmail_messages(1, 1, "user1", start_date, None) + + mock_run.assert_called_once() + call_args = mock_run.call_args[0] + # days_back should be approximately 15 + assert 14 <= call_args[5] <= 16 + + @pytest.mark.asyncio + async def test_gmail_minimum_days_back(self): + """Test Gmail uses minimum of 1 day when start_date is today.""" + from app.tasks.celery_tasks.connector_tasks import _index_google_gmail_messages + from datetime import datetime + + mock_session = AsyncMock() + + # Set start_date to today + start_date = datetime.now().strftime("%Y-%m-%d") + + with patch( + "app.tasks.celery_tasks.connector_tasks.get_celery_session_maker" + ) as mock_maker: + mock_context = AsyncMock() + mock_context.__aenter__.return_value = mock_session + mock_context.__aexit__.return_value = None + mock_maker.return_value.return_value = mock_context + + with patch( + "app.routes.search_source_connectors_routes.run_google_gmail_indexing", + new_callable=AsyncMock, + ) as mock_run: + await _index_google_gmail_messages(1, 1, "user1", start_date, None) + + mock_run.assert_called_once() + call_args = mock_run.call_args[0] + # days_back should be at least 1 + assert call_args[5] >= 1 + + +class TestConnectorTasksErrorHandling: + """Tests for error handling in connector tasks.""" + + @pytest.mark.asyncio + async def test_slack_task_handles_session_error(self): + """Test Slack task handles session creation errors.""" + from app.tasks.celery_tasks.connector_tasks import _index_slack_messages + + with patch( + "app.tasks.celery_tasks.connector_tasks.get_celery_session_maker" + ) as mock_maker: + mock_maker.side_effect = Exception("Session creation failed") + + with pytest.raises(Exception, match="Session creation failed"): + await _index_slack_messages(1, 1, "user1", "2024-01-01", "2024-12-31") + + @pytest.mark.asyncio + async def test_github_task_handles_indexing_error(self): + """Test GitHub task handles indexing errors.""" + from app.tasks.celery_tasks.connector_tasks import _index_github_repos + + mock_session = AsyncMock() + + with patch( + "app.tasks.celery_tasks.connector_tasks.get_celery_session_maker" + ) as mock_maker: + mock_context = AsyncMock() + mock_context.__aenter__.return_value = mock_session + mock_context.__aexit__.return_value = None + mock_maker.return_value.return_value = mock_context + + with patch( + "app.routes.search_source_connectors_routes.run_github_indexing", + new_callable=AsyncMock, + ) as mock_run: + mock_run.side_effect = Exception("GitHub API error") + + with pytest.raises(Exception, match="GitHub API error"): + await _index_github_repos(1, 1, "user1", "2024-01-01", "2024-12-31") diff --git a/surfsense_backend/tests/test_connector_config.py b/surfsense_backend/tests/test_connector_config.py new file mode 100644 index 000000000..a786f77db --- /dev/null +++ b/surfsense_backend/tests/test_connector_config.py @@ -0,0 +1,256 @@ +""" +Tests for the connector configuration validation in validators module. +""" + +import pytest + +from app.utils.validators import validate_connector_config + + +class TestValidateConnectorConfig: + """Tests for validate_connector_config function.""" + + def test_invalid_config_type_raises_error(self): + """Test that non-dict config raises ValueError.""" + with pytest.raises(ValueError) as exc_info: + validate_connector_config("TAVILY_API", "not a dict") + assert "must be a dictionary" in str(exc_info.value) + + def test_boolean_config_raises_error(self): + """Test that boolean config raises ValueError.""" + with pytest.raises(ValueError) as exc_info: + validate_connector_config("TAVILY_API", True) + assert "must be a dictionary" in str(exc_info.value) + + def test_tavily_api_valid_config(self): + """Test valid Tavily API configuration.""" + config = {"TAVILY_API_KEY": "test-api-key-123"} + result = validate_connector_config("TAVILY_API", config) + assert result == config + + def test_tavily_api_missing_key_raises_error(self): + """Test that missing TAVILY_API_KEY raises ValueError.""" + with pytest.raises(ValueError) as exc_info: + validate_connector_config("TAVILY_API", {}) + assert "TAVILY_API_KEY" in str(exc_info.value) + + def test_tavily_api_empty_key_raises_error(self): + """Test that empty TAVILY_API_KEY raises ValueError.""" + with pytest.raises(ValueError) as exc_info: + validate_connector_config("TAVILY_API", {"TAVILY_API_KEY": ""}) + assert "cannot be empty" in str(exc_info.value) + + def test_tavily_api_unexpected_key_raises_error(self): + """Test that unexpected key in config raises ValueError.""" + with pytest.raises(ValueError) as exc_info: + validate_connector_config( + "TAVILY_API", + {"TAVILY_API_KEY": "test-key", "UNEXPECTED_KEY": "value"}, + ) + assert "may only contain" in str(exc_info.value) + + def test_linkup_api_valid_config(self): + """Test valid LinkUp API configuration.""" + config = {"LINKUP_API_KEY": "linkup-key-123"} + result = validate_connector_config("LINKUP_API", config) + assert result == config + + def test_searxng_api_valid_config(self): + """Test valid SearxNG API configuration.""" + config = {"SEARXNG_HOST": "https://searxng.example.com"} + result = validate_connector_config("SEARXNG_API", config) + assert result == config + + def test_searxng_api_with_optional_params(self): + """Test SearxNG API with optional parameters.""" + config = { + "SEARXNG_HOST": "https://searxng.example.com", + "SEARXNG_API_KEY": "optional-key", + "SEARXNG_ENGINES": "google,bing", + "SEARXNG_LANGUAGE": "en", + } + result = validate_connector_config("SEARXNG_API", config) + assert result == config + + def test_searxng_api_invalid_host_raises_error(self): + """Test that invalid SEARXNG_HOST raises ValueError.""" + with pytest.raises(ValueError) as exc_info: + validate_connector_config("SEARXNG_API", {"SEARXNG_HOST": "not-a-url"}) + assert "Invalid base URL" in str(exc_info.value) + + def test_slack_connector_valid_config(self): + """Test valid Slack connector configuration.""" + config = {"SLACK_BOT_TOKEN": "xoxb-token-123"} + result = validate_connector_config("SLACK_CONNECTOR", config) + assert result == config + + def test_notion_connector_valid_config(self): + """Test valid Notion connector configuration.""" + config = {"NOTION_INTEGRATION_TOKEN": "secret_token_123"} + result = validate_connector_config("NOTION_CONNECTOR", config) + assert result == config + + def test_github_connector_valid_config(self): + """Test valid GitHub connector configuration.""" + config = { + "GITHUB_PAT": "ghp_token_123", + "repo_full_names": ["owner/repo1", "owner/repo2"], + } + result = validate_connector_config("GITHUB_CONNECTOR", config) + assert result == config + + def test_github_connector_empty_repos_raises_error(self): + """Test that empty repo_full_names raises ValueError.""" + with pytest.raises(ValueError) as exc_info: + validate_connector_config( + "GITHUB_CONNECTOR", + {"GITHUB_PAT": "ghp_token_123", "repo_full_names": []}, + ) + assert "non-empty list" in str(exc_info.value) + + def test_jira_connector_valid_config(self): + """Test valid Jira connector configuration.""" + config = { + "JIRA_EMAIL": "user@example.com", + "JIRA_API_TOKEN": "api-token-123", + "JIRA_BASE_URL": "https://company.atlassian.net", + } + result = validate_connector_config("JIRA_CONNECTOR", config) + assert result == config + + def test_jira_connector_invalid_email_raises_error(self): + """Test that invalid JIRA_EMAIL raises ValueError.""" + with pytest.raises(ValueError) as exc_info: + validate_connector_config( + "JIRA_CONNECTOR", + { + "JIRA_EMAIL": "not-an-email", + "JIRA_API_TOKEN": "token", + "JIRA_BASE_URL": "https://company.atlassian.net", + }, + ) + assert "Invalid email" in str(exc_info.value) + + def test_jira_connector_invalid_url_raises_error(self): + """Test that invalid JIRA_BASE_URL raises ValueError.""" + with pytest.raises(ValueError) as exc_info: + validate_connector_config( + "JIRA_CONNECTOR", + { + "JIRA_EMAIL": "user@example.com", + "JIRA_API_TOKEN": "token", + "JIRA_BASE_URL": "not-a-url", + }, + ) + assert "Invalid base URL" in str(exc_info.value) + + def test_confluence_connector_valid_config(self): + """Test valid Confluence connector configuration.""" + config = { + "CONFLUENCE_BASE_URL": "https://company.atlassian.net/wiki", + "CONFLUENCE_EMAIL": "user@example.com", + "CONFLUENCE_API_TOKEN": "api-token-123", + } + result = validate_connector_config("CONFLUENCE_CONNECTOR", config) + assert result == config + + def test_linear_connector_valid_config(self): + """Test valid Linear connector configuration.""" + config = {"LINEAR_API_KEY": "lin_api_key_123"} + result = validate_connector_config("LINEAR_CONNECTOR", config) + assert result == config + + def test_discord_connector_valid_config(self): + """Test valid Discord connector configuration.""" + config = {"DISCORD_BOT_TOKEN": "discord-token-123"} + result = validate_connector_config("DISCORD_CONNECTOR", config) + assert result == config + + def test_clickup_connector_valid_config(self): + """Test valid ClickUp connector configuration.""" + config = {"CLICKUP_API_TOKEN": "pk_token_123"} + result = validate_connector_config("CLICKUP_CONNECTOR", config) + assert result == config + + def test_luma_connector_valid_config(self): + """Test valid Luma connector configuration.""" + config = {"LUMA_API_KEY": "luma-key-123"} + result = validate_connector_config("LUMA_CONNECTOR", config) + assert result == config + + def test_webcrawler_connector_valid_without_api_key(self): + """Test valid WebCrawler connector without API key (optional).""" + config = {} + result = validate_connector_config("WEBCRAWLER_CONNECTOR", config) + assert result == config + + def test_webcrawler_connector_valid_with_api_key(self): + """Test valid WebCrawler connector with API key.""" + config = {"FIRECRAWL_API_KEY": "fc-api-key-123"} + result = validate_connector_config("WEBCRAWLER_CONNECTOR", config) + assert result == config + + def test_webcrawler_connector_invalid_api_key_format(self): + """Test that invalid Firecrawl API key format raises ValueError.""" + with pytest.raises(ValueError) as exc_info: + validate_connector_config( + "WEBCRAWLER_CONNECTOR", + {"FIRECRAWL_API_KEY": "invalid-format-key"}, + ) + assert "should start with 'fc-'" in str(exc_info.value) + + def test_webcrawler_connector_valid_with_urls(self): + """Test valid WebCrawler connector with initial URLs.""" + config = {"INITIAL_URLS": "https://example.com\nhttps://another.com"} + result = validate_connector_config("WEBCRAWLER_CONNECTOR", config) + assert result == config + + def test_webcrawler_connector_invalid_urls(self): + """Test that invalid URL in INITIAL_URLS raises ValueError.""" + with pytest.raises(ValueError) as exc_info: + validate_connector_config( + "WEBCRAWLER_CONNECTOR", + {"INITIAL_URLS": "https://valid.com\nnot-a-valid-url"}, + ) + assert "Invalid URL format" in str(exc_info.value) + + def test_baidu_search_api_valid_config(self): + """Test valid Baidu Search API configuration.""" + config = {"BAIDU_API_KEY": "baidu-api-key-123"} + result = validate_connector_config("BAIDU_SEARCH_API", config) + assert result == config + + def test_baidu_search_api_with_optional_params(self): + """Test Baidu Search API with optional parameters.""" + config = { + "BAIDU_API_KEY": "baidu-api-key-123", + "BAIDU_MODEL": "ernie-4.0", + "BAIDU_SEARCH_SOURCE": "baidu_search_v2", + "BAIDU_ENABLE_DEEP_SEARCH": True, + } + result = validate_connector_config("BAIDU_SEARCH_API", config) + assert result == config + + def test_serper_api_valid_config(self): + """Test valid Serper API configuration.""" + config = {"SERPER_API_KEY": "serper-api-key-123"} + result = validate_connector_config("SERPER_API", config) + assert result == config + + def test_unknown_connector_type_passes_through(self): + """Test that unknown connector type passes config through unchanged.""" + config = {"ANY_KEY": "any_value"} + result = validate_connector_config("UNKNOWN_CONNECTOR", config) + assert result == config + + def test_connector_type_enum_handling(self): + """Test that connector type enum is handled correctly.""" + from unittest.mock import MagicMock + + mock_enum = MagicMock() + mock_enum.value = "TAVILY_API" + + config = {"TAVILY_API_KEY": "test-key"} + # The function should handle enum-like objects + result = validate_connector_config(mock_enum, config) + assert result == config diff --git a/surfsense_backend/tests/test_connector_indexers_comprehensive.py b/surfsense_backend/tests/test_connector_indexers_comprehensive.py new file mode 100644 index 000000000..c7f3f5ddf --- /dev/null +++ b/surfsense_backend/tests/test_connector_indexers_comprehensive.py @@ -0,0 +1,1178 @@ +"""Comprehensive tests for connector indexers module.""" + +import pytest +from unittest.mock import AsyncMock, MagicMock, patch + + +# ============================================================================ +# SLACK INDEXER TESTS +# ============================================================================ + + +class TestSlackIndexer: + """Tests for Slack connector indexer.""" + + @pytest.mark.asyncio + async def test_index_slack_messages_connector_not_found(self): + """Test handling when connector is not found.""" + from app.tasks.connector_indexers.slack_indexer import index_slack_messages + + mock_session = AsyncMock() + mock_task_logger = MagicMock() + mock_task_logger.log_task_start = AsyncMock(return_value=MagicMock()) + mock_task_logger.log_task_failure = AsyncMock() + mock_task_logger.log_task_progress = AsyncMock() + + with patch( + "app.tasks.connector_indexers.slack_indexer.TaskLoggingService", + return_value=mock_task_logger, + ): + with patch( + "app.tasks.connector_indexers.slack_indexer.get_connector_by_id", + new_callable=AsyncMock, + ) as mock_get_connector: + mock_get_connector.return_value = None + + count, error = await index_slack_messages( + mock_session, 999, 1, "user1", "2024-01-01", "2024-12-31" + ) + + assert count == 0 + assert "not found" in error.lower() + + @pytest.mark.asyncio + async def test_index_slack_messages_missing_token(self): + """Test handling when Slack token is missing.""" + from app.tasks.connector_indexers.slack_indexer import index_slack_messages + + mock_session = AsyncMock() + mock_connector = MagicMock() + mock_connector.config = {} # No token + + mock_task_logger = MagicMock() + mock_task_logger.log_task_start = AsyncMock(return_value=MagicMock()) + mock_task_logger.log_task_failure = AsyncMock() + mock_task_logger.log_task_progress = AsyncMock() + + with patch( + "app.tasks.connector_indexers.slack_indexer.TaskLoggingService", + return_value=mock_task_logger, + ): + with patch( + "app.tasks.connector_indexers.slack_indexer.get_connector_by_id", + new_callable=AsyncMock, + ) as mock_get_connector: + mock_get_connector.return_value = mock_connector + + count, error = await index_slack_messages( + mock_session, 1, 1, "user1", "2024-01-01", "2024-12-31" + ) + + assert count == 0 + assert "token" in error.lower() + + @pytest.mark.asyncio + async def test_index_slack_messages_no_channels_found(self): + """Test handling when no Slack channels are found.""" + from app.tasks.connector_indexers.slack_indexer import index_slack_messages + + mock_session = AsyncMock() + mock_connector = MagicMock() + mock_connector.config = {"SLACK_BOT_TOKEN": "xoxb-test-token"} + mock_connector.last_indexed_at = None + + mock_task_logger = MagicMock() + mock_task_logger.log_task_start = AsyncMock(return_value=MagicMock()) + mock_task_logger.log_task_success = AsyncMock() + mock_task_logger.log_task_progress = AsyncMock() + + with patch( + "app.tasks.connector_indexers.slack_indexer.TaskLoggingService", + return_value=mock_task_logger, + ): + with patch( + "app.tasks.connector_indexers.slack_indexer.get_connector_by_id", + new_callable=AsyncMock, + ) as mock_get_connector: + mock_get_connector.return_value = mock_connector + + with patch( + "app.tasks.connector_indexers.slack_indexer.SlackHistory" + ) as mock_slack: + mock_slack_instance = MagicMock() + mock_slack_instance.get_all_channels.return_value = [] + mock_slack.return_value = mock_slack_instance + + count, error = await index_slack_messages( + mock_session, 1, 1, "user1", "2024-01-01", "2024-12-31" + ) + + assert count == 0 + assert "no slack channels found" in error.lower() + + @pytest.mark.asyncio + async def test_index_slack_messages_successful_indexing(self): + """Test successful Slack message indexing.""" + from app.tasks.connector_indexers.slack_indexer import index_slack_messages + + mock_session = AsyncMock() + # session.add is synchronous, so use MagicMock + mock_session.add = MagicMock() + mock_connector = MagicMock() + mock_connector.config = {"SLACK_BOT_TOKEN": "xoxb-test-token"} + mock_connector.last_indexed_at = None + + mock_task_logger = MagicMock() + mock_task_logger.log_task_start = AsyncMock(return_value=MagicMock()) + mock_task_logger.log_task_success = AsyncMock() + mock_task_logger.log_task_progress = AsyncMock() + + mock_channels = [ + {"id": "C123", "name": "general", "is_private": False, "is_member": True} + ] + + mock_messages = [ + { + "ts": "1234567890.123456", + "datetime": "2024-01-15 10:00:00", + "user_name": "Test User", + "user_email": "test@example.com", + "text": "Hello world", + } + ] + + with patch( + "app.tasks.connector_indexers.slack_indexer.TaskLoggingService", + return_value=mock_task_logger, + ): + with patch( + "app.tasks.connector_indexers.slack_indexer.get_connector_by_id", + new_callable=AsyncMock, + ) as mock_get_connector: + mock_get_connector.return_value = mock_connector + + with patch( + "app.tasks.connector_indexers.slack_indexer.SlackHistory" + ) as mock_slack: + mock_slack_instance = MagicMock() + mock_slack_instance.get_all_channels.return_value = mock_channels + mock_slack_instance.get_history_by_date_range.return_value = ( + mock_messages, + None, + ) + mock_slack_instance.format_message.return_value = mock_messages[0] + mock_slack.return_value = mock_slack_instance + + with patch( + "app.tasks.connector_indexers.slack_indexer.check_document_by_unique_identifier", + new_callable=AsyncMock, + ) as mock_check: + mock_check.return_value = None # No existing document + + with patch( + "app.tasks.connector_indexers.slack_indexer.create_document_chunks", + new_callable=AsyncMock, + ) as mock_chunks: + mock_chunks.return_value = [] + + with patch( + "app.tasks.connector_indexers.slack_indexer.config" + ) as mock_config: + mock_config.embedding_model_instance.embed.return_value = [ + 0.1, + 0.2, + ] + + count, error = await index_slack_messages( + mock_session, + 1, + 1, + "user1", + "2024-01-01", + "2024-12-31", + ) + + assert count >= 0 + mock_session.add.assert_called() + + @pytest.mark.asyncio + async def test_index_slack_messages_skips_private_channels(self): + """Test that private channels where bot is not a member are skipped.""" + from app.tasks.connector_indexers.slack_indexer import index_slack_messages + + mock_session = AsyncMock() + mock_connector = MagicMock() + mock_connector.config = {"SLACK_BOT_TOKEN": "xoxb-test-token"} + mock_connector.last_indexed_at = None + + mock_task_logger = MagicMock() + mock_task_logger.log_task_start = AsyncMock(return_value=MagicMock()) + mock_task_logger.log_task_success = AsyncMock() + mock_task_logger.log_task_progress = AsyncMock() + + # Only private channel where bot is not a member + mock_channels = [ + {"id": "C456", "name": "private-channel", "is_private": True, "is_member": False} + ] + + with patch( + "app.tasks.connector_indexers.slack_indexer.TaskLoggingService", + return_value=mock_task_logger, + ): + with patch( + "app.tasks.connector_indexers.slack_indexer.get_connector_by_id", + new_callable=AsyncMock, + ) as mock_get_connector: + mock_get_connector.return_value = mock_connector + + with patch( + "app.tasks.connector_indexers.slack_indexer.SlackHistory" + ) as mock_slack: + mock_slack_instance = MagicMock() + mock_slack_instance.get_all_channels.return_value = mock_channels + mock_slack.return_value = mock_slack_instance + + count, error = await index_slack_messages( + mock_session, 1, 1, "user1", "2024-01-01", "2024-12-31" + ) + + # Should have processed but skipped the private channel + assert "skipped" in error.lower() or count == 0 + + @pytest.mark.asyncio + async def test_index_slack_messages_handles_api_error(self): + """Test handling of Slack API errors.""" + from app.tasks.connector_indexers.slack_indexer import index_slack_messages + from slack_sdk.errors import SlackApiError + + mock_session = AsyncMock() + mock_connector = MagicMock() + mock_connector.config = {"SLACK_BOT_TOKEN": "xoxb-test-token"} + mock_connector.last_indexed_at = None + + mock_task_logger = MagicMock() + mock_task_logger.log_task_start = AsyncMock(return_value=MagicMock()) + mock_task_logger.log_task_failure = AsyncMock() + mock_task_logger.log_task_progress = AsyncMock() + + with patch( + "app.tasks.connector_indexers.slack_indexer.TaskLoggingService", + return_value=mock_task_logger, + ): + with patch( + "app.tasks.connector_indexers.slack_indexer.get_connector_by_id", + new_callable=AsyncMock, + ) as mock_get_connector: + mock_get_connector.return_value = mock_connector + + with patch( + "app.tasks.connector_indexers.slack_indexer.SlackHistory" + ) as mock_slack: + mock_slack_instance = MagicMock() + mock_slack_instance.get_all_channels.side_effect = Exception( + "API error" + ) + mock_slack.return_value = mock_slack_instance + + count, error = await index_slack_messages( + mock_session, 1, 1, "user1", "2024-01-01", "2024-12-31" + ) + + assert count == 0 + assert "failed" in error.lower() + + +# ============================================================================ +# NOTION INDEXER TESTS +# ============================================================================ + + +class TestNotionIndexer: + """Tests for Notion connector indexer.""" + + @pytest.mark.asyncio + async def test_index_notion_pages_connector_not_found(self): + """Test handling when connector is not found.""" + from app.tasks.connector_indexers.notion_indexer import index_notion_pages + + mock_session = AsyncMock() + mock_task_logger = MagicMock() + mock_task_logger.log_task_start = AsyncMock(return_value=MagicMock()) + mock_task_logger.log_task_failure = AsyncMock() + mock_task_logger.log_task_progress = AsyncMock() + + with patch( + "app.tasks.connector_indexers.notion_indexer.TaskLoggingService", + return_value=mock_task_logger, + ): + with patch( + "app.tasks.connector_indexers.notion_indexer.get_connector_by_id", + new_callable=AsyncMock, + ) as mock_get_connector: + mock_get_connector.return_value = None + + count, error = await index_notion_pages( + mock_session, 999, 1, "user1", "2024-01-01", "2024-12-31" + ) + + assert count == 0 + assert "not found" in error.lower() + + @pytest.mark.asyncio + async def test_index_notion_pages_missing_token(self): + """Test handling when Notion token is missing.""" + from app.tasks.connector_indexers.notion_indexer import index_notion_pages + + mock_session = AsyncMock() + mock_connector = MagicMock() + mock_connector.config = {} # No token + + mock_task_logger = MagicMock() + mock_task_logger.log_task_start = AsyncMock(return_value=MagicMock()) + mock_task_logger.log_task_failure = AsyncMock() + mock_task_logger.log_task_progress = AsyncMock() + + with patch( + "app.tasks.connector_indexers.notion_indexer.TaskLoggingService", + return_value=mock_task_logger, + ): + with patch( + "app.tasks.connector_indexers.notion_indexer.get_connector_by_id", + new_callable=AsyncMock, + ) as mock_get_connector: + mock_get_connector.return_value = mock_connector + + count, error = await index_notion_pages( + mock_session, 1, 1, "user1", "2024-01-01", "2024-12-31" + ) + + assert count == 0 + assert "token" in error.lower() + + @pytest.mark.asyncio + async def test_index_notion_pages_no_pages_found(self): + """Test handling when no Notion pages are found.""" + from app.tasks.connector_indexers.notion_indexer import index_notion_pages + + mock_session = AsyncMock() + mock_connector = MagicMock() + mock_connector.config = {"NOTION_INTEGRATION_TOKEN": "secret_token"} + mock_connector.last_indexed_at = None + + mock_task_logger = MagicMock() + mock_task_logger.log_task_start = AsyncMock(return_value=MagicMock()) + mock_task_logger.log_task_success = AsyncMock() + mock_task_logger.log_task_progress = AsyncMock() + + mock_notion_client = AsyncMock() + mock_notion_client.get_all_pages = AsyncMock(return_value=[]) + mock_notion_client.close = AsyncMock() + + with patch( + "app.tasks.connector_indexers.notion_indexer.TaskLoggingService", + return_value=mock_task_logger, + ): + with patch( + "app.tasks.connector_indexers.notion_indexer.get_connector_by_id", + new_callable=AsyncMock, + ) as mock_get_connector: + mock_get_connector.return_value = mock_connector + + with patch( + "app.tasks.connector_indexers.notion_indexer.NotionHistoryConnector" + ) as mock_notion: + mock_notion.return_value = mock_notion_client + + count, error = await index_notion_pages( + mock_session, 1, 1, "user1", "2024-01-01", "2024-12-31" + ) + + assert count == 0 + assert "no notion pages found" in error.lower() + + @pytest.mark.asyncio + async def test_index_notion_pages_successful_indexing(self): + """Test successful Notion page indexing.""" + from app.tasks.connector_indexers.notion_indexer import index_notion_pages + + mock_session = AsyncMock() + # session.add is synchronous, so use MagicMock + mock_session.add = MagicMock() + mock_connector = MagicMock() + mock_connector.config = {"NOTION_INTEGRATION_TOKEN": "secret_token"} + mock_connector.last_indexed_at = None + + mock_task_logger = MagicMock() + mock_task_logger.log_task_start = AsyncMock(return_value=MagicMock()) + mock_task_logger.log_task_success = AsyncMock() + mock_task_logger.log_task_progress = AsyncMock() + + mock_pages = [ + { + "page_id": "page-123", + "title": "Test Page", + "content": [ + {"type": "paragraph", "content": "Test content", "children": []} + ], + } + ] + + mock_notion_client = AsyncMock() + mock_notion_client.get_all_pages = AsyncMock(return_value=mock_pages) + mock_notion_client.close = AsyncMock() + + with patch( + "app.tasks.connector_indexers.notion_indexer.TaskLoggingService", + return_value=mock_task_logger, + ): + with patch( + "app.tasks.connector_indexers.notion_indexer.get_connector_by_id", + new_callable=AsyncMock, + ) as mock_get_connector: + mock_get_connector.return_value = mock_connector + + with patch( + "app.tasks.connector_indexers.notion_indexer.NotionHistoryConnector" + ) as mock_notion: + mock_notion.return_value = mock_notion_client + + with patch( + "app.tasks.connector_indexers.notion_indexer.check_document_by_unique_identifier", + new_callable=AsyncMock, + ) as mock_check: + mock_check.return_value = None + + with patch( + "app.tasks.connector_indexers.notion_indexer.get_user_long_context_llm", + new_callable=AsyncMock, + ) as mock_llm: + mock_llm.return_value = MagicMock() + + with patch( + "app.tasks.connector_indexers.notion_indexer.generate_document_summary", + new_callable=AsyncMock, + ) as mock_summary: + mock_summary.return_value = ( + "Summary", + [0.1, 0.2], + ) + + with patch( + "app.tasks.connector_indexers.notion_indexer.create_document_chunks", + new_callable=AsyncMock, + ) as mock_chunks: + mock_chunks.return_value = [] + + count, error = await index_notion_pages( + mock_session, + 1, + 1, + "user1", + "2024-01-01", + "2024-12-31", + ) + + assert count >= 0 + mock_notion_client.close.assert_called() + + @pytest.mark.asyncio + async def test_index_notion_pages_skips_empty_pages(self): + """Test that pages with no content are skipped.""" + from app.tasks.connector_indexers.notion_indexer import index_notion_pages + + mock_session = AsyncMock() + mock_connector = MagicMock() + mock_connector.config = {"NOTION_INTEGRATION_TOKEN": "secret_token"} + mock_connector.last_indexed_at = None + + mock_task_logger = MagicMock() + mock_task_logger.log_task_start = AsyncMock(return_value=MagicMock()) + mock_task_logger.log_task_success = AsyncMock() + mock_task_logger.log_task_progress = AsyncMock() + + # Page with no content + mock_pages = [{"page_id": "page-empty", "title": "Empty Page", "content": []}] + + mock_notion_client = AsyncMock() + mock_notion_client.get_all_pages = AsyncMock(return_value=mock_pages) + mock_notion_client.close = AsyncMock() + + with patch( + "app.tasks.connector_indexers.notion_indexer.TaskLoggingService", + return_value=mock_task_logger, + ): + with patch( + "app.tasks.connector_indexers.notion_indexer.get_connector_by_id", + new_callable=AsyncMock, + ) as mock_get_connector: + mock_get_connector.return_value = mock_connector + + with patch( + "app.tasks.connector_indexers.notion_indexer.NotionHistoryConnector" + ) as mock_notion: + mock_notion.return_value = mock_notion_client + + count, error = await index_notion_pages( + mock_session, 1, 1, "user1", "2024-01-01", "2024-12-31" + ) + + # Should skip the empty page + assert "skipped" in error.lower() or count == 0 + + +# ============================================================================ +# GITHUB INDEXER TESTS +# ============================================================================ + + +class TestGitHubIndexer: + """Tests for GitHub connector indexer.""" + + @pytest.mark.asyncio + async def test_index_github_repos_connector_not_found(self): + """Test handling when connector is not found.""" + from app.tasks.connector_indexers.github_indexer import index_github_repos + + mock_session = AsyncMock() + mock_task_logger = MagicMock() + mock_task_logger.log_task_start = AsyncMock(return_value=MagicMock()) + mock_task_logger.log_task_failure = AsyncMock() + mock_task_logger.log_task_progress = AsyncMock() + + with patch( + "app.tasks.connector_indexers.github_indexer.TaskLoggingService", + return_value=mock_task_logger, + ): + with patch( + "app.tasks.connector_indexers.github_indexer.get_connector_by_id", + new_callable=AsyncMock, + ) as mock_get_connector: + mock_get_connector.return_value = None + + count, error = await index_github_repos( + mock_session, 999, 1, "user1", "2024-01-01", "2024-12-31" + ) + + assert count == 0 + assert "not found" in error.lower() + + @pytest.mark.asyncio + async def test_index_github_repos_missing_pat(self): + """Test handling when GitHub PAT is missing.""" + from app.tasks.connector_indexers.github_indexer import index_github_repos + + mock_session = AsyncMock() + mock_connector = MagicMock() + mock_connector.config = {"repo_full_names": ["owner/repo"]} # No PAT + + mock_task_logger = MagicMock() + mock_task_logger.log_task_start = AsyncMock(return_value=MagicMock()) + mock_task_logger.log_task_failure = AsyncMock() + mock_task_logger.log_task_progress = AsyncMock() + + with patch( + "app.tasks.connector_indexers.github_indexer.TaskLoggingService", + return_value=mock_task_logger, + ): + with patch( + "app.tasks.connector_indexers.github_indexer.get_connector_by_id", + new_callable=AsyncMock, + ) as mock_get_connector: + mock_get_connector.return_value = mock_connector + + count, error = await index_github_repos( + mock_session, 1, 1, "user1", "2024-01-01", "2024-12-31" + ) + + assert count == 0 + assert "pat" in error.lower() or "token" in error.lower() + + @pytest.mark.asyncio + async def test_index_github_repos_missing_repo_list(self): + """Test handling when repo_full_names is missing.""" + from app.tasks.connector_indexers.github_indexer import index_github_repos + + mock_session = AsyncMock() + mock_connector = MagicMock() + mock_connector.config = {"GITHUB_PAT": "ghp_test_token"} # No repo list + + mock_task_logger = MagicMock() + mock_task_logger.log_task_start = AsyncMock(return_value=MagicMock()) + mock_task_logger.log_task_failure = AsyncMock() + mock_task_logger.log_task_progress = AsyncMock() + + with patch( + "app.tasks.connector_indexers.github_indexer.TaskLoggingService", + return_value=mock_task_logger, + ): + with patch( + "app.tasks.connector_indexers.github_indexer.get_connector_by_id", + new_callable=AsyncMock, + ) as mock_get_connector: + mock_get_connector.return_value = mock_connector + + count, error = await index_github_repos( + mock_session, 1, 1, "user1", "2024-01-01", "2024-12-31" + ) + + assert count == 0 + assert "repo_full_names" in error.lower() + + @pytest.mark.asyncio + async def test_index_github_repos_successful_indexing(self): + """Test successful GitHub repository indexing.""" + from app.tasks.connector_indexers.github_indexer import index_github_repos + + mock_session = AsyncMock() + # session.add is synchronous, so use MagicMock + mock_session.add = MagicMock() + mock_connector = MagicMock() + mock_connector.config = { + "GITHUB_PAT": "ghp_test_token", + "repo_full_names": ["owner/repo"], + } + + mock_task_logger = MagicMock() + mock_task_logger.log_task_start = AsyncMock(return_value=MagicMock()) + mock_task_logger.log_task_success = AsyncMock() + mock_task_logger.log_task_progress = AsyncMock() + + mock_files = [ + { + "path": "README.md", + "url": "https://github.com/owner/repo/blob/main/README.md", + "sha": "abc123", + "type": "doc", + } + ] + + with patch( + "app.tasks.connector_indexers.github_indexer.TaskLoggingService", + return_value=mock_task_logger, + ): + with patch( + "app.tasks.connector_indexers.github_indexer.get_connector_by_id", + new_callable=AsyncMock, + ) as mock_get_connector: + mock_get_connector.return_value = mock_connector + + with patch( + "app.tasks.connector_indexers.github_indexer.GitHubConnector" + ) as mock_github: + mock_github_instance = MagicMock() + mock_github_instance.get_repository_files.return_value = mock_files + mock_github_instance.get_file_content.return_value = ( + "# README\n\nTest content" + ) + mock_github.return_value = mock_github_instance + + with patch( + "app.tasks.connector_indexers.github_indexer.check_document_by_unique_identifier", + new_callable=AsyncMock, + ) as mock_check: + mock_check.return_value = None + + with patch( + "app.tasks.connector_indexers.github_indexer.get_user_long_context_llm", + new_callable=AsyncMock, + ) as mock_llm: + mock_llm.return_value = MagicMock() + + with patch( + "app.tasks.connector_indexers.github_indexer.generate_document_summary", + new_callable=AsyncMock, + ) as mock_summary: + mock_summary.return_value = ( + "Summary", + [0.1, 0.2], + ) + + with patch( + "app.tasks.connector_indexers.github_indexer.create_document_chunks", + new_callable=AsyncMock, + ) as mock_chunks: + mock_chunks.return_value = [] + + with patch( + "app.tasks.connector_indexers.github_indexer.config" + ) as mock_config: + mock_config.embedding_model_instance.embed.return_value = [ + 0.1, + 0.2, + ] + + count, error = await index_github_repos( + mock_session, + 1, + 1, + "user1", + "2024-01-01", + "2024-12-31", + ) + + assert count >= 0 + + @pytest.mark.asyncio + async def test_index_github_repos_handles_file_fetch_error(self): + """Test handling file content fetch errors.""" + from app.tasks.connector_indexers.github_indexer import index_github_repos + + mock_session = AsyncMock() + mock_connector = MagicMock() + mock_connector.config = { + "GITHUB_PAT": "ghp_test_token", + "repo_full_names": ["owner/repo"], + } + + mock_task_logger = MagicMock() + mock_task_logger.log_task_start = AsyncMock(return_value=MagicMock()) + mock_task_logger.log_task_success = AsyncMock() + mock_task_logger.log_task_progress = AsyncMock() + + mock_files = [ + {"path": "file.py", "url": "https://...", "sha": "def456", "type": "code"} + ] + + with patch( + "app.tasks.connector_indexers.github_indexer.TaskLoggingService", + return_value=mock_task_logger, + ): + with patch( + "app.tasks.connector_indexers.github_indexer.get_connector_by_id", + new_callable=AsyncMock, + ) as mock_get_connector: + mock_get_connector.return_value = mock_connector + + with patch( + "app.tasks.connector_indexers.github_indexer.GitHubConnector" + ) as mock_github: + mock_github_instance = MagicMock() + mock_github_instance.get_repository_files.return_value = mock_files + mock_github_instance.get_file_content.return_value = ( + None # File fetch failed + ) + mock_github.return_value = mock_github_instance + + count, error = await index_github_repos( + mock_session, + 1, + 1, + "user1", + "2024-01-01", + "2024-12-31", + ) + + # Should handle gracefully and continue + assert count == 0 + + +# ============================================================================ +# JIRA INDEXER TESTS +# ============================================================================ + + +class TestJiraIndexer: + """Tests for Jira connector indexer.""" + + @pytest.mark.asyncio + async def test_jira_indexer_connector_not_found(self): + """Test handling when Jira connector is not found.""" + from app.tasks.connector_indexers.jira_indexer import index_jira_issues + + mock_session = AsyncMock() + mock_task_logger = MagicMock() + mock_task_logger.log_task_start = AsyncMock(return_value=MagicMock()) + mock_task_logger.log_task_failure = AsyncMock() + mock_task_logger.log_task_progress = AsyncMock() + + with patch( + "app.tasks.connector_indexers.jira_indexer.TaskLoggingService", + return_value=mock_task_logger, + ): + with patch( + "app.tasks.connector_indexers.jira_indexer.get_connector_by_id", + new_callable=AsyncMock, + ) as mock_get_connector: + mock_get_connector.return_value = None + + count, error = await index_jira_issues( + mock_session, 999, 1, "user1", "2024-01-01", "2024-12-31" + ) + + assert count == 0 + assert "not found" in error.lower() + + +# ============================================================================ +# CONFLUENCE INDEXER TESTS +# ============================================================================ + + +class TestConfluenceIndexer: + """Tests for Confluence connector indexer.""" + + @pytest.mark.asyncio + async def test_confluence_indexer_connector_not_found(self): + """Test handling when Confluence connector is not found.""" + from app.tasks.connector_indexers.confluence_indexer import index_confluence_pages + + mock_session = AsyncMock() + mock_task_logger = MagicMock() + mock_task_logger.log_task_start = AsyncMock(return_value=MagicMock()) + mock_task_logger.log_task_failure = AsyncMock() + mock_task_logger.log_task_progress = AsyncMock() + + with patch( + "app.tasks.connector_indexers.confluence_indexer.TaskLoggingService", + return_value=mock_task_logger, + ): + with patch( + "app.tasks.connector_indexers.confluence_indexer.get_connector_by_id", + new_callable=AsyncMock, + ) as mock_get_connector: + mock_get_connector.return_value = None + + count, error = await index_confluence_pages( + mock_session, 999, 1, "user1", "2024-01-01", "2024-12-31" + ) + + assert count == 0 + assert "not found" in error.lower() + + +# ============================================================================ +# LINEAR INDEXER TESTS +# ============================================================================ + + +class TestLinearIndexer: + """Tests for Linear connector indexer.""" + + @pytest.mark.asyncio + async def test_linear_indexer_connector_not_found(self): + """Test handling when Linear connector is not found.""" + from app.tasks.connector_indexers.linear_indexer import index_linear_issues + + mock_session = AsyncMock() + mock_task_logger = MagicMock() + mock_task_logger.log_task_start = AsyncMock(return_value=MagicMock()) + mock_task_logger.log_task_failure = AsyncMock() + mock_task_logger.log_task_progress = AsyncMock() + + with patch( + "app.tasks.connector_indexers.linear_indexer.TaskLoggingService", + return_value=mock_task_logger, + ): + with patch( + "app.tasks.connector_indexers.linear_indexer.get_connector_by_id", + new_callable=AsyncMock, + ) as mock_get_connector: + mock_get_connector.return_value = None + + count, error = await index_linear_issues( + mock_session, 999, 1, "user1", "2024-01-01", "2024-12-31" + ) + + assert count == 0 + assert "not found" in error.lower() + + +# ============================================================================ +# DISCORD INDEXER TESTS +# ============================================================================ + + +class TestDiscordIndexer: + """Tests for Discord connector indexer.""" + + @pytest.mark.asyncio + async def test_discord_indexer_connector_not_found(self): + """Test handling when Discord connector is not found.""" + from app.tasks.connector_indexers.discord_indexer import index_discord_messages + + mock_session = AsyncMock() + mock_task_logger = MagicMock() + mock_task_logger.log_task_start = AsyncMock(return_value=MagicMock()) + mock_task_logger.log_task_failure = AsyncMock() + mock_task_logger.log_task_progress = AsyncMock() + + with patch( + "app.tasks.connector_indexers.discord_indexer.TaskLoggingService", + return_value=mock_task_logger, + ): + with patch( + "app.tasks.connector_indexers.discord_indexer.get_connector_by_id", + new_callable=AsyncMock, + ) as mock_get_connector: + mock_get_connector.return_value = None + + count, error = await index_discord_messages( + mock_session, 999, 1, "user1", "2024-01-01", "2024-12-31" + ) + + assert count == 0 + assert "not found" in error.lower() + + +# ============================================================================ +# GOOGLE CALENDAR INDEXER TESTS +# ============================================================================ + + +class TestGoogleCalendarIndexer: + """Tests for Google Calendar connector indexer.""" + + @pytest.mark.asyncio + async def test_google_calendar_indexer_connector_not_found(self): + """Test handling when Google Calendar connector is not found.""" + from app.tasks.connector_indexers.google_calendar_indexer import ( + index_google_calendar_events, + ) + + mock_session = AsyncMock() + mock_task_logger = MagicMock() + mock_task_logger.log_task_start = AsyncMock(return_value=MagicMock()) + mock_task_logger.log_task_failure = AsyncMock() + mock_task_logger.log_task_progress = AsyncMock() + + with patch( + "app.tasks.connector_indexers.google_calendar_indexer.TaskLoggingService", + return_value=mock_task_logger, + ): + with patch( + "app.tasks.connector_indexers.google_calendar_indexer.get_connector_by_id", + new_callable=AsyncMock, + ) as mock_get_connector: + mock_get_connector.return_value = None + + count, error = await index_google_calendar_events( + mock_session, 999, 1, "user1", "2024-01-01", "2024-12-31" + ) + + assert count == 0 + assert "not found" in error.lower() + + +# ============================================================================ +# AIRTABLE INDEXER TESTS +# ============================================================================ + + +class TestAirtableIndexer: + """Tests for Airtable connector indexer.""" + + @pytest.mark.asyncio + async def test_airtable_indexer_connector_not_found(self): + """Test handling when Airtable connector is not found.""" + from app.tasks.connector_indexers.airtable_indexer import index_airtable_records + + mock_session = AsyncMock() + mock_task_logger = MagicMock() + mock_task_logger.log_task_start = AsyncMock(return_value=MagicMock()) + mock_task_logger.log_task_failure = AsyncMock() + mock_task_logger.log_task_progress = AsyncMock() + + with patch( + "app.tasks.connector_indexers.airtable_indexer.TaskLoggingService", + return_value=mock_task_logger, + ): + with patch( + "app.tasks.connector_indexers.airtable_indexer.get_connector_by_id", + new_callable=AsyncMock, + ) as mock_get_connector: + mock_get_connector.return_value = None + + count, error = await index_airtable_records( + mock_session, 999, 1, "user1", "2024-01-01", "2024-12-31" + ) + + assert count == 0 + assert "not found" in error.lower() + + +# ============================================================================ +# WEBCRAWLER INDEXER TESTS +# ============================================================================ + + +class TestWebcrawlerIndexer: + """Tests for Webcrawler connector indexer.""" + + @pytest.mark.asyncio + async def test_webcrawler_indexer_connector_not_found(self): + """Test handling when Webcrawler connector is not found.""" + from app.tasks.connector_indexers.webcrawler_indexer import index_crawled_urls + + mock_session = AsyncMock() + mock_task_logger = MagicMock() + mock_task_logger.log_task_start = AsyncMock(return_value=MagicMock()) + mock_task_logger.log_task_failure = AsyncMock() + mock_task_logger.log_task_progress = AsyncMock() + + with patch( + "app.tasks.connector_indexers.webcrawler_indexer.TaskLoggingService", + return_value=mock_task_logger, + ): + with patch( + "app.tasks.connector_indexers.webcrawler_indexer.get_connector_by_id", + new_callable=AsyncMock, + ) as mock_get_connector: + mock_get_connector.return_value = None + + count, error = await index_crawled_urls( + mock_session, 999, 1, "user1", "2024-01-01", "2024-12-31" + ) + + assert count == 0 + assert "not found" in error.lower() + + +# ============================================================================ +# ELASTICSEARCH INDEXER TESTS +# ============================================================================ + + +class TestElasticsearchIndexer: + """Tests for Elasticsearch connector indexer.""" + + @pytest.mark.asyncio + async def test_elasticsearch_indexer_connector_not_found(self): + """Test handling when Elasticsearch connector is not found.""" + from app.tasks.connector_indexers.elasticsearch_indexer import ( + index_elasticsearch_documents, + ) + + mock_session = AsyncMock() + mock_task_logger = MagicMock() + mock_task_logger.log_task_start = AsyncMock(return_value=MagicMock()) + mock_task_logger.log_task_failure = AsyncMock() + mock_task_logger.log_task_progress = AsyncMock() + + # Mock the session.execute to return no connector + mock_result = MagicMock() + mock_result.scalars.return_value.first.return_value = None + mock_session.execute.return_value = mock_result + + with patch( + "app.tasks.connector_indexers.elasticsearch_indexer.TaskLoggingService", + return_value=mock_task_logger, + ): + count, error = await index_elasticsearch_documents( + mock_session, 999, 1, "user1", "2024-01-01", "2024-12-31" + ) + + assert count == 0 + assert "not found" in error.lower() + + +# ============================================================================ +# LUMA INDEXER TESTS +# ============================================================================ + + +class TestLumaIndexer: + """Tests for Luma connector indexer.""" + + @pytest.mark.asyncio + async def test_luma_indexer_connector_not_found(self): + """Test handling when Luma connector is not found.""" + from app.tasks.connector_indexers.luma_indexer import index_luma_events + + mock_session = AsyncMock() + mock_task_logger = MagicMock() + mock_task_logger.log_task_start = AsyncMock(return_value=MagicMock()) + mock_task_logger.log_task_failure = AsyncMock() + mock_task_logger.log_task_progress = AsyncMock() + + with patch( + "app.tasks.connector_indexers.luma_indexer.TaskLoggingService", + return_value=mock_task_logger, + ): + with patch( + "app.tasks.connector_indexers.luma_indexer.get_connector_by_id", + new_callable=AsyncMock, + ) as mock_get_connector: + mock_get_connector.return_value = None + + count, error = await index_luma_events( + mock_session, 999, 1, "user1", "2024-01-01", "2024-12-31" + ) + + assert count == 0 + assert "not found" in error.lower() + + +# ============================================================================ +# GOOGLE GMAIL INDEXER TESTS +# ============================================================================ + + +class TestGoogleGmailIndexer: + """Tests for Google Gmail connector indexer.""" + + @pytest.mark.asyncio + async def test_google_gmail_indexer_connector_not_found(self): + """Test handling when Google Gmail connector is not found.""" + from app.tasks.connector_indexers.google_gmail_indexer import ( + index_google_gmail_messages, + ) + + mock_session = AsyncMock() + mock_task_logger = MagicMock() + mock_task_logger.log_task_start = AsyncMock(return_value=MagicMock()) + mock_task_logger.log_task_failure = AsyncMock() + mock_task_logger.log_task_progress = AsyncMock() + + with patch( + "app.tasks.connector_indexers.google_gmail_indexer.TaskLoggingService", + return_value=mock_task_logger, + ): + with patch( + "app.tasks.connector_indexers.google_gmail_indexer.get_connector_by_id", + new_callable=AsyncMock, + ) as mock_get_connector: + mock_get_connector.return_value = None + + count, error = await index_google_gmail_messages( + mock_session, 999, 1, "user1", 100, 30 + ) + + assert count == 0 + assert "not found" in error.lower() + + +# ============================================================================ +# CLICKUP INDEXER TESTS +# ============================================================================ + + +class TestClickupIndexer: + """Tests for ClickUp connector indexer.""" + + @pytest.mark.asyncio + async def test_clickup_indexer_connector_not_found(self): + """Test handling when ClickUp connector is not found.""" + from app.tasks.connector_indexers.clickup_indexer import index_clickup_tasks + + mock_session = AsyncMock() + mock_task_logger = MagicMock() + mock_task_logger.log_task_start = AsyncMock(return_value=MagicMock()) + mock_task_logger.log_task_failure = AsyncMock() + mock_task_logger.log_task_progress = AsyncMock() + + with patch( + "app.tasks.connector_indexers.clickup_indexer.TaskLoggingService", + return_value=mock_task_logger, + ): + with patch( + "app.tasks.connector_indexers.clickup_indexer.get_connector_by_id", + new_callable=AsyncMock, + ) as mock_get_connector: + mock_get_connector.return_value = None + + count, error = await index_clickup_tasks( + mock_session, 999, 1, "user1", "2024-01-01", "2024-12-31" + ) + + assert count == 0 + assert "not found" in error.lower() diff --git a/surfsense_backend/tests/test_connector_service.py b/surfsense_backend/tests/test_connector_service.py new file mode 100644 index 000000000..292ed2e9a --- /dev/null +++ b/surfsense_backend/tests/test_connector_service.py @@ -0,0 +1,489 @@ +""" +Tests for the ConnectorService class. + +These tests validate: +1. Search results are properly transformed with correct structure +2. Missing connectors are handled gracefully (empty results, not errors) +3. Counter initialization is resilient to database errors +4. Search modes affect which retriever is used +""" + +import pytest +from unittest.mock import AsyncMock, MagicMock, patch + +# Skip these tests if app dependencies aren't installed +pytest.importorskip("linkup") +pytest.importorskip("litellm") + +from app.services.connector_service import ConnectorService +from app.agents.researcher.configuration import SearchMode + + +class TestConnectorServiceResilience: + """Tests for ConnectorService resilience and error handling.""" + + def test_init_sets_safe_defaults(self, mock_session): + """ + Service must initialize with safe defaults. + Critical: source_id_counter must never start at 0 (collision risk). + """ + service = ConnectorService(mock_session) + + # Must have a high starting counter to avoid collisions with existing data + assert service.source_id_counter >= 100000 + + @pytest.mark.asyncio + async def test_counter_init_survives_database_error(self, mock_session): + """ + Counter initialization must not crash on database errors. + This is critical - a DB error during init shouldn't break the service. + """ + from sqlalchemy.exc import SQLAlchemyError + + service = ConnectorService(mock_session, search_space_id=1) + mock_session.execute = AsyncMock(side_effect=SQLAlchemyError("DB error")) + + # Must not raise + await service.initialize_counter() + + # Must have a usable counter value + assert service.source_id_counter >= 1 + + @pytest.mark.asyncio + async def test_counter_init_without_search_space_is_no_op(self, mock_session): + """ + When no search_space_id is provided, counter init should be a no-op. + Calling the database without a search_space_id would be wasteful. + """ + service = ConnectorService(mock_session, search_space_id=None) + + await service.initialize_counter() + + # Should NOT have called database + mock_session.execute.assert_not_called() + + @pytest.mark.asyncio + async def test_counter_continues_from_existing_chunks(self, mock_session): + """ + Counter must continue from the highest existing source_id + 1. + Starting lower would cause ID collisions. + """ + service = ConnectorService(mock_session, search_space_id=1) + + mock_result = MagicMock() + mock_result.scalar.return_value = 500 # Max existing source_id + mock_session.execute = AsyncMock(return_value=mock_result) + + await service.initialize_counter() + + # Must be > 500 to avoid collision + assert service.source_id_counter == 501 + + +class TestSearchResultTransformation: + """ + Tests validating search result transformation produces correct output structure. + """ + + def test_transform_empty_list_returns_empty(self, mock_session): + """Empty input must return empty output - not None or error.""" + service = ConnectorService(mock_session) + result = service._transform_document_results([]) + + assert result == [] + assert isinstance(result, list) + + def test_transform_preserves_all_required_fields(self, mock_session): + """ + Transformation must preserve all fields needed by the frontend. + Missing fields would break the UI. + """ + service = ConnectorService(mock_session) + + input_docs = [ + { + "document_id": 42, + "title": "Important Doc", + "document_type": "FILE", + "metadata": {"url": "https://example.com/doc"}, + "chunks_content": "The actual content", + "score": 0.87, + } + ] + + result = service._transform_document_results(input_docs) + + assert len(result) == 1 + transformed = result[0] + + # All these fields are required by the frontend + assert "chunk_id" in transformed + assert "document" in transformed + assert "content" in transformed + assert "score" in transformed + + # Nested document structure must be correct + assert "id" in transformed["document"] + assert "title" in transformed["document"] + assert "document_type" in transformed["document"] + assert "metadata" in transformed["document"] + + def test_transform_uses_chunks_content_over_content(self, mock_session): + """ + When chunks_content exists, it should be used over content field. + This ensures full content is returned, not truncated. + """ + service = ConnectorService(mock_session) + + input_docs = [ + { + "document_id": 1, + "title": "Test", + "document_type": "FILE", + "metadata": {}, + "content": "Short preview", + "chunks_content": "Full document content that is much longer", + "score": 0.8, + } + ] + + result = service._transform_document_results(input_docs) + + # Must use chunks_content, not content + assert result[0]["content"] == "Full document content that is much longer" + + def test_transform_falls_back_to_content_when_no_chunks_content(self, mock_session): + """ + When chunks_content is missing, fall back to content field. + Must not error or return empty content. + """ + service = ConnectorService(mock_session) + + input_docs = [ + { + "document_id": 1, + "title": "Test", + "document_type": "FILE", + "metadata": {}, + "content": "Fallback content", + "score": 0.8, + } + ] + + result = service._transform_document_results(input_docs) + + assert result[0]["content"] == "Fallback content" + + +class TestMissingConnectorHandling: + """ + Tests validating graceful handling when connectors are not configured. + """ + + @pytest.mark.asyncio + async def test_missing_tavily_connector_returns_empty_not_error(self, mock_session): + """ + Missing Tavily connector must return empty results, not raise exception. + This is important - users without the connector shouldn't see errors. + """ + service = ConnectorService(mock_session, search_space_id=1) + + with patch.object( + service, "get_connector_by_type", new_callable=AsyncMock + ) as mock_get: + mock_get.return_value = None + + result_obj, docs = await service.search_tavily( + "test query", search_space_id=1 + ) + + # Must return valid structure with empty sources + assert result_obj["type"] == "TAVILY_API" + assert result_obj["sources"] == [] + assert docs == [] + # No exception should have been raised + + @pytest.mark.asyncio + async def test_missing_searxng_connector_returns_empty_not_error(self, mock_session): + """Missing SearxNG connector must return empty results gracefully.""" + service = ConnectorService(mock_session, search_space_id=1) + + with patch.object( + service, "get_connector_by_type", new_callable=AsyncMock + ) as mock_get: + mock_get.return_value = None + + result_obj, docs = await service.search_searxng( + "test query", search_space_id=1 + ) + + assert result_obj["type"] == "SEARXNG_API" + assert result_obj["sources"] == [] + + @pytest.mark.asyncio + async def test_missing_baidu_connector_returns_empty_not_error(self, mock_session): + """Missing Baidu connector must return empty results gracefully.""" + service = ConnectorService(mock_session, search_space_id=1) + + with patch.object( + service, "get_connector_by_type", new_callable=AsyncMock + ) as mock_get: + mock_get.return_value = None + + result_obj, docs = await service.search_baidu( + "test query", search_space_id=1 + ) + + assert result_obj["type"] == "BAIDU_SEARCH_API" + assert result_obj["sources"] == [] + + +class TestSearchResultStructure: + """ + Tests validating that search results have correct structure. + """ + + @pytest.mark.asyncio + async def test_crawled_urls_result_has_correct_type(self, mock_session): + """ + Crawled URL search results must have type "CRAWLED_URL". + Wrong type would break filtering in the frontend. + """ + service = ConnectorService(mock_session, search_space_id=1) + + with patch.object( + service.chunk_retriever, "hybrid_search", new_callable=AsyncMock + ) as mock_search: + mock_search.return_value = [] + + result_obj, _ = await service.search_crawled_urls( + "test query", search_space_id=1, top_k=10 + ) + + assert result_obj["type"] == "CRAWLED_URL" + + @pytest.mark.asyncio + async def test_files_result_has_correct_type(self, mock_session): + """File search results must have type "FILE".""" + service = ConnectorService(mock_session, search_space_id=1) + + with patch.object( + service.chunk_retriever, "hybrid_search", new_callable=AsyncMock + ) as mock_search: + mock_search.return_value = [] + + result_obj, _ = await service.search_files( + "test query", search_space_id=1, top_k=10 + ) + + assert result_obj["type"] == "FILE" + + @pytest.mark.asyncio + async def test_slack_result_has_correct_type(self, mock_session): + """Slack search results must have type "SLACK_CONNECTOR".""" + service = ConnectorService(mock_session, search_space_id=1) + + with patch.object( + service.chunk_retriever, "hybrid_search", new_callable=AsyncMock + ) as mock_search: + mock_search.return_value = [] + + result_obj, _ = await service.search_slack( + "test query", search_space_id=1, top_k=10 + ) + + assert result_obj["type"] == "SLACK_CONNECTOR" + + @pytest.mark.asyncio + async def test_notion_result_has_correct_type(self, mock_session): + """Notion search results must have type "NOTION_CONNECTOR".""" + service = ConnectorService(mock_session, search_space_id=1) + + with patch.object( + service.chunk_retriever, "hybrid_search", new_callable=AsyncMock + ) as mock_search: + mock_search.return_value = [] + + result_obj, _ = await service.search_notion( + "test query", search_space_id=1, top_k=10 + ) + + assert result_obj["type"] == "NOTION_CONNECTOR" + + @pytest.mark.asyncio + async def test_github_result_has_correct_type(self, mock_session): + """GitHub search results must have type "GITHUB_CONNECTOR".""" + service = ConnectorService(mock_session, search_space_id=1) + + with patch.object( + service.chunk_retriever, "hybrid_search", new_callable=AsyncMock + ) as mock_search: + mock_search.return_value = [] + + result_obj, _ = await service.search_github( + "test query", search_space_id=1, top_k=10 + ) + + assert result_obj["type"] == "GITHUB_CONNECTOR" + + @pytest.mark.asyncio + async def test_youtube_result_has_correct_type(self, mock_session): + """YouTube search results must have type "YOUTUBE_VIDEO".""" + service = ConnectorService(mock_session, search_space_id=1) + + with patch.object( + service.chunk_retriever, "hybrid_search", new_callable=AsyncMock + ) as mock_search: + mock_search.return_value = [] + + result_obj, _ = await service.search_youtube( + "test query", search_space_id=1, top_k=10 + ) + + assert result_obj["type"] == "YOUTUBE_VIDEO" + + +class TestSearchModeAffectsRetriever: + """ + Tests validating that search mode affects which retriever is used. + """ + + @pytest.mark.asyncio + async def test_documents_mode_uses_document_retriever(self, mock_session): + """ + DOCUMENTS mode must use document_retriever, not chunk_retriever. + Using wrong retriever would return wrong result granularity. + """ + service = ConnectorService(mock_session, search_space_id=1) + + mock_docs = [ + { + "document_id": 1, + "title": "Test", + "document_type": "FILE", + "metadata": {}, + "chunks_content": "content", + "score": 0.9, + } + ] + + with patch.object( + service.document_retriever, "hybrid_search", new_callable=AsyncMock + ) as mock_doc_search: + mock_doc_search.return_value = mock_docs + + with patch.object( + service.chunk_retriever, "hybrid_search", new_callable=AsyncMock + ) as mock_chunk_search: + + await service.search_files( + "test query", + search_space_id=1, + top_k=10, + search_mode=SearchMode.DOCUMENTS, + ) + + # Document retriever should have been called + mock_doc_search.assert_called_once() + # Chunk retriever should NOT have been called + mock_chunk_search.assert_not_called() + + @pytest.mark.asyncio + async def test_chunks_mode_uses_chunk_retriever(self, mock_session): + """ + Default/CHUNKS mode must use chunk_retriever. + """ + service = ConnectorService(mock_session, search_space_id=1) + + with patch.object( + service.chunk_retriever, "hybrid_search", new_callable=AsyncMock + ) as mock_chunk_search: + mock_chunk_search.return_value = [] + + with patch.object( + service.document_retriever, "hybrid_search", new_callable=AsyncMock + ) as mock_doc_search: + + await service.search_files( + "test query", + search_space_id=1, + top_k=10, + # Default mode (no search_mode specified) + ) + + # Chunk retriever should have been called + mock_chunk_search.assert_called_once() + # Document retriever should NOT have been called + mock_doc_search.assert_not_called() + + +class TestSearchResultMetadataExtraction: + """ + Tests validating that metadata is correctly extracted for different source types. + """ + + @pytest.mark.asyncio + async def test_crawled_url_extracts_source_as_url(self, mock_session): + """ + Crawled URL results must extract 'source' from metadata as URL. + Wrong field would break link navigation. + """ + service = ConnectorService(mock_session, search_space_id=1) + + mock_chunks = [ + { + "chunk_id": 1, + "content": "Page content", + "document": { + "title": "Web Page", + "metadata": {"source": "https://example.com/page"}, + }, + } + ] + + with patch.object( + service.chunk_retriever, "hybrid_search", new_callable=AsyncMock + ) as mock_search: + mock_search.return_value = mock_chunks + + result_obj, _ = await service.search_crawled_urls( + "test", search_space_id=1, top_k=10 + ) + + assert result_obj["sources"][0]["url"] == "https://example.com/page" + + @pytest.mark.asyncio + async def test_youtube_extracts_video_metadata(self, mock_session): + """ + YouTube results must extract video_id and other video metadata. + Missing video_id would break video embedding. + """ + service = ConnectorService(mock_session, search_space_id=1) + + mock_chunks = [ + { + "chunk_id": 1, + "content": "Transcript", + "document": { + "title": "YouTube", + "metadata": { + "video_title": "Test Video", + "video_id": "dQw4w9WgXcQ", + "channel_name": "Test Channel", + }, + }, + } + ] + + with patch.object( + service.chunk_retriever, "hybrid_search", new_callable=AsyncMock + ) as mock_search: + mock_search.return_value = mock_chunks + + result_obj, _ = await service.search_youtube( + "test", search_space_id=1, top_k=10 + ) + + source = result_obj["sources"][0] + assert source["video_id"] == "dQw4w9WgXcQ" + assert "Test Video" in source["title"] diff --git a/surfsense_backend/tests/test_connector_service_extended.py b/surfsense_backend/tests/test_connector_service_extended.py new file mode 100644 index 000000000..cf200d748 --- /dev/null +++ b/surfsense_backend/tests/test_connector_service_extended.py @@ -0,0 +1,490 @@ +""" +Extended tests for connector service. +Tests the ConnectorService class with mocked database and external dependencies. +""" +import pytest +from unittest.mock import AsyncMock, MagicMock + +from app.services.connector_service import ConnectorService +from app.agents.researcher.configuration import SearchMode + + +class TestConnectorServiceInitialization: + """Tests for ConnectorService initialization.""" + + def test_init_with_search_space_id(self): + """Test initialization with search space ID.""" + mock_session = AsyncMock() + service = ConnectorService(mock_session, search_space_id=1) + + assert service.session == mock_session + assert service.search_space_id == 1 + assert service.source_id_counter == 100000 + + def test_init_without_search_space_id(self): + """Test initialization without search space ID.""" + mock_session = AsyncMock() + service = ConnectorService(mock_session) + + assert service.search_space_id is None + + @pytest.mark.asyncio + async def test_initialize_counter_success(self): + """Test counter initialization from database.""" + mock_session = AsyncMock() + mock_result = MagicMock() + mock_result.scalar.return_value = 50 + mock_session.execute = AsyncMock(return_value=mock_result) + + service = ConnectorService(mock_session, search_space_id=1) + await service.initialize_counter() + + assert service.source_id_counter == 51 + + @pytest.mark.asyncio + async def test_initialize_counter_database_error(self): + """Test counter initialization handles database errors gracefully.""" + from sqlalchemy.exc import SQLAlchemyError + + mock_session = AsyncMock() + mock_session.execute = AsyncMock(side_effect=SQLAlchemyError("DB Error")) + + service = ConnectorService(mock_session, search_space_id=1) + await service.initialize_counter() + + # Should fallback to 1 + assert service.source_id_counter == 1 + + +class TestSearchCrawledUrls: + """Tests for search_crawled_urls method.""" + + @pytest.mark.asyncio + async def test_search_crawled_urls_empty_results(self): + """Test search with no results.""" + mock_session = AsyncMock() + service = ConnectorService(mock_session, search_space_id=1) + + # Mock the chunk retriever + service.chunk_retriever = MagicMock() + service.chunk_retriever.hybrid_search = AsyncMock(return_value=[]) + + result, chunks = await service.search_crawled_urls( + user_query="test query", + search_space_id=1, + top_k=20, + search_mode=SearchMode.CHUNKS, + ) + + assert result["type"] == "CRAWLED_URL" + assert result["sources"] == [] + assert chunks == [] + + @pytest.mark.asyncio + async def test_search_crawled_urls_with_results(self): + """Test search with results.""" + mock_session = AsyncMock() + service = ConnectorService(mock_session, search_space_id=1) + + # Mock the chunk retriever + mock_chunks = [ + { + "chunk_id": 1, + "content": "Test content", + "document": { + "title": "Test Document", + "metadata": { + "source": "https://example.com", + "description": "Test description", + }, + }, + } + ] + service.chunk_retriever = MagicMock() + service.chunk_retriever.hybrid_search = AsyncMock(return_value=mock_chunks) + + result, chunks = await service.search_crawled_urls( + user_query="test query", + search_space_id=1, + top_k=20, + search_mode=SearchMode.CHUNKS, + ) + + assert result["type"] == "CRAWLED_URL" + assert len(result["sources"]) == 1 + assert result["sources"][0]["title"] == "Test Document" + assert len(chunks) == 1 + + +class TestSearchFiles: + """Tests for search_files method.""" + + @pytest.mark.asyncio + async def test_search_files_empty_results(self): + """Test file search with no results.""" + mock_session = AsyncMock() + service = ConnectorService(mock_session, search_space_id=1) + + # Mock the chunk retriever + service.chunk_retriever = MagicMock() + service.chunk_retriever.hybrid_search = AsyncMock(return_value=[]) + + result, chunks = await service.search_files( + user_query="test query", + search_space_id=1, + top_k=20, + search_mode=SearchMode.CHUNKS, + ) + + assert result["type"] == "FILE" + assert result["sources"] == [] + assert chunks == [] + + +class TestSearchDocuments: + """Tests for document search mode.""" + + @pytest.mark.asyncio + async def test_search_uses_document_retriever_in_documents_mode(self): + """Test that document mode uses document retriever.""" + mock_session = AsyncMock() + service = ConnectorService(mock_session, search_space_id=1) + + # Mock both retrievers + service.chunk_retriever = MagicMock() + service.chunk_retriever.hybrid_search = AsyncMock(return_value=[]) + service.document_retriever = MagicMock() + service.document_retriever.hybrid_search = AsyncMock(return_value=[]) + + await service.search_crawled_urls( + user_query="test query", + search_space_id=1, + top_k=20, + search_mode=SearchMode.DOCUMENTS, + ) + + # Document retriever should be called, not chunk retriever + assert service.document_retriever.hybrid_search.called + + +class TestTransformDocumentResults: + """Tests for _transform_document_results method.""" + + def test_transform_empty_list(self): + """Test transformation of empty results.""" + mock_session = AsyncMock() + service = ConnectorService(mock_session) + + result = service._transform_document_results([]) + + assert result == [] + + def test_transform_document_with_chunks_content(self): + """Test transformation uses chunks_content when available.""" + mock_session = AsyncMock() + service = ConnectorService(mock_session) + + input_docs = [ + { + "document_id": 1, + "title": "Test", + "document_type": "FILE", + "metadata": {}, + "content": "Short", + "chunks_content": "Full content from chunks", + "score": 0.8, + } + ] + + result = service._transform_document_results(input_docs) + + assert len(result) == 1 + assert result[0]["content"] == "Full content from chunks" + + def test_transform_document_falls_back_to_content(self): + """Test transformation falls back to content when no chunks_content.""" + mock_session = AsyncMock() + service = ConnectorService(mock_session) + + input_docs = [ + { + "document_id": 1, + "title": "Test", + "document_type": "FILE", + "metadata": {}, + "content": "Only content available", + "score": 0.8, + } + ] + + result = service._transform_document_results(input_docs) + + assert len(result) == 1 + assert result[0]["content"] == "Only content available" + + +class TestSearchExtension: + """Tests for extension document search.""" + + @pytest.mark.asyncio + async def test_search_extension_documents(self): + """Test searching extension documents.""" + mock_session = AsyncMock() + service = ConnectorService(mock_session, search_space_id=1) + + # Mock the chunk retriever + mock_chunks = [ + { + "chunk_id": 1, + "content": "Browser captured content", + "document": { + "title": "Web Page Title", + "metadata": { + "url": "https://example.com/page", + "BrowsingSessionId": "session-123", + }, + }, + } + ] + service.chunk_retriever = MagicMock() + service.chunk_retriever.hybrid_search = AsyncMock(return_value=mock_chunks) + + result, chunks = await service.search_extension( + user_query="test", + search_space_id=1, + top_k=20, + search_mode=SearchMode.CHUNKS, + ) + + assert result["type"] == "EXTENSION" + assert len(result["sources"]) == 1 + + +class TestSearchSlack: + """Tests for Slack connector search.""" + + @pytest.mark.asyncio + async def test_search_slack_documents(self): + """Test searching Slack documents.""" + mock_session = AsyncMock() + service = ConnectorService(mock_session, search_space_id=1) + + # Mock the chunk retriever + mock_chunks = [ + { + "chunk_id": 1, + "content": "Slack message content", + "document": { + "title": "Slack Channel - #general", + "metadata": { + "channel_name": "general", + "username": "john_doe", + "timestamp": "2024-01-01T12:00:00Z", + }, + }, + } + ] + service.chunk_retriever = MagicMock() + service.chunk_retriever.hybrid_search = AsyncMock(return_value=mock_chunks) + + result, chunks = await service.search_slack( + user_query="test", + search_space_id=1, + top_k=20, + search_mode=SearchMode.CHUNKS, + ) + + assert result["type"] == "SLACK_CONNECTOR" + assert len(result["sources"]) == 1 + + +class TestSearchNotion: + """Tests for Notion connector search.""" + + @pytest.mark.asyncio + async def test_search_notion_documents(self): + """Test searching Notion documents.""" + mock_session = AsyncMock() + service = ConnectorService(mock_session, search_space_id=1) + + # Mock the chunk retriever + mock_chunks = [ + { + "chunk_id": 1, + "content": "Notion page content", + "document": { + "title": "Meeting Notes", + "metadata": { + "page_id": "notion-page-123", + "url": "https://notion.so/page", + }, + }, + } + ] + service.chunk_retriever = MagicMock() + service.chunk_retriever.hybrid_search = AsyncMock(return_value=mock_chunks) + + result, chunks = await service.search_notion( + user_query="test", + search_space_id=1, + top_k=20, + search_mode=SearchMode.CHUNKS, + ) + + assert result["type"] == "NOTION_CONNECTOR" + assert len(result["sources"]) == 1 + + +class TestSearchYoutube: + """Tests for YouTube document search.""" + + @pytest.mark.asyncio + async def test_search_youtube_documents(self): + """Test searching YouTube documents.""" + mock_session = AsyncMock() + service = ConnectorService(mock_session, search_space_id=1) + + # Mock the chunk retriever + mock_chunks = [ + { + "chunk_id": 1, + "content": "Video transcript content", + "document": { + "title": "YouTube Video Title", + "metadata": { + "video_id": "dQw4w9WgXcQ", + "channel": "Channel Name", + "duration": "3:45", + }, + }, + } + ] + service.chunk_retriever = MagicMock() + service.chunk_retriever.hybrid_search = AsyncMock(return_value=mock_chunks) + + result, chunks = await service.search_youtube( + user_query="test", + search_space_id=1, + top_k=20, + search_mode=SearchMode.CHUNKS, + ) + + assert result["type"] == "YOUTUBE_VIDEO" + assert len(result["sources"]) == 1 + + +class TestSearchGithub: + """Tests for GitHub connector search.""" + + @pytest.mark.asyncio + async def test_search_github_documents(self): + """Test searching GitHub documents.""" + mock_session = AsyncMock() + service = ConnectorService(mock_session, search_space_id=1) + + # Mock the chunk retriever + mock_chunks = [ + { + "chunk_id": 1, + "content": "Code content from GitHub", + "document": { + "title": "repo/file.py", + "metadata": { + "repo": "owner/repo", + "path": "src/file.py", + "branch": "main", + }, + }, + } + ] + service.chunk_retriever = MagicMock() + service.chunk_retriever.hybrid_search = AsyncMock(return_value=mock_chunks) + + result, chunks = await service.search_github( + user_query="test", + search_space_id=1, + top_k=20, + search_mode=SearchMode.CHUNKS, + ) + + assert result["type"] == "GITHUB_CONNECTOR" + assert len(result["sources"]) == 1 + + +class TestExternalSearchConnectors: + """Tests for external search API connectors.""" + + @pytest.mark.asyncio + async def test_tavily_search_no_connector(self): + """Test Tavily search returns empty when no connector configured.""" + mock_session = AsyncMock() + + # Mock no connector found + mock_result = MagicMock() + mock_scalars = MagicMock() + mock_scalars.first.return_value = None + mock_result.scalars.return_value = mock_scalars + mock_session.execute = AsyncMock(return_value=mock_result) + + service = ConnectorService(mock_session, search_space_id=1) + + result = await service.search_tavily( + user_query="test", + search_space_id=1, + ) + + # Returns a tuple (sources_info_dict, documents_list) + sources_info, documents = result + assert sources_info["type"] == "TAVILY_API" + assert sources_info["sources"] == [] + assert documents == [] + + @pytest.mark.asyncio + async def test_linkup_search_no_connector(self): + """Test Linkup search returns empty when no connector configured.""" + mock_session = AsyncMock() + + # Mock no connector found + mock_result = MagicMock() + mock_scalars = MagicMock() + mock_scalars.first.return_value = None + mock_result.scalars.return_value = mock_scalars + mock_session.execute = AsyncMock(return_value=mock_result) + + service = ConnectorService(mock_session, search_space_id=1) + + result = await service.search_linkup( + user_query="test", + search_space_id=1, + ) + + # Returns a tuple (sources_info_dict, documents_list) + sources_info, documents = result + assert sources_info["type"] == "LINKUP_API" + assert sources_info["sources"] == [] + assert documents == [] + + @pytest.mark.asyncio + async def test_searxng_search_no_connector(self): + """Test SearXNG search returns empty when no connector configured.""" + mock_session = AsyncMock() + + # Mock no connector found + mock_result = MagicMock() + mock_scalars = MagicMock() + mock_scalars.first.return_value = None + mock_result.scalars.return_value = mock_scalars + mock_session.execute = AsyncMock(return_value=mock_result) + + service = ConnectorService(mock_session, search_space_id=1) + + result = await service.search_searxng( + user_query="test", + search_space_id=1, + ) + + # Returns a tuple (sources_info_dict, documents_list) + sources_info, documents = result + assert sources_info["type"] == "SEARXNG_API" + assert sources_info["sources"] == [] + assert documents == [] diff --git a/surfsense_backend/tests/test_llm_service.py b/surfsense_backend/tests/test_llm_service.py new file mode 100644 index 000000000..03667ed0d --- /dev/null +++ b/surfsense_backend/tests/test_llm_service.py @@ -0,0 +1,307 @@ +""" +Tests for the LLM service module. + +These tests validate: +1. LLM role constants have correct values (used for routing) +2. Global vs user-space LLM config lookup is correct +3. Missing LLMs are handled gracefully (return None, not crash) +4. Role-to-LLM mapping is correct (fast -> fast_llm_id, etc.) +""" + +import pytest +from unittest.mock import AsyncMock, MagicMock, patch + +# Skip these tests if app dependencies aren't installed +pytest.importorskip("litellm") + +from app.services.llm_service import ( + LLMRole, + get_global_llm_config, + get_fast_llm, + get_long_context_llm, + get_strategic_llm, + get_search_space_llm_instance, +) + + +class TestLLMRoleConstants: + """ + Tests for LLMRole constants. + These values are used for database lookups and must be stable. + """ + + def test_role_constants_are_strings(self): + """LLM role values must be strings for database compatibility.""" + assert isinstance(LLMRole.LONG_CONTEXT, str) + assert isinstance(LLMRole.FAST, str) + assert isinstance(LLMRole.STRATEGIC, str) + + def test_role_values_are_unique(self): + """Role values must be unique to prevent routing confusion.""" + roles = [LLMRole.LONG_CONTEXT, LLMRole.FAST, LLMRole.STRATEGIC] + assert len(roles) == len(set(roles)) + + def test_expected_role_values(self): + """ + Validate exact role values. + These are used in the database schema and must not change. + """ + assert LLMRole.LONG_CONTEXT == "long_context" + assert LLMRole.FAST == "fast" + assert LLMRole.STRATEGIC == "strategic" + + +class TestGlobalLLMConfigLookup: + """ + Tests validating global (negative ID) LLM config lookup behavior. + """ + + def test_positive_id_never_returns_global_config(self): + """ + Positive IDs are user-space configs, must never match global. + Returning a global config for a user ID would be a security issue. + """ + result = get_global_llm_config(1) + assert result is None + + result = get_global_llm_config(100) + assert result is None + + result = get_global_llm_config(999999) + assert result is None + + def test_zero_id_never_returns_global_config(self): + """Zero is not a valid global config ID.""" + result = get_global_llm_config(0) + assert result is None + + @patch("app.services.llm_service.config") + def test_negative_id_matches_correct_global_config(self, mock_config): + """ + Negative IDs should match global configs by exact ID. + Wrong matching would return wrong model configuration. + """ + mock_config.GLOBAL_LLM_CONFIGS = [ + {"id": -1, "provider": "OPENAI", "model_name": "gpt-4"}, + {"id": -2, "provider": "ANTHROPIC", "model_name": "claude-3"}, + {"id": -3, "provider": "GOOGLE", "model_name": "gemini-pro"}, + ] + + # Each ID should return its exact match + result_1 = get_global_llm_config(-1) + assert result_1["id"] == -1 + assert result_1["provider"] == "OPENAI" + + result_2 = get_global_llm_config(-2) + assert result_2["id"] == -2 + assert result_2["provider"] == "ANTHROPIC" + + result_3 = get_global_llm_config(-3) + assert result_3["id"] == -3 + assert result_3["provider"] == "GOOGLE" + + @patch("app.services.llm_service.config") + def test_non_existent_negative_id_returns_none(self, mock_config): + """Non-existent global config IDs must return None, not error.""" + mock_config.GLOBAL_LLM_CONFIGS = [ + {"id": -1, "provider": "OPENAI", "model_name": "gpt-4"}, + ] + + result = get_global_llm_config(-999) + assert result is None + + +class TestSearchSpaceLLMInstanceRetrieval: + """ + Tests for search space LLM instance retrieval. + Validates correct role-to-field mapping and graceful error handling. + """ + + @pytest.mark.asyncio + async def test_nonexistent_search_space_returns_none(self, mock_session): + """ + Missing search space must return None, not raise exception. + This prevents crashes when search spaces are deleted. + """ + mock_result = MagicMock() + mock_result.scalars.return_value.first.return_value = None + mock_session.execute = AsyncMock(return_value=mock_result) + + result = await get_search_space_llm_instance( + mock_session, search_space_id=999, role=LLMRole.FAST + ) + + assert result is None + + @pytest.mark.asyncio + async def test_invalid_role_returns_none(self, mock_session): + """ + Invalid role must return None to prevent undefined behavior. + """ + mock_search_space = MagicMock() + mock_search_space.fast_llm_id = 1 + + mock_result = MagicMock() + mock_result.scalars.return_value.first.return_value = mock_search_space + mock_session.execute = AsyncMock(return_value=mock_result) + + result = await get_search_space_llm_instance( + mock_session, search_space_id=1, role="not_a_valid_role" + ) + + assert result is None + + @pytest.mark.asyncio + async def test_unconfigured_llm_returns_none(self, mock_session): + """ + When no LLM is configured for a role, return None. + This is a valid state - not all search spaces have all LLMs. + """ + mock_search_space = MagicMock() + mock_search_space.fast_llm_id = None + mock_search_space.long_context_llm_id = None + mock_search_space.strategic_llm_id = None + + mock_result = MagicMock() + mock_result.scalars.return_value.first.return_value = mock_search_space + mock_session.execute = AsyncMock(return_value=mock_result) + + result = await get_search_space_llm_instance( + mock_session, search_space_id=1, role=LLMRole.FAST + ) + + assert result is None + + @pytest.mark.asyncio + @patch("app.services.llm_service.get_global_llm_config") + @patch("app.services.llm_service.ChatLiteLLM") + async def test_global_config_creates_llm_instance( + self, mock_chat_litellm, mock_get_global, mock_session + ): + """ + Global config (negative ID) should create an LLM instance. + """ + mock_search_space = MagicMock() + mock_search_space.fast_llm_id = -1 # Global config ID + + mock_result = MagicMock() + mock_result.scalars.return_value.first.return_value = mock_search_space + mock_session.execute = AsyncMock(return_value=mock_result) + + mock_get_global.return_value = { + "id": -1, + "provider": "OPENAI", + "model_name": "gpt-4", + "api_key": "test-key", + } + + mock_llm_instance = MagicMock() + mock_chat_litellm.return_value = mock_llm_instance + + result = await get_search_space_llm_instance( + mock_session, search_space_id=1, role=LLMRole.FAST + ) + + # Must return an LLM instance + assert result == mock_llm_instance + # Must have attempted to create ChatLiteLLM + mock_chat_litellm.assert_called_once() + + @pytest.mark.asyncio + @patch("app.services.llm_service.get_global_llm_config") + async def test_missing_global_config_returns_none( + self, mock_get_global, mock_session + ): + """ + If global config ID is set but config doesn't exist, return None. + This handles config deletion gracefully. + """ + mock_search_space = MagicMock() + mock_search_space.fast_llm_id = -1 # Global ID that doesn't exist + + mock_result = MagicMock() + mock_result.scalars.return_value.first.return_value = mock_search_space + mock_session.execute = AsyncMock(return_value=mock_result) + + mock_get_global.return_value = None # Config not found + + result = await get_search_space_llm_instance( + mock_session, search_space_id=1, role=LLMRole.FAST + ) + + assert result is None + + +class TestRoleToLLMMapping: + """ + Tests validating that convenience functions map to correct roles. + Wrong mapping would use wrong model (e.g., slow model for fast tasks). + """ + + @pytest.mark.asyncio + @patch("app.services.llm_service.get_search_space_llm_instance") + async def test_get_fast_llm_uses_fast_role(self, mock_get_instance, mock_session): + """get_fast_llm must request LLMRole.FAST specifically.""" + mock_llm = MagicMock() + mock_get_instance.return_value = mock_llm + + await get_fast_llm(mock_session, search_space_id=1) + + mock_get_instance.assert_called_once_with( + mock_session, 1, LLMRole.FAST + ) + + @pytest.mark.asyncio + @patch("app.services.llm_service.get_search_space_llm_instance") + async def test_get_long_context_llm_uses_long_context_role( + self, mock_get_instance, mock_session + ): + """get_long_context_llm must request LLMRole.LONG_CONTEXT specifically.""" + mock_llm = MagicMock() + mock_get_instance.return_value = mock_llm + + await get_long_context_llm(mock_session, search_space_id=1) + + mock_get_instance.assert_called_once_with( + mock_session, 1, LLMRole.LONG_CONTEXT + ) + + @pytest.mark.asyncio + @patch("app.services.llm_service.get_search_space_llm_instance") + async def test_get_strategic_llm_uses_strategic_role( + self, mock_get_instance, mock_session + ): + """get_strategic_llm must request LLMRole.STRATEGIC specifically.""" + mock_llm = MagicMock() + mock_get_instance.return_value = mock_llm + + await get_strategic_llm(mock_session, search_space_id=1) + + mock_get_instance.assert_called_once_with( + mock_session, 1, LLMRole.STRATEGIC + ) + + @pytest.mark.asyncio + @patch("app.services.llm_service.get_search_space_llm_instance") + async def test_convenience_functions_return_llm_instance( + self, mock_get_instance, mock_session + ): + """Convenience functions must return the LLM instance unchanged.""" + mock_llm = MagicMock() + mock_llm.model_name = "test-model" + mock_get_instance.return_value = mock_llm + + fast = await get_fast_llm(mock_session, search_space_id=1) + assert fast == mock_llm + + mock_get_instance.reset_mock() + mock_get_instance.return_value = mock_llm + + long_context = await get_long_context_llm(mock_session, search_space_id=1) + assert long_context == mock_llm + + mock_get_instance.reset_mock() + mock_get_instance.return_value = mock_llm + + strategic = await get_strategic_llm(mock_session, search_space_id=1) + assert strategic == mock_llm diff --git a/surfsense_backend/tests/test_llm_service_extended.py b/surfsense_backend/tests/test_llm_service_extended.py new file mode 100644 index 000000000..0f9ca85e1 --- /dev/null +++ b/surfsense_backend/tests/test_llm_service_extended.py @@ -0,0 +1,433 @@ +""" +Extended tests for LLM service. +Tests LLM configuration validation and instance creation. +""" +import pytest +from unittest.mock import AsyncMock, MagicMock, patch + +from app.services.llm_service import ( + LLMRole, + get_global_llm_config, + validate_llm_config, + get_search_space_llm_instance, + get_long_context_llm, + get_fast_llm, + get_strategic_llm, +) + + +class TestLLMRoleExtended: + """Extended tests for LLMRole constants.""" + + def test_role_long_context(self): + """Test long context role value.""" + assert LLMRole.LONG_CONTEXT == "long_context" + + def test_role_fast(self): + """Test fast role value.""" + assert LLMRole.FAST == "fast" + + def test_role_strategic(self): + """Test strategic role value.""" + assert LLMRole.STRATEGIC == "strategic" + + +class TestGetGlobalLLMConfig: + """Tests for get_global_llm_config function.""" + + def test_returns_none_for_positive_id(self): + """Test that positive IDs return None.""" + result = get_global_llm_config(1) + assert result is None + + def test_returns_none_for_zero_id(self): + """Test that zero ID returns None.""" + result = get_global_llm_config(0) + assert result is None + + def test_returns_config_for_matching_negative_id(self): + """Test that matching negative ID returns config.""" + with patch("app.services.llm_service.config") as mock_config: + mock_config.GLOBAL_LLM_CONFIGS = [ + {"id": -1, "name": "GPT-4", "provider": "OPENAI"}, + {"id": -2, "name": "Claude", "provider": "ANTHROPIC"}, + ] + + result = get_global_llm_config(-1) + + assert result is not None + assert result["name"] == "GPT-4" + + def test_returns_none_for_non_matching_negative_id(self): + """Test that non-matching negative ID returns None.""" + with patch("app.services.llm_service.config") as mock_config: + mock_config.GLOBAL_LLM_CONFIGS = [ + {"id": -1, "name": "GPT-4"}, + ] + + result = get_global_llm_config(-999) + + assert result is None + + +class TestValidateLLMConfig: + """Tests for validate_llm_config function.""" + + @pytest.mark.asyncio + async def test_validate_llm_config_success(self): + """Test successful LLM config validation.""" + with patch("app.services.llm_service.ChatLiteLLM") as MockChatLiteLLM: + mock_llm = MagicMock() + mock_response = MagicMock() + mock_response.content = "Hello!" + mock_llm.ainvoke = AsyncMock(return_value=mock_response) + MockChatLiteLLM.return_value = mock_llm + + is_valid, error = await validate_llm_config( + provider="OPENAI", + model_name="gpt-4", + api_key="sk-test-key", + ) + + assert is_valid is True + assert error == "" + + @pytest.mark.asyncio + async def test_validate_llm_config_empty_response(self): + """Test validation fails with empty response.""" + with patch("app.services.llm_service.ChatLiteLLM") as MockChatLiteLLM: + mock_llm = MagicMock() + mock_response = MagicMock() + mock_response.content = "" + mock_llm.ainvoke = AsyncMock(return_value=mock_response) + MockChatLiteLLM.return_value = mock_llm + + is_valid, error = await validate_llm_config( + provider="OPENAI", + model_name="gpt-4", + api_key="sk-test-key", + ) + + assert is_valid is False + assert "empty response" in error.lower() + + @pytest.mark.asyncio + async def test_validate_llm_config_exception(self): + """Test validation handles exceptions.""" + with patch("app.services.llm_service.ChatLiteLLM") as MockChatLiteLLM: + mock_llm = MagicMock() + mock_llm.ainvoke = AsyncMock(side_effect=Exception("API Error")) + MockChatLiteLLM.return_value = mock_llm + + is_valid, error = await validate_llm_config( + provider="OPENAI", + model_name="gpt-4", + api_key="sk-invalid-key", + ) + + assert is_valid is False + assert "API Error" in error + + @pytest.mark.asyncio + async def test_validate_llm_config_with_custom_provider(self): + """Test validation with custom provider.""" + with patch("app.services.llm_service.ChatLiteLLM") as MockChatLiteLLM: + mock_llm = MagicMock() + mock_response = MagicMock() + mock_response.content = "Hello!" + mock_llm.ainvoke = AsyncMock(return_value=mock_response) + MockChatLiteLLM.return_value = mock_llm + + is_valid, error = await validate_llm_config( + provider="OPENAI", + model_name="custom-model", + api_key="sk-test-key", + custom_provider="custom/provider", + ) + + assert is_valid is True + # Verify custom provider was used in model string + call_args = MockChatLiteLLM.call_args + assert "custom/provider" in call_args.kwargs.get("model", "") + + @pytest.mark.asyncio + async def test_validate_llm_config_with_api_base(self): + """Test validation with custom API base.""" + with patch("app.services.llm_service.ChatLiteLLM") as MockChatLiteLLM: + mock_llm = MagicMock() + mock_response = MagicMock() + mock_response.content = "Hello!" + mock_llm.ainvoke = AsyncMock(return_value=mock_response) + MockChatLiteLLM.return_value = mock_llm + + is_valid, error = await validate_llm_config( + provider="OPENAI", + model_name="gpt-4", + api_key="sk-test-key", + api_base="https://custom.api.com", + ) + + assert is_valid is True + call_args = MockChatLiteLLM.call_args + assert call_args.kwargs.get("api_base") == "https://custom.api.com" + + +class TestGetSearchSpaceLLMInstance: + """Tests for get_search_space_llm_instance function.""" + + @pytest.mark.asyncio + async def test_returns_none_for_nonexistent_search_space(self): + """Test returns None when search space doesn't exist.""" + mock_session = AsyncMock() + mock_result = MagicMock() + mock_scalars = MagicMock() + mock_scalars.first.return_value = None + mock_result.scalars.return_value = mock_scalars + mock_session.execute = AsyncMock(return_value=mock_result) + + result = await get_search_space_llm_instance( + session=mock_session, + search_space_id=999, + role=LLMRole.FAST, + ) + + assert result is None + + @pytest.mark.asyncio + async def test_returns_none_for_invalid_role(self): + """Test returns None for invalid role.""" + mock_session = AsyncMock() + + # Mock search space exists + mock_search_space = MagicMock() + mock_search_space.id = 1 + mock_search_space.fast_llm_id = 1 + + mock_result = MagicMock() + mock_scalars = MagicMock() + mock_scalars.first.return_value = mock_search_space + mock_result.scalars.return_value = mock_scalars + mock_session.execute = AsyncMock(return_value=mock_result) + + result = await get_search_space_llm_instance( + session=mock_session, + search_space_id=1, + role="invalid_role", + ) + + assert result is None + + @pytest.mark.asyncio + async def test_returns_none_when_no_llm_configured(self): + """Test returns None when LLM is not configured for role.""" + mock_session = AsyncMock() + + # Mock search space with no LLM configured + mock_search_space = MagicMock() + mock_search_space.id = 1 + mock_search_space.fast_llm_id = None + mock_search_space.long_context_llm_id = None + mock_search_space.strategic_llm_id = None + + mock_result = MagicMock() + mock_scalars = MagicMock() + mock_scalars.first.return_value = mock_search_space + mock_result.scalars.return_value = mock_scalars + mock_session.execute = AsyncMock(return_value=mock_result) + + result = await get_search_space_llm_instance( + session=mock_session, + search_space_id=1, + role=LLMRole.FAST, + ) + + assert result is None + + @pytest.mark.asyncio + async def test_returns_instance_for_global_config(self): + """Test returns LLM instance for global config.""" + mock_session = AsyncMock() + + # Mock search space with global config + mock_search_space = MagicMock() + mock_search_space.id = 1 + mock_search_space.fast_llm_id = -1 # Global config + + mock_result = MagicMock() + mock_scalars = MagicMock() + mock_scalars.first.return_value = mock_search_space + mock_result.scalars.return_value = mock_scalars + mock_session.execute = AsyncMock(return_value=mock_result) + + with patch("app.services.llm_service.config") as mock_config: + mock_config.GLOBAL_LLM_CONFIGS = [ + { + "id": -1, + "name": "GPT-4", + "provider": "OPENAI", + "model_name": "gpt-4", + "api_key": "sk-test", + "api_base": None, + "custom_provider": None, + "litellm_params": None, + } + ] + + with patch("app.services.llm_service.ChatLiteLLM") as MockChatLiteLLM: + mock_llm = MagicMock() + MockChatLiteLLM.return_value = mock_llm + + result = await get_search_space_llm_instance( + session=mock_session, + search_space_id=1, + role=LLMRole.FAST, + ) + + assert result is not None + assert MockChatLiteLLM.called + + +class TestConvenienceFunctions: + """Tests for convenience wrapper functions.""" + + @pytest.mark.asyncio + async def test_get_long_context_llm(self): + """Test get_long_context_llm uses correct role.""" + mock_session = AsyncMock() + + with patch("app.services.llm_service.get_search_space_llm_instance") as mock_get: + mock_get.return_value = MagicMock() + + await get_long_context_llm(mock_session, 1) + + mock_get.assert_called_once_with(mock_session, 1, LLMRole.LONG_CONTEXT) + + @pytest.mark.asyncio + async def test_get_fast_llm(self): + """Test get_fast_llm uses correct role.""" + mock_session = AsyncMock() + + with patch("app.services.llm_service.get_search_space_llm_instance") as mock_get: + mock_get.return_value = MagicMock() + + await get_fast_llm(mock_session, 1) + + mock_get.assert_called_once_with(mock_session, 1, LLMRole.FAST) + + @pytest.mark.asyncio + async def test_get_strategic_llm(self): + """Test get_strategic_llm uses correct role.""" + mock_session = AsyncMock() + + with patch("app.services.llm_service.get_search_space_llm_instance") as mock_get: + mock_get.return_value = MagicMock() + + await get_strategic_llm(mock_session, 1) + + mock_get.assert_called_once_with(mock_session, 1, LLMRole.STRATEGIC) + + +class TestProviderMapping: + """Tests for provider string mapping.""" + + @pytest.mark.asyncio + async def test_openai_provider_mapping(self): + """Test OPENAI maps to openai.""" + with patch("app.services.llm_service.ChatLiteLLM") as MockChatLiteLLM: + mock_llm = MagicMock() + mock_response = MagicMock() + mock_response.content = "Hello!" + mock_llm.ainvoke = AsyncMock(return_value=mock_response) + MockChatLiteLLM.return_value = mock_llm + + await validate_llm_config( + provider="OPENAI", + model_name="gpt-4", + api_key="sk-test", + ) + + call_args = MockChatLiteLLM.call_args + assert "openai/gpt-4" in call_args.kwargs.get("model", "") + + @pytest.mark.asyncio + async def test_anthropic_provider_mapping(self): + """Test ANTHROPIC maps to anthropic.""" + with patch("app.services.llm_service.ChatLiteLLM") as MockChatLiteLLM: + mock_llm = MagicMock() + mock_response = MagicMock() + mock_response.content = "Hello!" + mock_llm.ainvoke = AsyncMock(return_value=mock_response) + MockChatLiteLLM.return_value = mock_llm + + await validate_llm_config( + provider="ANTHROPIC", + model_name="claude-3", + api_key="sk-test", + ) + + call_args = MockChatLiteLLM.call_args + assert "anthropic/claude-3" in call_args.kwargs.get("model", "") + + @pytest.mark.asyncio + async def test_google_provider_mapping(self): + """Test GOOGLE maps to gemini.""" + with patch("app.services.llm_service.ChatLiteLLM") as MockChatLiteLLM: + mock_llm = MagicMock() + mock_response = MagicMock() + mock_response.content = "Hello!" + mock_llm.ainvoke = AsyncMock(return_value=mock_response) + MockChatLiteLLM.return_value = mock_llm + + await validate_llm_config( + provider="GOOGLE", + model_name="gemini-pro", + api_key="test-key", + ) + + call_args = MockChatLiteLLM.call_args + assert "gemini/gemini-pro" in call_args.kwargs.get("model", "") + + @pytest.mark.asyncio + async def test_ollama_provider_mapping(self): + """Test OLLAMA maps to ollama.""" + with patch("app.services.llm_service.ChatLiteLLM") as MockChatLiteLLM: + mock_llm = MagicMock() + mock_response = MagicMock() + mock_response.content = "Hello!" + mock_llm.ainvoke = AsyncMock(return_value=mock_response) + MockChatLiteLLM.return_value = mock_llm + + await validate_llm_config( + provider="OLLAMA", + model_name="llama2", + api_key="", + api_base="http://localhost:11434", + ) + + call_args = MockChatLiteLLM.call_args + assert "ollama/llama2" in call_args.kwargs.get("model", "") + + +class TestLiteLLMParams: + """Tests for litellm_params handling.""" + + @pytest.mark.asyncio + async def test_litellm_params_passed_to_instance(self): + """Test that litellm_params are passed to ChatLiteLLM.""" + with patch("app.services.llm_service.ChatLiteLLM") as MockChatLiteLLM: + mock_llm = MagicMock() + mock_response = MagicMock() + mock_response.content = "Hello!" + mock_llm.ainvoke = AsyncMock(return_value=mock_response) + MockChatLiteLLM.return_value = mock_llm + + await validate_llm_config( + provider="OPENAI", + model_name="gpt-4", + api_key="sk-test", + litellm_params={"temperature": 0.7, "max_tokens": 1000}, + ) + + call_args = MockChatLiteLLM.call_args + assert call_args.kwargs.get("temperature") == 0.7 + assert call_args.kwargs.get("max_tokens") == 1000 diff --git a/surfsense_backend/tests/test_page_limit_service.py b/surfsense_backend/tests/test_page_limit_service.py new file mode 100644 index 000000000..1cef0b5dd --- /dev/null +++ b/surfsense_backend/tests/test_page_limit_service.py @@ -0,0 +1,354 @@ +""" +Tests for PageLimitService. + +This module tests the page limit service used for tracking user document processing limits. +""" + +import os +import tempfile +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from app.services.page_limit_service import PageLimitExceededError, PageLimitService + + +class TestPageLimitExceededError: + """Tests for PageLimitExceededError exception.""" + + def test_default_message(self): + """Test default error message.""" + error = PageLimitExceededError() + assert "Page limit exceeded" in str(error) + + def test_custom_message(self): + """Test custom error message.""" + error = PageLimitExceededError(message="Custom message") + assert str(error) == "Custom message" + + def test_stores_usage_info(self): + """Test error stores usage information.""" + error = PageLimitExceededError( + pages_used=100, + pages_limit=200, + pages_to_add=50, + ) + assert error.pages_used == 100 + assert error.pages_limit == 200 + assert error.pages_to_add == 50 + + def test_default_values(self): + """Test default values are zero.""" + error = PageLimitExceededError() + assert error.pages_used == 0 + assert error.pages_limit == 0 + assert error.pages_to_add == 0 + + +class TestPageLimitServiceEstimation: + """Tests for page estimation methods.""" + + @pytest.fixture + def service(self): + """Create a PageLimitService with mock session.""" + mock_session = AsyncMock() + return PageLimitService(mock_session) + + def test_estimate_pages_from_elements_with_page_numbers(self, service): + """Test estimation from elements with page number metadata.""" + elements = [] + for page in [1, 1, 2, 2, 3]: # 3 unique pages + elem = MagicMock() + elem.metadata = {"page_number": page} + elements.append(elem) + + result = service.estimate_pages_from_elements(elements) + assert result == 3 + + def test_estimate_pages_from_elements_by_content_length(self, service): + """Test estimation from elements by content length.""" + elements = [] + # Create elements with ~4000 chars total (should be 2 pages) + for i in range(4): + elem = MagicMock() + elem.metadata = {} # No page number + elem.page_content = "x" * 1000 # 1000 chars each + elements.append(elem) + + result = service.estimate_pages_from_elements(elements) + assert result == 2 # 4000 / 2000 = 2 + + def test_estimate_pages_from_elements_empty_list(self, service): + """Test estimation from empty elements list returns minimum 1.""" + result = service.estimate_pages_from_elements([]) + # Implementation uses max(1, ...) so minimum is 1 + assert result == 1 + + def test_estimate_pages_from_markdown_with_metadata(self, service): + """Test estimation from markdown documents with page metadata.""" + docs = [] + for page in range(5): + doc = MagicMock() + doc.metadata = {"page": page} + doc.text = "Content" + docs.append(doc) + + result = service.estimate_pages_from_markdown(docs) + assert result == 5 + + def test_estimate_pages_from_markdown_by_content(self, service): + """Test estimation from markdown by content length.""" + docs = [] + for i in range(2): + doc = MagicMock() + doc.metadata = {} + doc.text = "x" * 4000 # 4000 chars = 2 pages each + docs.append(doc) + + result = service.estimate_pages_from_markdown(docs) + assert result == 4 # (4000/2000) * 2 = 4 + + def test_estimate_pages_from_markdown_empty_list(self, service): + """Test estimation from empty markdown list.""" + result = service.estimate_pages_from_markdown([]) + assert result == 1 # Minimum 1 page + + def test_estimate_pages_from_content_length(self, service): + """Test estimation from content length.""" + # 5000 chars should be ~2 pages + result = service.estimate_pages_from_content_length(5000) + assert result == 2 + + def test_estimate_pages_from_content_length_minimum(self, service): + """Test minimum of 1 page for small content.""" + result = service.estimate_pages_from_content_length(100) + assert result == 1 + + def test_estimate_pages_from_content_length_zero(self, service): + """Test zero content length returns 1 page.""" + result = service.estimate_pages_from_content_length(0) + assert result == 1 + + +class TestPageEstimationFromFile: + """Tests for estimate_pages_before_processing method.""" + + @pytest.fixture + def service(self): + """Create a PageLimitService with mock session.""" + mock_session = AsyncMock() + return PageLimitService(mock_session) + + def test_file_not_found(self, service): + """Test error when file doesn't exist.""" + with pytest.raises(ValueError, match="File not found"): + service.estimate_pages_before_processing("/nonexistent/file.pdf") + + def test_text_file_estimation(self, service): + """Test estimation for text files.""" + with tempfile.NamedTemporaryFile(suffix=".txt", delete=False) as f: + # Write ~6000 bytes (2 pages at 3000 bytes/page) + f.write(b"x" * 6000) + f.flush() + + try: + result = service.estimate_pages_before_processing(f.name) + assert result == 2 + finally: + os.unlink(f.name) + + def test_small_text_file(self, service): + """Test minimum 1 page for small files.""" + with tempfile.NamedTemporaryFile(suffix=".txt", delete=False) as f: + f.write(b"small") + f.flush() + + try: + result = service.estimate_pages_before_processing(f.name) + assert result == 1 + finally: + os.unlink(f.name) + + def test_markdown_file_estimation(self, service): + """Test estimation for markdown files.""" + with tempfile.NamedTemporaryFile(suffix=".md", delete=False) as f: + # Need at least 6000 bytes for 2 pages (3000 bytes per page) + f.write(b"# Title\n" + b"x" * 6000) + f.flush() + + try: + result = service.estimate_pages_before_processing(f.name) + assert result == 2 + finally: + os.unlink(f.name) + + def test_image_file_estimation(self, service): + """Test image files return 1 page.""" + for ext in [".jpg", ".png", ".gif", ".bmp"]: + with tempfile.NamedTemporaryFile(suffix=ext, delete=False) as f: + f.write(b"fake image data" * 1000) + f.flush() + + try: + result = service.estimate_pages_before_processing(f.name) + assert result == 1, f"Expected 1 page for {ext}" + finally: + os.unlink(f.name) + + def test_word_doc_estimation(self, service): + """Test estimation for Word documents.""" + with tempfile.NamedTemporaryFile(suffix=".docx", delete=False) as f: + # Write ~100KB (2 pages at 50KB/page) + f.write(b"x" * (100 * 1024)) + f.flush() + + try: + result = service.estimate_pages_before_processing(f.name) + assert result == 2 + finally: + os.unlink(f.name) + + def test_presentation_estimation(self, service): + """Test estimation for presentation files.""" + with tempfile.NamedTemporaryFile(suffix=".pptx", delete=False) as f: + # Write ~400KB (2 slides at 200KB/slide) + f.write(b"x" * (400 * 1024)) + f.flush() + + try: + result = service.estimate_pages_before_processing(f.name) + assert result == 2 + finally: + os.unlink(f.name) + + def test_spreadsheet_estimation(self, service): + """Test estimation for spreadsheet files.""" + with tempfile.NamedTemporaryFile(suffix=".xlsx", delete=False) as f: + # Write ~200KB (2 sheets at 100KB/sheet) + f.write(b"x" * (200 * 1024)) + f.flush() + + try: + result = service.estimate_pages_before_processing(f.name) + assert result == 2 + finally: + os.unlink(f.name) + + def test_html_file_estimation(self, service): + """Test estimation for HTML files.""" + with tempfile.NamedTemporaryFile(suffix=".html", delete=False) as f: + f.write(b"" + b"x" * 5980 + b"") + f.flush() + + try: + result = service.estimate_pages_before_processing(f.name) + assert result == 2 # ~6000 / 3000 = 2 + finally: + os.unlink(f.name) + + def test_unknown_extension(self, service): + """Test estimation for unknown file types.""" + with tempfile.NamedTemporaryFile(suffix=".xyz", delete=False) as f: + # Write ~160KB (2 pages at 80KB/page) + f.write(b"x" * (160 * 1024)) + f.flush() + + try: + result = service.estimate_pages_before_processing(f.name) + assert result == 2 + finally: + os.unlink(f.name) + + def test_pdf_estimation_fallback(self, service): + """Test PDF estimation falls back when pypdf fails.""" + with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as f: + # Write invalid PDF data (will fail to parse) + f.write(b"not a real pdf" * 10000) # ~140KB + f.flush() + + try: + result = service.estimate_pages_before_processing(f.name) + # Falls back to size estimation: ~140KB / 100KB = 1 page + assert result >= 1 + finally: + os.unlink(f.name) + + +class TestPageLimitServiceDatabase: + """Tests for database operations (mocked).""" + + @pytest.fixture + def mock_user(self): + """Create a mock user.""" + user = MagicMock() + user.pages_used = 50 + user.pages_limit = 100 + return user + + @pytest.fixture + def service(self): + """Create a PageLimitService with mock session.""" + mock_session = AsyncMock() + return PageLimitService(mock_session) + + @pytest.mark.asyncio + async def test_check_page_limit_success(self, service, mock_user): + """Test check_page_limit succeeds when within limit.""" + # Setup mock to return user data + mock_result = MagicMock() + mock_result.first.return_value = (50, 100) # pages_used, pages_limit + service.session.execute.return_value = mock_result + + has_capacity, pages_used, pages_limit = await service.check_page_limit( + "user-123", + estimated_pages=10, + ) + + assert has_capacity is True + assert pages_used == 50 + assert pages_limit == 100 + + @pytest.mark.asyncio + async def test_check_page_limit_exceeds(self, service): + """Test check_page_limit raises error when would exceed limit.""" + mock_result = MagicMock() + mock_result.first.return_value = (95, 100) # Near limit + service.session.execute.return_value = mock_result + + with pytest.raises(PageLimitExceededError) as exc_info: + await service.check_page_limit("user-123", estimated_pages=10) + + assert exc_info.value.pages_used == 95 + assert exc_info.value.pages_limit == 100 + assert exc_info.value.pages_to_add == 10 + + @pytest.mark.asyncio + async def test_check_page_limit_user_not_found(self, service): + """Test check_page_limit raises error for missing user.""" + mock_result = MagicMock() + mock_result.first.return_value = None + service.session.execute.return_value = mock_result + + with pytest.raises(ValueError, match="User with ID .* not found"): + await service.check_page_limit("nonexistent", estimated_pages=1) + + @pytest.mark.asyncio + async def test_get_page_usage(self, service): + """Test get_page_usage returns correct values.""" + mock_result = MagicMock() + mock_result.first.return_value = (75, 500) + service.session.execute.return_value = mock_result + + result = await service.get_page_usage("user-123") + + assert result == (75, 500) + + @pytest.mark.asyncio + async def test_get_page_usage_user_not_found(self, service): + """Test get_page_usage raises error for missing user.""" + mock_result = MagicMock() + mock_result.first.return_value = None + service.session.execute.return_value = mock_result + + with pytest.raises(ValueError, match="User with ID .* not found"): + await service.get_page_usage("nonexistent") diff --git a/surfsense_backend/tests/test_retrievers.py b/surfsense_backend/tests/test_retrievers.py new file mode 100644 index 000000000..8bf7a6331 --- /dev/null +++ b/surfsense_backend/tests/test_retrievers.py @@ -0,0 +1,98 @@ +""" +Tests for hybrid search retrievers. +Tests the ChucksHybridSearchRetriever and DocumentHybridSearchRetriever classes. +""" +import pytest +from unittest.mock import AsyncMock, MagicMock, patch + +from app.retriever.chunks_hybrid_search import ChucksHybridSearchRetriever +from app.retriever.documents_hybrid_search import DocumentHybridSearchRetriever + + +class TestChunksHybridSearchRetriever: + """Tests for ChucksHybridSearchRetriever.""" + + def test_init(self): + """Test retriever initialization.""" + mock_session = AsyncMock() + retriever = ChucksHybridSearchRetriever(mock_session) + + assert retriever.db_session == mock_session + + @pytest.mark.asyncio + async def test_hybrid_search_returns_empty_on_no_results(self): + """Test hybrid search returns empty list when no results.""" + mock_session = AsyncMock() + + # Mock the session to return empty results + mock_result = MagicMock() + mock_scalars = MagicMock() + mock_scalars.all.return_value = [] + mock_result.scalars.return_value = mock_scalars + mock_session.execute = AsyncMock(return_value=mock_result) + + retriever = ChucksHybridSearchRetriever(mock_session) + + with patch.object(retriever, 'hybrid_search', new_callable=AsyncMock) as mock_search: + mock_search.return_value = [] + + result = await retriever.hybrid_search( + query_text="test query", + top_k=10, + search_space_id=1, + document_type="FILE", + ) + + assert result == [] + + +class TestDocumentHybridSearchRetriever: + """Tests for DocumentHybridSearchRetriever.""" + + def test_init(self): + """Test retriever initialization.""" + mock_session = AsyncMock() + retriever = DocumentHybridSearchRetriever(mock_session) + + assert retriever.db_session == mock_session + + @pytest.mark.asyncio + async def test_hybrid_search_returns_empty_on_no_results(self): + """Test hybrid search returns empty list when no results.""" + mock_session = AsyncMock() + + retriever = DocumentHybridSearchRetriever(mock_session) + + with patch.object(retriever, 'hybrid_search', new_callable=AsyncMock) as mock_search: + mock_search.return_value = [] + + result = await retriever.hybrid_search( + query_text="test query", + top_k=10, + search_space_id=1, + document_type="FILE", + ) + + assert result == [] + + +class TestRetrieverIntegration: + """Integration tests for retrievers.""" + + def test_chunk_retriever_uses_correct_session(self): + """Test chunk retriever uses provided session.""" + mock_session = AsyncMock() + mock_session.id = "test-session" + + retriever = ChucksHybridSearchRetriever(mock_session) + + assert retriever.db_session.id == "test-session" + + def test_document_retriever_uses_correct_session(self): + """Test document retriever uses provided session.""" + mock_session = AsyncMock() + mock_session.id = "test-session" + + retriever = DocumentHybridSearchRetriever(mock_session) + + assert retriever.db_session.id == "test-session" diff --git a/surfsense_backend/tests/test_routes_documents.py b/surfsense_backend/tests/test_routes_documents.py new file mode 100644 index 000000000..04437e4a7 --- /dev/null +++ b/surfsense_backend/tests/test_routes_documents.py @@ -0,0 +1,440 @@ +""" +Tests for documents routes. +Tests API endpoints with mocked database sessions and authentication. +""" +import pytest +from unittest.mock import AsyncMock, MagicMock, patch +from fastapi import HTTPException + +from app.routes.documents_routes import ( + create_documents, + read_documents, + search_documents, + read_document, + update_document, + delete_document, + get_document_type_counts, + get_document_by_chunk_id, +) +from app.schemas import DocumentsCreate, DocumentUpdate +from app.db import DocumentType + + +class TestCreateDocuments: + """Tests for the create_documents endpoint.""" + + @pytest.mark.asyncio + async def test_create_documents_invalid_type(self): + """Test creating documents with invalid type.""" + mock_session = AsyncMock() + mock_session.rollback = AsyncMock() + mock_user = MagicMock() + mock_user.id = "user-123" + + # Use a type that triggers the else branch + request = DocumentsCreate( + search_space_id=1, + document_type=DocumentType.FILE, # Not EXTENSION or YOUTUBE_VIDEO + content=[], + ) + + with patch("app.routes.documents_routes.check_permission") as mock_check: + mock_check.return_value = None + + with pytest.raises(HTTPException) as exc_info: + await create_documents( + request=request, + session=mock_session, + user=mock_user, + ) + + assert exc_info.value.status_code == 400 + + +class TestReadDocuments: + """Tests for the read_documents endpoint.""" + + @pytest.mark.asyncio + async def test_read_documents_with_search_space_filter(self): + """Test reading documents filtered by search space.""" + mock_session = AsyncMock() + mock_user = MagicMock() + mock_user.id = "user-123" + + # Mock query results + mock_result = MagicMock() + mock_scalars = MagicMock() + mock_scalars.all.return_value = [] + mock_result.scalars.return_value = mock_scalars + mock_result.scalar.return_value = 0 + mock_session.execute = AsyncMock(return_value=mock_result) + + with patch("app.routes.documents_routes.check_permission") as mock_check: + mock_check.return_value = None + + result = await read_documents( + skip=0, + page=None, + page_size=50, + search_space_id=1, + document_types=None, + session=mock_session, + user=mock_user, + ) + + assert result.items == [] + assert result.total == 0 + + @pytest.mark.asyncio + async def test_read_documents_with_type_filter(self): + """Test reading documents filtered by type.""" + mock_session = AsyncMock() + mock_user = MagicMock() + mock_user.id = "user-123" + + # Mock query results + mock_result = MagicMock() + mock_scalars = MagicMock() + mock_scalars.all.return_value = [] + mock_result.scalars.return_value = mock_scalars + mock_result.scalar.return_value = 0 + mock_session.execute = AsyncMock(return_value=mock_result) + + with patch("app.routes.documents_routes.check_permission") as mock_check: + mock_check.return_value = None + + result = await read_documents( + skip=0, + page=None, + page_size=50, + search_space_id=1, + document_types="EXTENSION,FILE", + session=mock_session, + user=mock_user, + ) + + assert result.items == [] + + @pytest.mark.asyncio + async def test_read_documents_all_search_spaces(self): + """Test reading documents from all accessible search spaces.""" + mock_session = AsyncMock() + mock_user = MagicMock() + mock_user.id = "user-123" + + # Mock query results + mock_result = MagicMock() + mock_scalars = MagicMock() + mock_scalars.all.return_value = [] + mock_result.scalars.return_value = mock_scalars + mock_result.scalar.return_value = 0 + mock_session.execute = AsyncMock(return_value=mock_result) + + result = await read_documents( + skip=0, + page=None, + page_size=50, + search_space_id=None, + document_types=None, + session=mock_session, + user=mock_user, + ) + + assert result.items == [] + + +class TestSearchDocuments: + """Tests for the search_documents endpoint.""" + + @pytest.mark.asyncio + async def test_search_documents_by_title(self): + """Test searching documents by title.""" + mock_session = AsyncMock() + mock_user = MagicMock() + mock_user.id = "user-123" + + # Mock query results + mock_result = MagicMock() + mock_scalars = MagicMock() + mock_scalars.all.return_value = [] + mock_result.scalars.return_value = mock_scalars + mock_result.scalar.return_value = 0 + mock_session.execute = AsyncMock(return_value=mock_result) + + with patch("app.routes.documents_routes.check_permission") as mock_check: + mock_check.return_value = None + + result = await search_documents( + title="test", + skip=0, + page=None, + page_size=50, + search_space_id=1, + document_types=None, + session=mock_session, + user=mock_user, + ) + + assert result.items == [] + assert result.total == 0 + + +class TestReadDocument: + """Tests for the read_document endpoint.""" + + @pytest.mark.asyncio + async def test_read_document_not_found(self): + """Test reading non-existent document.""" + mock_session = AsyncMock() + mock_user = MagicMock() + mock_user.id = "user-123" + + # Mock empty result + mock_result = MagicMock() + mock_scalars = MagicMock() + mock_scalars.first.return_value = None + mock_result.scalars.return_value = mock_scalars + mock_session.execute = AsyncMock(return_value=mock_result) + + with pytest.raises(HTTPException) as exc_info: + await read_document( + document_id=999, + session=mock_session, + user=mock_user, + ) + + assert exc_info.value.status_code == 404 + + @pytest.mark.asyncio + async def test_read_document_success(self): + """Test successful document reading.""" + from datetime import datetime + + mock_session = AsyncMock() + mock_user = MagicMock() + mock_user.id = "user-123" + + # Mock existing document + mock_document = MagicMock() + mock_document.id = 1 + mock_document.title = "Test Document" + mock_document.document_type = DocumentType.FILE + mock_document.document_metadata = {} + mock_document.content = "Test content" + mock_document.created_at = datetime.now() # Must be a datetime + mock_document.search_space_id = 1 + + mock_result = MagicMock() + mock_scalars = MagicMock() + mock_scalars.first.return_value = mock_document + mock_result.scalars.return_value = mock_scalars + mock_session.execute = AsyncMock(return_value=mock_result) + + with patch("app.routes.documents_routes.check_permission") as mock_check: + mock_check.return_value = None + + result = await read_document( + document_id=1, + session=mock_session, + user=mock_user, + ) + + assert result.id == 1 + assert result.title == "Test Document" + + +class TestUpdateDocument: + """Tests for the update_document endpoint.""" + + @pytest.mark.asyncio + async def test_update_document_not_found(self): + """Test updating non-existent document.""" + mock_session = AsyncMock() + mock_user = MagicMock() + mock_user.id = "user-123" + + # Mock empty result + mock_result = MagicMock() + mock_scalars = MagicMock() + mock_scalars.first.return_value = None + mock_result.scalars.return_value = mock_scalars + mock_session.execute = AsyncMock(return_value=mock_result) + mock_session.rollback = AsyncMock() + + # DocumentUpdate requires document_type, content, and search_space_id + update_data = DocumentUpdate( + document_type=DocumentType.FILE, + content="Updated content", + search_space_id=1 + ) + + with pytest.raises(HTTPException) as exc_info: + await update_document( + document_id=999, + document_update=update_data, + session=mock_session, + user=mock_user, + ) + + assert exc_info.value.status_code == 404 + + @pytest.mark.asyncio + async def test_update_document_success(self): + """Test successful document update.""" + from datetime import datetime + + mock_session = AsyncMock() + mock_user = MagicMock() + mock_user.id = "user-123" + + # Mock existing document + mock_document = MagicMock() + mock_document.id = 1 + mock_document.title = "Old Title" + mock_document.document_type = DocumentType.FILE + mock_document.document_metadata = {} + mock_document.content = "Test content" + mock_document.created_at = datetime.now() + mock_document.search_space_id = 1 + + mock_result = MagicMock() + mock_scalars = MagicMock() + mock_scalars.first.return_value = mock_document + mock_result.scalars.return_value = mock_scalars + mock_session.execute = AsyncMock(return_value=mock_result) + mock_session.commit = AsyncMock() + mock_session.refresh = AsyncMock() + + # DocumentUpdate requires document_type, content, and search_space_id + update_data = DocumentUpdate( + document_type=DocumentType.FILE, + content="New content", + search_space_id=1 + ) + + with patch("app.routes.documents_routes.check_permission") as mock_check: + mock_check.return_value = None + + await update_document( + document_id=1, + document_update=update_data, + session=mock_session, + user=mock_user, + ) + + assert mock_session.commit.called + + +class TestDeleteDocument: + """Tests for the delete_document endpoint.""" + + @pytest.mark.asyncio + async def test_delete_document_not_found(self): + """Test deleting non-existent document.""" + mock_session = AsyncMock() + mock_user = MagicMock() + mock_user.id = "user-123" + + # Mock empty result + mock_result = MagicMock() + mock_scalars = MagicMock() + mock_scalars.first.return_value = None + mock_result.scalars.return_value = mock_scalars + mock_session.execute = AsyncMock(return_value=mock_result) + mock_session.rollback = AsyncMock() + + with pytest.raises(HTTPException) as exc_info: + await delete_document( + document_id=999, + session=mock_session, + user=mock_user, + ) + + assert exc_info.value.status_code == 404 + + @pytest.mark.asyncio + async def test_delete_document_success(self): + """Test successful document deletion.""" + mock_session = AsyncMock() + mock_user = MagicMock() + mock_user.id = "user-123" + + # Mock existing document + mock_document = MagicMock() + mock_document.id = 1 + mock_document.search_space_id = 1 + + mock_result = MagicMock() + mock_scalars = MagicMock() + mock_scalars.first.return_value = mock_document + mock_result.scalars.return_value = mock_scalars + mock_session.execute = AsyncMock(return_value=mock_result) + mock_session.delete = AsyncMock() + mock_session.commit = AsyncMock() + + with patch("app.routes.documents_routes.check_permission") as mock_check: + mock_check.return_value = None + + result = await delete_document( + document_id=1, + session=mock_session, + user=mock_user, + ) + + assert result["message"] == "Document deleted successfully" + assert mock_session.delete.called + + +class TestGetDocumentTypeCounts: + """Tests for the get_document_type_counts endpoint.""" + + @pytest.mark.asyncio + async def test_get_document_type_counts_success(self): + """Test getting document type counts.""" + mock_session = AsyncMock() + mock_user = MagicMock() + mock_user.id = "user-123" + + # Mock query result + mock_result = MagicMock() + mock_result.all.return_value = [("FILE", 5), ("EXTENSION", 3)] + mock_session.execute = AsyncMock(return_value=mock_result) + + with patch("app.routes.documents_routes.check_permission") as mock_check: + mock_check.return_value = None + + result = await get_document_type_counts( + search_space_id=1, + session=mock_session, + user=mock_user, + ) + + assert result == {"FILE": 5, "EXTENSION": 3} + + +class TestGetDocumentByChunkId: + """Tests for the get_document_by_chunk_id endpoint.""" + + @pytest.mark.asyncio + async def test_get_document_by_chunk_id_chunk_not_found(self): + """Test getting document when chunk doesn't exist.""" + mock_session = AsyncMock() + mock_user = MagicMock() + mock_user.id = "user-123" + + # Mock empty chunk result + mock_result = MagicMock() + mock_scalars = MagicMock() + mock_scalars.first.return_value = None + mock_result.scalars.return_value = mock_scalars + mock_session.execute = AsyncMock(return_value=mock_result) + + with pytest.raises(HTTPException) as exc_info: + await get_document_by_chunk_id( + chunk_id=999, + session=mock_session, + user=mock_user, + ) + + assert exc_info.value.status_code == 404 + assert "Chunk" in exc_info.value.detail diff --git a/surfsense_backend/tests/test_routes_llm_config.py b/surfsense_backend/tests/test_routes_llm_config.py new file mode 100644 index 000000000..166b0f59e --- /dev/null +++ b/surfsense_backend/tests/test_routes_llm_config.py @@ -0,0 +1,421 @@ +""" +Tests for LLM config routes. +Tests API endpoints with mocked database sessions and authentication. +""" +import pytest +from unittest.mock import AsyncMock, MagicMock, patch +from fastapi import HTTPException + +from app.routes.llm_config_routes import ( + get_global_llm_configs, + create_llm_config, + read_llm_configs, + read_llm_config, + update_llm_config, + delete_llm_config, + get_llm_preferences, + update_llm_preferences, + LLMPreferencesUpdate, +) +from app.schemas import LLMConfigCreate, LLMConfigUpdate +from app.db import LiteLLMProvider + + +class TestGetGlobalLLMConfigs: + """Tests for the get_global_llm_configs endpoint.""" + + @pytest.mark.asyncio + async def test_returns_global_configs_without_api_keys(self): + """Test that global configs are returned without exposing API keys.""" + mock_user = MagicMock() + mock_user.id = "user-123" + + with patch("app.routes.llm_config_routes.config") as mock_config: + mock_config.GLOBAL_LLM_CONFIGS = [ + { + "id": -1, + "name": "GPT-4", + "provider": "OPENAI", + "custom_provider": None, + "model_name": "gpt-4", + "api_key": "sk-secret-key", + "api_base": None, + "language": "en", + "litellm_params": {}, + }, + ] + + result = await get_global_llm_configs(user=mock_user) + + assert len(result) == 1 + # API key should not be in the response + assert "api_key" not in result[0] or result[0].get("api_key") != "sk-secret-key" + assert result[0]["name"] == "GPT-4" + assert result[0]["is_global"] is True + + @pytest.mark.asyncio + async def test_handles_empty_global_configs(self): + """Test handling when no global configs are configured.""" + mock_user = MagicMock() + mock_user.id = "user-123" + + with patch("app.routes.llm_config_routes.config") as mock_config: + mock_config.GLOBAL_LLM_CONFIGS = [] + + result = await get_global_llm_configs(user=mock_user) + + assert result == [] + + +class TestCreateLLMConfig: + """Tests for the create_llm_config endpoint.""" + + @pytest.mark.asyncio + async def test_create_llm_config_invalid_validation(self): + """Test creating LLM config with invalid validation.""" + mock_session = AsyncMock() + mock_user = MagicMock() + mock_user.id = "user-123" + + llm_config_data = LLMConfigCreate( + name="Test LLM", + provider=LiteLLMProvider.OPENAI, + model_name="gpt-4", + api_key="invalid-key", + search_space_id=1, + ) + + with patch("app.routes.llm_config_routes.check_permission") as mock_check: + mock_check.return_value = None + + with patch("app.routes.llm_config_routes.validate_llm_config") as mock_validate: + mock_validate.return_value = (False, "Invalid API key") + + with pytest.raises(HTTPException) as exc_info: + await create_llm_config( + llm_config=llm_config_data, + session=mock_session, + user=mock_user, + ) + + assert exc_info.value.status_code == 400 + assert "Invalid LLM configuration" in exc_info.value.detail + + @pytest.mark.asyncio + async def test_create_llm_config_success(self): + """Test successful LLM config creation.""" + mock_session = AsyncMock() + mock_session.add = MagicMock() + mock_session.commit = AsyncMock() + mock_session.refresh = AsyncMock() + + mock_user = MagicMock() + mock_user.id = "user-123" + + llm_config_data = LLMConfigCreate( + name="Test LLM", + provider=LiteLLMProvider.OPENAI, + model_name="gpt-4", + api_key="sk-valid-key", + search_space_id=1, + ) + + with patch("app.routes.llm_config_routes.check_permission") as mock_check: + mock_check.return_value = None + + with patch("app.routes.llm_config_routes.validate_llm_config") as mock_validate: + mock_validate.return_value = (True, "") + + with patch("app.routes.llm_config_routes.LLMConfig") as MockLLMConfig: + mock_config = MagicMock() + mock_config.id = 1 + mock_config.name = "Test LLM" + MockLLMConfig.return_value = mock_config + + await create_llm_config( + llm_config=llm_config_data, + session=mock_session, + user=mock_user, + ) + + assert mock_session.add.called + assert mock_session.commit.called + + +class TestReadLLMConfigs: + """Tests for the read_llm_configs endpoint.""" + + @pytest.mark.asyncio + async def test_read_llm_configs_success(self): + """Test successful reading of LLM configs.""" + mock_session = AsyncMock() + mock_user = MagicMock() + mock_user.id = "user-123" + + # Mock query result + mock_result = MagicMock() + mock_scalars = MagicMock() + mock_scalars.all.return_value = [] + mock_result.scalars.return_value = mock_scalars + mock_session.execute = AsyncMock(return_value=mock_result) + + with patch("app.routes.llm_config_routes.check_permission") as mock_check: + mock_check.return_value = None + + result = await read_llm_configs( + search_space_id=1, + skip=0, + limit=200, + session=mock_session, + user=mock_user, + ) + + assert isinstance(result, list) + + +class TestReadLLMConfig: + """Tests for the read_llm_config endpoint.""" + + @pytest.mark.asyncio + async def test_read_llm_config_not_found(self): + """Test reading non-existent LLM config.""" + mock_session = AsyncMock() + mock_user = MagicMock() + mock_user.id = "user-123" + + # Mock empty result + mock_result = MagicMock() + mock_scalars = MagicMock() + mock_scalars.first.return_value = None + mock_result.scalars.return_value = mock_scalars + mock_session.execute = AsyncMock(return_value=mock_result) + + with pytest.raises(HTTPException) as exc_info: + await read_llm_config( + llm_config_id=999, + session=mock_session, + user=mock_user, + ) + + assert exc_info.value.status_code == 404 + + @pytest.mark.asyncio + async def test_read_llm_config_success(self): + """Test successful reading of LLM config.""" + mock_session = AsyncMock() + mock_user = MagicMock() + mock_user.id = "user-123" + + # Mock existing config + mock_config = MagicMock() + mock_config.id = 1 + mock_config.search_space_id = 1 + + mock_result = MagicMock() + mock_scalars = MagicMock() + mock_scalars.first.return_value = mock_config + mock_result.scalars.return_value = mock_scalars + mock_session.execute = AsyncMock(return_value=mock_result) + + with patch("app.routes.llm_config_routes.check_permission") as mock_check: + mock_check.return_value = None + + result = await read_llm_config( + llm_config_id=1, + session=mock_session, + user=mock_user, + ) + + assert result.id == 1 + + +class TestUpdateLLMConfig: + """Tests for the update_llm_config endpoint.""" + + @pytest.mark.asyncio + async def test_update_llm_config_not_found(self): + """Test updating non-existent LLM config.""" + mock_session = AsyncMock() + mock_user = MagicMock() + mock_user.id = "user-123" + + # Mock empty result + mock_result = MagicMock() + mock_scalars = MagicMock() + mock_scalars.first.return_value = None + mock_result.scalars.return_value = mock_scalars + mock_session.execute = AsyncMock(return_value=mock_result) + mock_session.rollback = AsyncMock() + + update_data = LLMConfigUpdate(name="Updated Name") + + with pytest.raises(HTTPException) as exc_info: + await update_llm_config( + llm_config_id=999, + llm_config_update=update_data, + session=mock_session, + user=mock_user, + ) + + assert exc_info.value.status_code == 404 + + +class TestDeleteLLMConfig: + """Tests for the delete_llm_config endpoint.""" + + @pytest.mark.asyncio + async def test_delete_llm_config_not_found(self): + """Test deleting non-existent LLM config.""" + mock_session = AsyncMock() + mock_user = MagicMock() + mock_user.id = "user-123" + + # Mock empty result + mock_result = MagicMock() + mock_scalars = MagicMock() + mock_scalars.first.return_value = None + mock_result.scalars.return_value = mock_scalars + mock_session.execute = AsyncMock(return_value=mock_result) + mock_session.rollback = AsyncMock() + + with pytest.raises(HTTPException) as exc_info: + await delete_llm_config( + llm_config_id=999, + session=mock_session, + user=mock_user, + ) + + assert exc_info.value.status_code == 404 + + @pytest.mark.asyncio + async def test_delete_llm_config_success(self): + """Test successful LLM config deletion.""" + mock_session = AsyncMock() + mock_user = MagicMock() + mock_user.id = "user-123" + + # Mock existing config + mock_config = MagicMock() + mock_config.id = 1 + mock_config.search_space_id = 1 + + mock_result = MagicMock() + mock_scalars = MagicMock() + mock_scalars.first.return_value = mock_config + mock_result.scalars.return_value = mock_scalars + mock_session.execute = AsyncMock(return_value=mock_result) + mock_session.delete = AsyncMock() + mock_session.commit = AsyncMock() + + with patch("app.routes.llm_config_routes.check_permission") as mock_check: + mock_check.return_value = None + + result = await delete_llm_config( + llm_config_id=1, + session=mock_session, + user=mock_user, + ) + + assert result["message"] == "LLM configuration deleted successfully" + assert mock_session.delete.called + assert mock_session.commit.called + + +class TestGetLLMPreferences: + """Tests for the get_llm_preferences endpoint.""" + + @pytest.mark.asyncio + async def test_get_llm_preferences_not_found(self): + """Test getting preferences for non-existent search space.""" + mock_session = AsyncMock() + mock_user = MagicMock() + mock_user.id = "user-123" + + # Mock empty result + mock_result = MagicMock() + mock_scalars = MagicMock() + mock_scalars.first.return_value = None + mock_result.scalars.return_value = mock_scalars + mock_session.execute = AsyncMock(return_value=mock_result) + + with patch("app.routes.llm_config_routes.check_permission") as mock_check: + mock_check.return_value = None + + with pytest.raises(HTTPException) as exc_info: + await get_llm_preferences( + search_space_id=999, + session=mock_session, + user=mock_user, + ) + + assert exc_info.value.status_code == 404 + + +class TestUpdateLLMPreferences: + """Tests for the update_llm_preferences endpoint.""" + + @pytest.mark.asyncio + async def test_update_llm_preferences_search_space_not_found(self): + """Test updating preferences for non-existent search space.""" + mock_session = AsyncMock() + mock_user = MagicMock() + mock_user.id = "user-123" + + # Mock empty result + mock_result = MagicMock() + mock_scalars = MagicMock() + mock_scalars.first.return_value = None + mock_result.scalars.return_value = mock_scalars + mock_session.execute = AsyncMock(return_value=mock_result) + mock_session.rollback = AsyncMock() + + preferences = LLMPreferencesUpdate(fast_llm_id=1) + + with patch("app.routes.llm_config_routes.check_permission") as mock_check: + mock_check.return_value = None + + with pytest.raises(HTTPException) as exc_info: + await update_llm_preferences( + search_space_id=999, + preferences=preferences, + session=mock_session, + user=mock_user, + ) + + assert exc_info.value.status_code == 404 + + @pytest.mark.asyncio + async def test_update_llm_preferences_global_config_not_found(self): + """Test updating with non-existent global config.""" + mock_session = AsyncMock() + mock_user = MagicMock() + mock_user.id = "user-123" + + # Mock search space exists + mock_search_space = MagicMock() + mock_search_space.id = 1 + + mock_result = MagicMock() + mock_scalars = MagicMock() + mock_scalars.first.return_value = mock_search_space + mock_result.scalars.return_value = mock_scalars + mock_session.execute = AsyncMock(return_value=mock_result) + mock_session.rollback = AsyncMock() + + preferences = LLMPreferencesUpdate(fast_llm_id=-999) # Non-existent global config + + with patch("app.routes.llm_config_routes.check_permission") as mock_check: + mock_check.return_value = None + + with patch("app.routes.llm_config_routes.config") as mock_config: + mock_config.GLOBAL_LLM_CONFIGS = [] + + with pytest.raises(HTTPException) as exc_info: + await update_llm_preferences( + search_space_id=1, + preferences=preferences, + session=mock_session, + user=mock_user, + ) + + assert exc_info.value.status_code == 404 diff --git a/surfsense_backend/tests/test_routes_search_spaces.py b/surfsense_backend/tests/test_routes_search_spaces.py new file mode 100644 index 000000000..0ae59e207 --- /dev/null +++ b/surfsense_backend/tests/test_routes_search_spaces.py @@ -0,0 +1,329 @@ +""" +Tests for search spaces routes. +Tests API endpoints with mocked database sessions and authentication. +""" +import pytest +from unittest.mock import AsyncMock, MagicMock, patch +from fastapi import HTTPException + +from app.routes.search_spaces_routes import ( + create_search_space, + read_search_spaces, + read_search_space, + update_search_space, + delete_search_space, + create_default_roles_and_membership, +) +from app.schemas import SearchSpaceCreate, SearchSpaceUpdate + + +class TestCreateDefaultRolesAndMembership: + """Tests for the create_default_roles_and_membership helper function.""" + + @pytest.mark.asyncio + async def test_creates_default_roles(self): + """Test that default roles are created for a search space.""" + mock_session = AsyncMock() + mock_session.add = MagicMock() + mock_session.flush = AsyncMock() + + with patch("app.routes.search_spaces_routes.get_default_roles_config") as mock_get_roles: + mock_get_roles.return_value = [ + { + "name": "Owner", + "description": "Full access", + "permissions": ["*"], + "is_default": False, + "is_system_role": True, + }, + { + "name": "Editor", + "description": "Can edit", + "permissions": ["documents:create"], + "is_default": True, + "is_system_role": True, + }, + ] + + await create_default_roles_and_membership( + mock_session, + search_space_id=1, + owner_user_id="user-123", + ) + + # Should add roles and membership + assert mock_session.add.call_count >= 2 + assert mock_session.flush.call_count >= 1 + + +class TestCreateSearchSpace: + """Tests for the create_search_space endpoint.""" + + @pytest.mark.asyncio + async def test_create_search_space_success(self): + """Test successful search space creation.""" + mock_session = AsyncMock() + mock_session.add = MagicMock() + mock_session.flush = AsyncMock() + mock_session.commit = AsyncMock() + mock_session.refresh = AsyncMock() + + mock_user = MagicMock() + mock_user.id = "user-123" + + search_space_data = SearchSpaceCreate(name="Test Space") + + with patch("app.routes.search_spaces_routes.create_default_roles_and_membership") as mock_create_roles: + mock_create_roles.return_value = None + + # Mock the SearchSpace class + with patch("app.routes.search_spaces_routes.SearchSpace") as MockSearchSpace: + mock_search_space = MagicMock() + mock_search_space.id = 1 + mock_search_space.name = "Test Space" + MockSearchSpace.return_value = mock_search_space + + await create_search_space( + search_space=search_space_data, + session=mock_session, + user=mock_user, + ) + + assert mock_session.add.called + assert mock_session.commit.called + + @pytest.mark.asyncio + async def test_create_search_space_database_error(self): + """Test search space creation handles database errors.""" + mock_session = AsyncMock() + mock_session.add = MagicMock(side_effect=Exception("Database error")) + mock_session.rollback = AsyncMock() + + mock_user = MagicMock() + mock_user.id = "user-123" + + search_space_data = SearchSpaceCreate(name="Test Space") + + with pytest.raises(HTTPException) as exc_info: + await create_search_space( + search_space=search_space_data, + session=mock_session, + user=mock_user, + ) + + assert exc_info.value.status_code == 500 + + +class TestReadSearchSpaces: + """Tests for the read_search_spaces endpoint.""" + + @pytest.mark.asyncio + async def test_read_search_spaces_owned_only(self): + """Test reading only owned search spaces.""" + mock_session = AsyncMock() + mock_user = MagicMock() + mock_user.id = "user-123" + + # Mock the query result + mock_result = MagicMock() + mock_scalars = MagicMock() + mock_scalars.all.return_value = [] + mock_result.scalars.return_value = mock_scalars + mock_session.execute = AsyncMock(return_value=mock_result) + + result = await read_search_spaces( + skip=0, + limit=200, + owned_only=True, + session=mock_session, + user=mock_user, + ) + + assert isinstance(result, list) + + @pytest.mark.asyncio + async def test_read_search_spaces_all_accessible(self): + """Test reading all accessible search spaces.""" + mock_session = AsyncMock() + mock_user = MagicMock() + mock_user.id = "user-123" + + # Mock the query result + mock_result = MagicMock() + mock_scalars = MagicMock() + mock_scalars.all.return_value = [] + mock_scalars.first.return_value = None + mock_result.scalars.return_value = mock_scalars + mock_result.scalar.return_value = 0 + mock_session.execute = AsyncMock(return_value=mock_result) + + result = await read_search_spaces( + skip=0, + limit=200, + owned_only=False, + session=mock_session, + user=mock_user, + ) + + assert isinstance(result, list) + + +class TestReadSearchSpace: + """Tests for the read_search_space endpoint.""" + + @pytest.mark.asyncio + async def test_read_search_space_not_found(self): + """Test reading non-existent search space.""" + mock_session = AsyncMock() + mock_user = MagicMock() + mock_user.id = "user-123" + + # Mock empty result + mock_result = MagicMock() + mock_scalars = MagicMock() + mock_scalars.first.return_value = None + mock_result.scalars.return_value = mock_scalars + mock_session.execute = AsyncMock(return_value=mock_result) + + with patch("app.routes.search_spaces_routes.check_search_space_access") as mock_check: + mock_check.return_value = None + + with pytest.raises(HTTPException) as exc_info: + await read_search_space( + search_space_id=999, + session=mock_session, + user=mock_user, + ) + + assert exc_info.value.status_code == 404 + + +class TestUpdateSearchSpace: + """Tests for the update_search_space endpoint.""" + + @pytest.mark.asyncio + async def test_update_search_space_not_found(self): + """Test updating non-existent search space.""" + mock_session = AsyncMock() + mock_user = MagicMock() + mock_user.id = "user-123" + + # Mock empty result + mock_result = MagicMock() + mock_scalars = MagicMock() + mock_scalars.first.return_value = None + mock_result.scalars.return_value = mock_scalars + mock_session.execute = AsyncMock(return_value=mock_result) + mock_session.rollback = AsyncMock() + + update_data = SearchSpaceUpdate(name="Updated Name") + + with patch("app.routes.search_spaces_routes.check_permission") as mock_check: + mock_check.return_value = None + + with pytest.raises(HTTPException) as exc_info: + await update_search_space( + search_space_id=999, + search_space_update=update_data, + session=mock_session, + user=mock_user, + ) + + assert exc_info.value.status_code == 404 + + @pytest.mark.asyncio + async def test_update_search_space_success(self): + """Test successful search space update.""" + mock_session = AsyncMock() + mock_user = MagicMock() + mock_user.id = "user-123" + + # Mock existing search space + mock_search_space = MagicMock() + mock_search_space.id = 1 + mock_search_space.name = "Old Name" + + mock_result = MagicMock() + mock_scalars = MagicMock() + mock_scalars.first.return_value = mock_search_space + mock_result.scalars.return_value = mock_scalars + mock_session.execute = AsyncMock(return_value=mock_result) + mock_session.commit = AsyncMock() + mock_session.refresh = AsyncMock() + + update_data = SearchSpaceUpdate(name="New Name") + + with patch("app.routes.search_spaces_routes.check_permission") as mock_check: + mock_check.return_value = None + + await update_search_space( + search_space_id=1, + search_space_update=update_data, + session=mock_session, + user=mock_user, + ) + + assert mock_session.commit.called + + +class TestDeleteSearchSpace: + """Tests for the delete_search_space endpoint.""" + + @pytest.mark.asyncio + async def test_delete_search_space_not_found(self): + """Test deleting non-existent search space.""" + mock_session = AsyncMock() + mock_user = MagicMock() + mock_user.id = "user-123" + + # Mock empty result + mock_result = MagicMock() + mock_scalars = MagicMock() + mock_scalars.first.return_value = None + mock_result.scalars.return_value = mock_scalars + mock_session.execute = AsyncMock(return_value=mock_result) + mock_session.rollback = AsyncMock() + + with patch("app.routes.search_spaces_routes.check_permission") as mock_check: + mock_check.return_value = None + + with pytest.raises(HTTPException) as exc_info: + await delete_search_space( + search_space_id=999, + session=mock_session, + user=mock_user, + ) + + assert exc_info.value.status_code == 404 + + @pytest.mark.asyncio + async def test_delete_search_space_success(self): + """Test successful search space deletion.""" + mock_session = AsyncMock() + mock_user = MagicMock() + mock_user.id = "user-123" + + # Mock existing search space + mock_search_space = MagicMock() + mock_search_space.id = 1 + + mock_result = MagicMock() + mock_scalars = MagicMock() + mock_scalars.first.return_value = mock_search_space + mock_result.scalars.return_value = mock_scalars + mock_session.execute = AsyncMock(return_value=mock_result) + mock_session.delete = AsyncMock() + mock_session.commit = AsyncMock() + + with patch("app.routes.search_spaces_routes.check_permission") as mock_check: + mock_check.return_value = None + + result = await delete_search_space( + search_space_id=1, + session=mock_session, + user=mock_user, + ) + + assert result["message"] == "Search space deleted successfully" + assert mock_session.delete.called + assert mock_session.commit.called