Skip to content

Commit f7c0eb7

Browse files
authored
✨ use specify the model when running agent;long-text model use main model default
2 parents 0e9d162 + cb11377 commit f7c0eb7

File tree

2 files changed

+216
-34
lines changed

2 files changed

+216
-34
lines changed

backend/agents/create_agent_info.py

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
from services.memory_config_service import build_memory_context
1616
from database.agent_db import search_agent_info_by_agent_id, query_sub_agents_id_list
1717
from database.tool_db import search_tools_for_sub_agent
18+
from database.model_management_db import get_model_records
19+
from utils.model_name_utils import add_repo_to_name
1820
from utils.prompt_template_utils import get_agent_prompt_template
1921
from utils.config_utils import tenant_config_manager, get_model_name_from_config
2022
from consts.const import LOCAL_MCP_SERVER, MODEL_CONFIG_MAPPING, LANGUAGE
@@ -24,21 +26,34 @@
2426

2527

2628
async def create_model_config_list(tenant_id):
29+
records = get_model_records({"model_type": "llm"}, tenant_id)
30+
model_list = []
31+
for record in records:
32+
model_list.append(
33+
ModelConfig(cite_name=record["display_name"],
34+
api_key=record.get("api_key", ""),
35+
model_name=add_repo_to_name(
36+
model_repo=record["model_repo"],
37+
model_name=record["model_name"],
38+
),
39+
url=record["base_url"]))
40+
# fit for old version, main_model and sub_model use default model
2741
main_model_config = tenant_config_manager.get_model_config(
2842
key=MODEL_CONFIG_MAPPING["llm"], tenant_id=tenant_id)
29-
sub_model_config = tenant_config_manager.get_model_config(
30-
key=MODEL_CONFIG_MAPPING["llmSecondary"], tenant_id=tenant_id)
31-
32-
return [ModelConfig(cite_name="main_model",
33-
api_key=main_model_config.get("api_key", ""),
34-
model_name=get_model_name_from_config(main_model_config) if main_model_config.get(
35-
"model_name") else "",
36-
url=main_model_config.get("base_url", "")),
37-
ModelConfig(cite_name="sub_model",
38-
api_key=sub_model_config.get("api_key", ""),
39-
model_name=get_model_name_from_config(sub_model_config) if sub_model_config.get(
40-
"model_name") else "",
41-
url=sub_model_config.get("base_url", ""))]
43+
model_list.append(
44+
ModelConfig(cite_name="main_model",
45+
api_key=main_model_config.get("api_key", ""),
46+
model_name=get_model_name_from_config(main_model_config) if main_model_config.get(
47+
"model_name") else "",
48+
url=main_model_config.get("base_url", "")))
49+
model_list.append(
50+
ModelConfig(cite_name="sub_model",
51+
api_key=main_model_config.get("api_key", ""),
52+
model_name=get_model_name_from_config(main_model_config) if main_model_config.get(
53+
"model_name") else "",
54+
url=main_model_config.get("base_url", "")))
55+
56+
return model_list
4257

4358

4459
async def create_agent_config(

test/backend/agents/test_create_agent_info.py

Lines changed: 188 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2,25 +2,72 @@
22
import sys
33
from unittest.mock import AsyncMock, MagicMock, patch, Mock, PropertyMock
44

5+
# Mock consts module first to avoid ModuleNotFoundError
6+
consts_mock = MagicMock()
7+
consts_mock.const = MagicMock()
8+
# Set required constants in consts.const
9+
consts_mock.const.MINIO_ENDPOINT = "http://localhost:9000"
10+
consts_mock.const.MINIO_ACCESS_KEY = "test_access_key"
11+
consts_mock.const.MINIO_SECRET_KEY = "test_secret_key"
12+
consts_mock.const.MINIO_REGION = "us-east-1"
13+
consts_mock.const.MINIO_DEFAULT_BUCKET = "test-bucket"
14+
consts_mock.const.POSTGRES_HOST = "localhost"
15+
consts_mock.const.POSTGRES_USER = "test_user"
16+
consts_mock.const.NEXENT_POSTGRES_PASSWORD = "test_password"
17+
consts_mock.const.POSTGRES_DB = "test_db"
18+
consts_mock.const.POSTGRES_PORT = 5432
19+
consts_mock.const.DEFAULT_TENANT_ID = "default_tenant"
20+
consts_mock.const.LOCAL_MCP_SERVER = "http://localhost:5011"
21+
consts_mock.const.MODEL_CONFIG_MAPPING = {"llm": "llm_config"}
22+
consts_mock.const.LANGUAGE = {"ZH": "zh"}
23+
24+
# Add the mocked consts module to sys.modules
25+
sys.modules['consts'] = consts_mock
26+
sys.modules['consts.const'] = consts_mock.const
27+
28+
# Mock utils module
29+
utils_mock = MagicMock()
30+
utils_mock.auth_utils = MagicMock()
31+
utils_mock.auth_utils.get_current_user_id = MagicMock(return_value=("test_user_id", "test_tenant_id"))
32+
33+
# Add the mocked utils module to sys.modules
34+
sys.modules['utils'] = utils_mock
35+
sys.modules['utils.auth_utils'] = utils_mock.auth_utils
36+
37+
# Provide a stub for the `boto3` module so that it can be imported safely even
38+
# if the testing environment does not have it available.
39+
boto3_mock = MagicMock()
40+
sys.modules['boto3'] = boto3_mock
41+
42+
# Mock the entire client module
43+
client_mock = MagicMock()
44+
client_mock.MinioClient = MagicMock()
45+
client_mock.PostgresClient = MagicMock()
46+
client_mock.db_client = MagicMock()
47+
client_mock.get_db_session = MagicMock()
48+
client_mock.as_dict = MagicMock()
49+
50+
# Add the mocked client module to sys.modules
51+
sys.modules['backend.database.client'] = client_mock
552

653
# Mock external dependencies before imports
754
sys.modules['nexent.core.utils.observer'] = MagicMock()
855
sys.modules['nexent.core.agents.agent_model'] = MagicMock()
956
sys.modules['smolagents.agents'] = MagicMock()
1057
sys.modules['smolagents.utils'] = MagicMock()
1158
sys.modules['services.remote_mcp_service'] = MagicMock()
12-
sys.modules['utils.auth_utils'] = MagicMock()
1359
sys.modules['database.agent_db'] = MagicMock()
1460
sys.modules['database.tool_db'] = MagicMock()
61+
sys.modules['database.model_management_db'] = MagicMock()
1562
sys.modules['services.elasticsearch_service'] = MagicMock()
1663
sys.modules['services.tenant_config_service'] = MagicMock()
1764
sys.modules['utils.prompt_template_utils'] = MagicMock()
1865
sys.modules['utils.config_utils'] = MagicMock()
1966
sys.modules['utils.langchain_utils'] = MagicMock()
67+
sys.modules['utils.model_name_utils'] = MagicMock()
2068
sys.modules['langchain_core.tools'] = MagicMock()
2169
sys.modules['services.memory_config_service'] = MagicMock()
2270
sys.modules['nexent.memory.memory_service'] = MagicMock()
23-
sys.modules['consts.const'] = MagicMock()
2471

2572
# Create mock classes that might be imported
2673
mock_agent_config = MagicMock()
@@ -38,9 +85,6 @@
3885
# Mock BASE_BUILTIN_MODULES
3986
sys.modules['smolagents.utils'].BASE_BUILTIN_MODULES = ["os", "sys", "json"]
4087

41-
# Mock LOCAL_MCP_SERVER constant
42-
sys.modules['consts.const'].LOCAL_MCP_SERVER = "http://localhost:5011"
43-
4488
# Now import the module under test
4589
from backend.agents.create_agent_info import (
4690
discover_langchain_tools,
@@ -53,6 +97,9 @@
5397
prepare_prompt_templates
5498
)
5599

100+
# Import constants for testing
101+
from consts.const import MODEL_CONFIG_MAPPING
102+
56103

57104
class TestDiscoverLangchainTools:
58105
"""Tests for the discover_langchain_tools function"""
@@ -592,33 +639,153 @@ class TestCreateModelConfigList:
592639
@pytest.mark.asyncio
593640
async def test_create_model_config_list(self):
594641
"""Test case for model configuration list creation"""
595-
with patch('backend.agents.create_agent_info.tenant_config_manager') as mock_manager, \
596-
patch('backend.agents.create_agent_info.get_model_name_from_config') as mock_get_model_name:
597-
598-
# Set mock return values
599-
mock_manager.get_model_config.side_effect = [
642+
# Reset mock call count before test
643+
mock_model_config.reset_mock()
644+
645+
with patch('backend.agents.create_agent_info.get_model_records') as mock_get_records, \
646+
patch('backend.agents.create_agent_info.tenant_config_manager') as mock_manager, \
647+
patch('backend.agents.create_agent_info.get_model_name_from_config') as mock_get_model_name, \
648+
patch('backend.agents.create_agent_info.add_repo_to_name') as mock_add_repo:
649+
650+
# Mock database records
651+
mock_get_records.return_value = [
600652
{
601-
"api_key": "main_key",
602-
"model_name": "main_model",
603-
"base_url": "http://main.url",
604-
"is_deep_thinking": True
653+
"display_name": "GPT-4",
654+
"api_key": "gpt4_key",
655+
"model_repo": "openai",
656+
"model_name": "gpt-4",
657+
"base_url": "https://api.openai.com"
605658
},
606659
{
607-
"api_key": "sub_key",
608-
"model_name": "sub_model",
609-
"base_url": "http://sub.url",
610-
"is_deep_thinking": False
660+
"display_name": "Claude",
661+
"api_key": "claude_key",
662+
"model_repo": "anthropic",
663+
"model_name": "claude-3",
664+
"base_url": "https://api.anthropic.com"
611665
}
612666
]
613667

614-
mock_get_model_name.side_effect = [
615-
"main_model_name", "sub_model_name"]
668+
# Mock tenant config for main_model and sub_model
669+
mock_manager.get_model_config.return_value = {
670+
"api_key": "main_key",
671+
"model_name": "main_model",
672+
"base_url": "http://main.url"
673+
}
674+
675+
# Mock utility functions
676+
mock_add_repo.side_effect = ["openai/gpt-4", "anthropic/claude-3"]
677+
mock_get_model_name.return_value = "main_model_name"
678+
679+
result = await create_model_config_list("tenant_1")
680+
681+
# Should have 4 models: 2 from database + 2 default (main_model, sub_model)
682+
assert len(result) == 4
683+
684+
# Verify get_model_records was called correctly
685+
mock_get_records.assert_called_once_with({"model_type": "llm"}, "tenant_1")
686+
687+
# Verify tenant_config_manager was called for default models
688+
mock_manager.get_model_config.assert_called_once_with(
689+
key=MODEL_CONFIG_MAPPING["llm"], tenant_id="tenant_1")
690+
691+
# Verify ModelConfig was called 4 times
692+
assert mock_model_config.call_count == 4
693+
694+
# Verify the calls to ModelConfig
695+
calls = mock_model_config.call_args_list
696+
697+
# First call: GPT-4 model from database
698+
assert calls[0][1]['cite_name'] == "GPT-4"
699+
assert calls[0][1]['api_key'] == "gpt4_key"
700+
assert calls[0][1]['model_name'] == "openai/gpt-4"
701+
assert calls[0][1]['url'] == "https://api.openai.com"
702+
703+
# Second call: Claude model from database
704+
assert calls[1][1]['cite_name'] == "Claude"
705+
assert calls[1][1]['api_key'] == "claude_key"
706+
assert calls[1][1]['model_name'] == "anthropic/claude-3"
707+
assert calls[1][1]['url'] == "https://api.anthropic.com"
708+
709+
# Third call: main_model
710+
assert calls[2][1]['cite_name'] == "main_model"
711+
assert calls[2][1]['api_key'] == "main_key"
712+
assert calls[2][1]['model_name'] == "main_model_name"
713+
assert calls[2][1]['url'] == "http://main.url"
714+
715+
# Fourth call: sub_model
716+
assert calls[3][1]['cite_name'] == "sub_model"
717+
assert calls[3][1]['api_key'] == "main_key"
718+
assert calls[3][1]['model_name'] == "main_model_name"
719+
assert calls[3][1]['url'] == "http://main.url"
720+
721+
@pytest.mark.asyncio
722+
async def test_create_model_config_list_empty_database(self):
723+
"""Test case when database returns no records"""
724+
# Reset mock call count before test
725+
mock_model_config.reset_mock()
726+
727+
with patch('backend.agents.create_agent_info.get_model_records') as mock_get_records, \
728+
patch('backend.agents.create_agent_info.tenant_config_manager') as mock_manager, \
729+
patch('backend.agents.create_agent_info.get_model_name_from_config') as mock_get_model_name:
730+
731+
# Mock empty database records
732+
mock_get_records.return_value = []
733+
734+
# Mock tenant config for main_model and sub_model
735+
mock_manager.get_model_config.return_value = {
736+
"api_key": "main_key",
737+
"model_name": "main_model",
738+
"base_url": "http://main.url"
739+
}
740+
741+
mock_get_model_name.return_value = "main_model_name"
742+
743+
result = await create_model_config_list("tenant_1")
744+
745+
# Should have 2 models: only default models (main_model, sub_model)
746+
assert len(result) == 2
747+
748+
# Verify ModelConfig was called 2 times
749+
assert mock_model_config.call_count == 2
750+
751+
# Verify both calls are for default models
752+
calls = mock_model_config.call_args_list
753+
assert calls[0][1]['cite_name'] == "main_model"
754+
assert calls[1][1]['cite_name'] == "sub_model"
755+
756+
@pytest.mark.asyncio
757+
async def test_create_model_config_list_no_model_name_in_config(self):
758+
"""Test case when tenant config has no model_name"""
759+
# Reset mock call count before test
760+
mock_model_config.reset_mock()
761+
762+
with patch('backend.agents.create_agent_info.get_model_records') as mock_get_records, \
763+
patch('backend.agents.create_agent_info.tenant_config_manager') as mock_manager, \
764+
patch('backend.agents.create_agent_info.get_model_name_from_config') as mock_get_model_name:
765+
766+
# Mock empty database records
767+
mock_get_records.return_value = []
768+
769+
# Mock tenant config without model_name
770+
mock_manager.get_model_config.return_value = {
771+
"api_key": "main_key",
772+
"base_url": "http://main.url"
773+
# No model_name field
774+
}
616775

617776
result = await create_model_config_list("tenant_1")
618777

778+
# Should have 2 models: only default models (main_model, sub_model)
619779
assert len(result) == 2
620-
# Verify that ModelConfig was called twice
780+
781+
# Verify ModelConfig was called 2 times with empty model_name
621782
assert mock_model_config.call_count == 2
783+
784+
calls = mock_model_config.call_args_list
785+
assert calls[0][1]['cite_name'] == "main_model"
786+
assert calls[0][1]['model_name'] == "" # Should be empty when no model_name in config
787+
assert calls[1][1]['cite_name'] == "sub_model"
788+
assert calls[1][1]['model_name'] == "" # Should be empty when no model_name in config
622789

623790

624791
class TestFilterMcpServersAndTools:

0 commit comments

Comments
 (0)