|
| 1 | +import asyncio |
1 | 2 | import inspect |
| 3 | +import sys |
2 | 4 | import unittest |
3 | | -from unittest.mock import Mock, patch, MagicMock, AsyncMock |
4 | 5 | from typing import Any, List, Dict |
5 | | -import sys |
| 6 | +from unittest.mock import AsyncMock, MagicMock, Mock, patch |
| 7 | + |
6 | 8 | import pytest |
7 | 9 |
|
8 | 10 | boto3_mock = MagicMock() |
@@ -903,5 +905,171 @@ def __init__(self): |
903 | 905 | assert mock_build_tool_info.call_count == 2 |
904 | 906 |
|
905 | 907 |
|
| 908 | +class TestInitializeToolsOnStartup: |
| 909 | + """Test cases for initialize_tools_on_startup function""" |
| 910 | + |
| 911 | + @patch('backend.services.tool_configuration_service.get_all_tenant_ids') |
| 912 | + @patch('backend.services.tool_configuration_service.update_tool_list') |
| 913 | + @patch('backend.services.tool_configuration_service.query_all_tools') |
| 914 | + @patch('backend.services.tool_configuration_service.logger') |
| 915 | + async def test_initialize_tools_on_startup_no_tenants(self, mock_logger, mock_query_tools, mock_update_tool_list, mock_get_tenants): |
| 916 | + """Test initialize_tools_on_startup when no tenants are found""" |
| 917 | + # Mock get_all_tenant_ids to return empty list |
| 918 | + mock_get_tenants.return_value = [] |
| 919 | + |
| 920 | + # Import and call the function |
| 921 | + from backend.services.tool_configuration_service import initialize_tools_on_startup |
| 922 | + await initialize_tools_on_startup() |
| 923 | + |
| 924 | + # Verify warning was logged |
| 925 | + mock_logger.warning.assert_called_with("No tenants found in database, skipping tool initialization") |
| 926 | + mock_update_tool_list.assert_not_called() |
| 927 | + |
| 928 | + @patch('backend.services.tool_configuration_service.get_all_tenant_ids') |
| 929 | + @patch('backend.services.tool_configuration_service.update_tool_list') |
| 930 | + @patch('backend.services.tool_configuration_service.query_all_tools') |
| 931 | + @patch('backend.services.tool_configuration_service.logger') |
| 932 | + async def test_initialize_tools_on_startup_success(self, mock_logger, mock_query_tools, mock_update_tool_list, mock_get_tenants): |
| 933 | + """Test successful tool initialization for all tenants""" |
| 934 | + # Mock tenant IDs |
| 935 | + tenant_ids = ["tenant_1", "tenant_2", "default_tenant"] |
| 936 | + mock_get_tenants.return_value = tenant_ids |
| 937 | + |
| 938 | + # Mock update_tool_list to succeed |
| 939 | + mock_update_tool_list.return_value = None |
| 940 | + |
| 941 | + # Mock query_all_tools to return mock tools |
| 942 | + mock_tools = [ |
| 943 | + {"tool_id": "tool_1", "name": "Test Tool 1"}, |
| 944 | + {"tool_id": "tool_2", "name": "Test Tool 2"} |
| 945 | + ] |
| 946 | + mock_query_tools.return_value = mock_tools |
| 947 | + |
| 948 | + # Import and call the function |
| 949 | + from backend.services.tool_configuration_service import initialize_tools_on_startup |
| 950 | + await initialize_tools_on_startup() |
| 951 | + |
| 952 | + # Verify update_tool_list was called for each tenant |
| 953 | + assert mock_update_tool_list.call_count == len(tenant_ids) |
| 954 | + |
| 955 | + # Verify success logging |
| 956 | + mock_logger.info.assert_any_call("Tool initialization completed!") |
| 957 | + mock_logger.info.assert_any_call("Total tools available across all tenants: 6") # 2 tools * 3 tenants |
| 958 | + mock_logger.info.assert_any_call("Successfully processed: 3/3 tenants") |
| 959 | + |
| 960 | + @patch('backend.services.tool_configuration_service.get_all_tenant_ids') |
| 961 | + @patch('backend.services.tool_configuration_service.update_tool_list') |
| 962 | + @patch('backend.services.tool_configuration_service.logger') |
| 963 | + async def test_initialize_tools_on_startup_timeout(self, mock_logger, mock_update_tool_list, mock_get_tenants): |
| 964 | + """Test tool initialization timeout scenario""" |
| 965 | + tenant_ids = ["tenant_1", "tenant_2"] |
| 966 | + mock_get_tenants.return_value = tenant_ids |
| 967 | + |
| 968 | + # Mock update_tool_list to timeout |
| 969 | + mock_update_tool_list.side_effect = asyncio.TimeoutError() |
| 970 | + |
| 971 | + # Import and call the function |
| 972 | + from backend.services.tool_configuration_service import initialize_tools_on_startup |
| 973 | + await initialize_tools_on_startup() |
| 974 | + |
| 975 | + # Verify timeout error was logged for each tenant |
| 976 | + assert mock_logger.error.call_count == len(tenant_ids) |
| 977 | + for call in mock_logger.error.call_args_list: |
| 978 | + assert "timed out" in str(call) |
| 979 | + |
| 980 | + # Verify failed tenants were logged |
| 981 | + mock_logger.warning.assert_called_once() |
| 982 | + warning_call = mock_logger.warning.call_args[0][0] |
| 983 | + assert "Failed tenants:" in warning_call |
| 984 | + assert "tenant_1 (timeout)" in warning_call |
| 985 | + assert "tenant_2 (timeout)" in warning_call |
| 986 | + |
| 987 | + @patch('backend.services.tool_configuration_service.get_all_tenant_ids') |
| 988 | + @patch('backend.services.tool_configuration_service.update_tool_list') |
| 989 | + @patch('backend.services.tool_configuration_service.logger') |
| 990 | + async def test_initialize_tools_on_startup_exception(self, mock_logger, mock_update_tool_list, mock_get_tenants): |
| 991 | + """Test tool initialization with exception during processing""" |
| 992 | + tenant_ids = ["tenant_1", "tenant_2"] |
| 993 | + mock_get_tenants.return_value = tenant_ids |
| 994 | + |
| 995 | + # Mock update_tool_list to raise exception |
| 996 | + mock_update_tool_list.side_effect = Exception("Database connection failed") |
| 997 | + |
| 998 | + # Import and call the function |
| 999 | + from backend.services.tool_configuration_service import initialize_tools_on_startup |
| 1000 | + await initialize_tools_on_startup() |
| 1001 | + |
| 1002 | + # Verify exception error was logged for each tenant |
| 1003 | + assert mock_logger.error.call_count == len(tenant_ids) |
| 1004 | + for call in mock_logger.error.call_args_list: |
| 1005 | + assert "Tool initialization failed" in str(call) |
| 1006 | + assert "Database connection failed" in str(call) |
| 1007 | + |
| 1008 | + # Verify failed tenants were logged |
| 1009 | + mock_logger.warning.assert_called_once() |
| 1010 | + warning_call = mock_logger.warning.call_args[0][0] |
| 1011 | + assert "Failed tenants:" in warning_call |
| 1012 | + assert "tenant_1 (error: Database connection failed)" in warning_call |
| 1013 | + assert "tenant_2 (error: Database connection failed)" in warning_call |
| 1014 | + |
| 1015 | + @patch('backend.services.tool_configuration_service.get_all_tenant_ids') |
| 1016 | + @patch('backend.services.tool_configuration_service.logger') |
| 1017 | + async def test_initialize_tools_on_startup_critical_exception(self, mock_logger, mock_get_tenants): |
| 1018 | + """Test tool initialization when get_all_tenant_ids raises exception""" |
| 1019 | + # Mock get_all_tenant_ids to raise exception |
| 1020 | + mock_get_tenants.side_effect = Exception("Database connection failed") |
| 1021 | + |
| 1022 | + # Import and call the function |
| 1023 | + from backend.services.tool_configuration_service import initialize_tools_on_startup |
| 1024 | + |
| 1025 | + # Should raise the exception |
| 1026 | + with pytest.raises(Exception, match="Database connection failed"): |
| 1027 | + await initialize_tools_on_startup() |
| 1028 | + |
| 1029 | + # Verify critical error was logged |
| 1030 | + mock_logger.error.assert_called_with("❌ Tool initialization failed: Database connection failed") |
| 1031 | + |
| 1032 | + @patch('backend.services.tool_configuration_service.get_all_tenant_ids') |
| 1033 | + @patch('backend.services.tool_configuration_service.update_tool_list') |
| 1034 | + @patch('backend.services.tool_configuration_service.query_all_tools') |
| 1035 | + @patch('backend.services.tool_configuration_service.logger') |
| 1036 | + async def test_initialize_tools_on_startup_mixed_results(self, mock_logger, mock_query_tools, mock_update_tool_list, mock_get_tenants): |
| 1037 | + """Test tool initialization with mixed success and failure results""" |
| 1038 | + tenant_ids = ["tenant_1", "tenant_2", "tenant_3"] |
| 1039 | + mock_get_tenants.return_value = tenant_ids |
| 1040 | + |
| 1041 | + # Mock update_tool_list with mixed results |
| 1042 | + def side_effect(*args, **kwargs): |
| 1043 | + tenant_id = kwargs.get('tenant_id') |
| 1044 | + if tenant_id == "tenant_1": |
| 1045 | + return None # Success |
| 1046 | + elif tenant_id == "tenant_2": |
| 1047 | + raise asyncio.TimeoutError() # Timeout |
| 1048 | + else: # tenant_3 |
| 1049 | + raise Exception("Connection error") # Exception |
| 1050 | + |
| 1051 | + mock_update_tool_list.side_effect = side_effect |
| 1052 | + |
| 1053 | + # Mock query_all_tools for successful tenant |
| 1054 | + mock_tools = [{"tool_id": "tool_1", "name": "Test Tool"}] |
| 1055 | + mock_query_tools.return_value = mock_tools |
| 1056 | + |
| 1057 | + # Import and call the function |
| 1058 | + from backend.services.tool_configuration_service import initialize_tools_on_startup |
| 1059 | + await initialize_tools_on_startup() |
| 1060 | + |
| 1061 | + # Verify mixed results logging |
| 1062 | + mock_logger.info.assert_any_call("Tool initialization completed!") |
| 1063 | + mock_logger.info.assert_any_call("Total tools available across all tenants: 1") |
| 1064 | + mock_logger.info.assert_any_call("Successfully processed: 1/3 tenants") |
| 1065 | + |
| 1066 | + # Verify failed tenants were logged |
| 1067 | + mock_logger.warning.assert_called_once() |
| 1068 | + warning_call = mock_logger.warning.call_args[0][0] |
| 1069 | + assert "Failed tenants:" in warning_call |
| 1070 | + assert "tenant_2 (timeout)" in warning_call |
| 1071 | + assert "tenant_3 (error: Connection error)" in warning_call |
| 1072 | + |
| 1073 | + |
906 | 1074 | if __name__ == '__main__': |
907 | 1075 | unittest.main() |
0 commit comments