Skip to content

Commit f2346e9

Browse files
authored
♻️ Refresh the tool list when backend service is initialized #1250
2 parents 9fcceea + b6b2664 commit f2346e9

File tree

5 files changed

+331
-7
lines changed

5 files changed

+331
-7
lines changed

backend/database/user_tenant_db.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
"""
44
from typing import Any, Dict, Optional
55

6+
from consts.const import DEFAULT_TENANT_ID
67
from database.client import as_dict, get_db_session
78
from database.db_models import UserTenant
89

@@ -28,6 +29,27 @@ def get_user_tenant_by_user_id(user_id: str) -> Optional[Dict[str, Any]]:
2829
return None
2930

3031

32+
def get_all_tenant_ids() -> list[str]:
33+
"""
34+
Get all unique tenant IDs from the database
35+
36+
Returns:
37+
list[str]: List of unique tenant IDs
38+
"""
39+
with get_db_session() as session:
40+
result = session.query(UserTenant.tenant_id).filter(
41+
UserTenant.delete_flag == "N"
42+
).distinct().all()
43+
44+
tenant_ids = [row[0] for row in result]
45+
46+
# Add default tenant_id if not already in the list
47+
if DEFAULT_TENANT_ID not in tenant_ids:
48+
tenant_ids.append(DEFAULT_TENANT_ID)
49+
50+
return tenant_ids
51+
52+
3153
def insert_user_tenant(user_id: str, tenant_id: str):
3254
"""
3355
Insert user tenant relationship

backend/main_service.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,39 @@
11
import uvicorn
22
import logging
33
import warnings
4+
import asyncio
45
warnings.filterwarnings("ignore", category=UserWarning)
56

67
from dotenv import load_dotenv
78
load_dotenv()
89

910
from apps.base_app import app
1011
from utils.logging_utils import configure_logging, configure_elasticsearch_logging
12+
from services.tool_configuration_service import initialize_tools_on_startup
1113

1214

1315
configure_logging(logging.INFO)
1416
configure_elasticsearch_logging()
1517
logger = logging.getLogger("main_service")
1618

19+
20+
async def startup_initialization():
21+
"""
22+
Perform initialization tasks during server startup
23+
"""
24+
logger.info("Starting server initialization...")
25+
26+
try:
27+
# Initialize tools on startup - service layer handles detailed logging
28+
await initialize_tools_on_startup()
29+
logger.info("Server initialization completed successfully!")
30+
31+
except Exception as e:
32+
logger.error(f"Server initialization failed: {str(e)}")
33+
# Don't raise the exception to allow server to start even if initialization fails
34+
logger.warning("Server will continue to start despite initialization issues")
35+
36+
1737
if __name__ == "__main__":
18-
# Scan agents and update to database
38+
asyncio.run(startup_initialization())
1939
uvicorn.run(app, host="0.0.0.0", port=5010, log_level="info")

backend/services/tool_configuration_service.py

Lines changed: 69 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
import importlib
23
import inspect
34
import json
@@ -10,7 +11,7 @@
1011
import jsonref
1112
from mcpadapt.smolagents_adapter import _sanitize_function_name
1213

13-
from consts.const import LOCAL_MCP_SERVER
14+
from consts.const import DEFAULT_USER_ID, LOCAL_MCP_SERVER
1415
from consts.exceptions import MCPConnectionError
1516
from consts.model import ToolInstanceInfoRequest, ToolInfo, ToolSourceEnum
1617
from database.remote_mcp_db import get_mcp_records_by_tenant
@@ -20,6 +21,7 @@
2021
query_tool_instances_by_id,
2122
update_tool_table_from_scan_tool_list
2223
)
24+
from database.user_tenant_db import get_all_tenant_ids
2325

2426
logger = logging.getLogger("tool_configuration_service")
2527

@@ -346,4 +348,69 @@ async def list_all_tools(tenant_id: str):
346348
}
347349
formatted_tools.append(formatted_tool)
348350

349-
return formatted_tools
351+
return formatted_tools
352+
353+
354+
async def initialize_tools_on_startup():
355+
"""
356+
Initialize and scan all tools during server startup for all tenants
357+
358+
This function scans all available tools (local, LangChain, and MCP)
359+
and updates the database with the latest tool information for all tenants.
360+
"""
361+
362+
logger.info("Starting tool initialization on server startup...")
363+
364+
try:
365+
# Get all tenant IDs from the database
366+
tenant_ids = get_all_tenant_ids()
367+
368+
if not tenant_ids:
369+
logger.warning("No tenants found in database, skipping tool initialization")
370+
return
371+
372+
logger.info(f"Found {len(tenant_ids)} tenants: {tenant_ids}")
373+
374+
total_tools = 0
375+
successful_tenants = 0
376+
failed_tenants = []
377+
378+
# Process each tenant
379+
for tenant_id in tenant_ids:
380+
try:
381+
logger.info(f"Initializing tools for tenant: {tenant_id}")
382+
383+
# Add timeout to prevent hanging during startup
384+
try:
385+
await asyncio.wait_for(
386+
update_tool_list(tenant_id=tenant_id, user_id=DEFAULT_USER_ID),
387+
timeout=60.0 # 60 seconds timeout per tenant
388+
)
389+
390+
# Get the count of tools for this tenant
391+
tools_info = query_all_tools(tenant_id)
392+
tenant_tool_count = len(tools_info)
393+
total_tools += tenant_tool_count
394+
successful_tenants += 1
395+
396+
logger.info(f"Tenant {tenant_id}: {tenant_tool_count} tools initialized")
397+
398+
except asyncio.TimeoutError:
399+
logger.error(f"Tool initialization timed out for tenant {tenant_id}")
400+
failed_tenants.append(f"{tenant_id} (timeout)")
401+
402+
except Exception as e:
403+
logger.error(f"Tool initialization failed for tenant {tenant_id}: {str(e)}")
404+
failed_tenants.append(f"{tenant_id} (error: {str(e)})")
405+
406+
# Log final results
407+
logger.info(f"Tool initialization completed!")
408+
logger.info(f"Total tools available across all tenants: {total_tools}")
409+
logger.info(f"Successfully processed: {successful_tenants}/{len(tenant_ids)} tenants")
410+
411+
if failed_tenants:
412+
logger.warning(f"Failed tenants: {', '.join(failed_tenants)}")
413+
414+
except Exception as e:
415+
logger.error(f"❌ Tool initialization failed: {str(e)}")
416+
raise

test/backend/database/test_user_tenant_db.py

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,8 @@
6666
# Now import the functions to be tested
6767
from backend.database.user_tenant_db import (
6868
get_user_tenant_by_user_id,
69-
insert_user_tenant
69+
insert_user_tenant,
70+
get_all_tenant_ids
7071
)
7172

7273
class MockUserTenant:
@@ -301,4 +302,50 @@ def test_get_user_tenant_by_user_id_with_deleted_record(monkeypatch, mock_sessio
301302

302303
assert result is None
303304
# Verify that the filter was called with correct conditions
304-
query.filter.assert_called_once()
305+
query.filter.assert_called_once()
306+
307+
308+
def test_get_all_tenant_ids_empty_database(monkeypatch, mock_session):
309+
"""Test get_all_tenant_ids when database is empty - should return only DEFAULT_TENANT_ID"""
310+
session, query = mock_session
311+
312+
# Mock empty database result
313+
query.filter.return_value.distinct.return_value.all.return_value = []
314+
315+
mock_ctx = MagicMock()
316+
mock_ctx.__enter__.return_value = session
317+
mock_ctx.__exit__.return_value = None
318+
monkeypatch.setattr("backend.database.user_tenant_db.get_db_session", lambda: mock_ctx)
319+
320+
result = get_all_tenant_ids()
321+
322+
assert result == ["default_tenant"] # DEFAULT_TENANT_ID from consts_mock
323+
assert len(result) == 1
324+
325+
326+
def test_get_all_tenant_ids_with_existing_tenants(monkeypatch, mock_session):
327+
"""Test get_all_tenant_ids with existing tenants - should include all plus DEFAULT_TENANT_ID"""
328+
session, query = mock_session
329+
330+
# Mock database result with existing tenants
331+
mock_tenants = [
332+
("tenant_1",),
333+
("tenant_2",),
334+
("tenant_3",)
335+
]
336+
query.filter.return_value.distinct.return_value.all.return_value = mock_tenants
337+
338+
mock_ctx = MagicMock()
339+
mock_ctx.__enter__.return_value = session
340+
mock_ctx.__exit__.return_value = None
341+
monkeypatch.setattr("backend.database.user_tenant_db.get_db_session", lambda: mock_ctx)
342+
343+
result = get_all_tenant_ids()
344+
345+
assert len(result) == 4 # 3 existing + 1 default
346+
assert "tenant_1" in result
347+
assert "tenant_2" in result
348+
assert "tenant_3" in result
349+
assert "default_tenant" in result # DEFAULT_TENANT_ID from consts_mock
350+
# Should not duplicate DEFAULT_TENANT_ID
351+
assert result.count("default_tenant") == 1

0 commit comments

Comments
 (0)