|
2 | 2 | import sys |
3 | 3 | from unittest.mock import AsyncMock, MagicMock, patch, Mock, PropertyMock |
4 | 4 |
|
| 5 | +# Mock consts module first to avoid ModuleNotFoundError |
| 6 | +consts_mock = MagicMock() |
| 7 | +consts_mock.const = MagicMock() |
| 8 | +# Set required constants in consts.const |
| 9 | +consts_mock.const.MINIO_ENDPOINT = "http://localhost:9000" |
| 10 | +consts_mock.const.MINIO_ACCESS_KEY = "test_access_key" |
| 11 | +consts_mock.const.MINIO_SECRET_KEY = "test_secret_key" |
| 12 | +consts_mock.const.MINIO_REGION = "us-east-1" |
| 13 | +consts_mock.const.MINIO_DEFAULT_BUCKET = "test-bucket" |
| 14 | +consts_mock.const.POSTGRES_HOST = "localhost" |
| 15 | +consts_mock.const.POSTGRES_USER = "test_user" |
| 16 | +consts_mock.const.NEXENT_POSTGRES_PASSWORD = "test_password" |
| 17 | +consts_mock.const.POSTGRES_DB = "test_db" |
| 18 | +consts_mock.const.POSTGRES_PORT = 5432 |
| 19 | +consts_mock.const.DEFAULT_TENANT_ID = "default_tenant" |
| 20 | +consts_mock.const.LOCAL_MCP_SERVER = "http://localhost:5011" |
| 21 | +consts_mock.const.MODEL_CONFIG_MAPPING = {"llm": "llm_config"} |
| 22 | +consts_mock.const.LANGUAGE = {"ZH": "zh"} |
| 23 | + |
| 24 | +# Add the mocked consts module to sys.modules |
| 25 | +sys.modules['consts'] = consts_mock |
| 26 | +sys.modules['consts.const'] = consts_mock.const |
| 27 | + |
| 28 | +# Mock utils module |
| 29 | +utils_mock = MagicMock() |
| 30 | +utils_mock.auth_utils = MagicMock() |
| 31 | +utils_mock.auth_utils.get_current_user_id = MagicMock(return_value=("test_user_id", "test_tenant_id")) |
| 32 | + |
| 33 | +# Add the mocked utils module to sys.modules |
| 34 | +sys.modules['utils'] = utils_mock |
| 35 | +sys.modules['utils.auth_utils'] = utils_mock.auth_utils |
| 36 | + |
| 37 | +# Provide a stub for the `boto3` module so that it can be imported safely even |
| 38 | +# if the testing environment does not have it available. |
| 39 | +boto3_mock = MagicMock() |
| 40 | +sys.modules['boto3'] = boto3_mock |
| 41 | + |
| 42 | +# Mock the entire client module |
| 43 | +client_mock = MagicMock() |
| 44 | +client_mock.MinioClient = MagicMock() |
| 45 | +client_mock.PostgresClient = MagicMock() |
| 46 | +client_mock.db_client = MagicMock() |
| 47 | +client_mock.get_db_session = MagicMock() |
| 48 | +client_mock.as_dict = MagicMock() |
| 49 | + |
| 50 | +# Add the mocked client module to sys.modules |
| 51 | +sys.modules['backend.database.client'] = client_mock |
5 | 52 |
|
6 | 53 | # Mock external dependencies before imports |
7 | 54 | sys.modules['nexent.core.utils.observer'] = MagicMock() |
8 | 55 | sys.modules['nexent.core.agents.agent_model'] = MagicMock() |
9 | 56 | sys.modules['smolagents.agents'] = MagicMock() |
10 | 57 | sys.modules['smolagents.utils'] = MagicMock() |
11 | 58 | sys.modules['services.remote_mcp_service'] = MagicMock() |
12 | | -sys.modules['utils.auth_utils'] = MagicMock() |
13 | 59 | sys.modules['database.agent_db'] = MagicMock() |
14 | 60 | sys.modules['database.tool_db'] = MagicMock() |
| 61 | +sys.modules['database.model_management_db'] = MagicMock() |
15 | 62 | sys.modules['services.elasticsearch_service'] = MagicMock() |
16 | 63 | sys.modules['services.tenant_config_service'] = MagicMock() |
17 | 64 | sys.modules['utils.prompt_template_utils'] = MagicMock() |
18 | 65 | sys.modules['utils.config_utils'] = MagicMock() |
19 | 66 | sys.modules['utils.langchain_utils'] = MagicMock() |
| 67 | +sys.modules['utils.model_name_utils'] = MagicMock() |
20 | 68 | sys.modules['langchain_core.tools'] = MagicMock() |
21 | 69 | sys.modules['services.memory_config_service'] = MagicMock() |
22 | 70 | sys.modules['nexent.memory.memory_service'] = MagicMock() |
23 | | -sys.modules['consts.const'] = MagicMock() |
24 | 71 |
|
25 | 72 | # Create mock classes that might be imported |
26 | 73 | mock_agent_config = MagicMock() |
|
38 | 85 | # Mock BASE_BUILTIN_MODULES |
39 | 86 | sys.modules['smolagents.utils'].BASE_BUILTIN_MODULES = ["os", "sys", "json"] |
40 | 87 |
|
41 | | -# Mock LOCAL_MCP_SERVER constant |
42 | | -sys.modules['consts.const'].LOCAL_MCP_SERVER = "http://localhost:5011" |
43 | | - |
44 | 88 | # Now import the module under test |
45 | 89 | from backend.agents.create_agent_info import ( |
46 | 90 | discover_langchain_tools, |
|
53 | 97 | prepare_prompt_templates |
54 | 98 | ) |
55 | 99 |
|
| 100 | +# Import constants for testing |
| 101 | +from consts.const import MODEL_CONFIG_MAPPING |
| 102 | + |
56 | 103 |
|
57 | 104 | class TestDiscoverLangchainTools: |
58 | 105 | """Tests for the discover_langchain_tools function""" |
@@ -592,33 +639,153 @@ class TestCreateModelConfigList: |
592 | 639 | @pytest.mark.asyncio |
593 | 640 | async def test_create_model_config_list(self): |
594 | 641 | """Test case for model configuration list creation""" |
595 | | - with patch('backend.agents.create_agent_info.tenant_config_manager') as mock_manager, \ |
596 | | - patch('backend.agents.create_agent_info.get_model_name_from_config') as mock_get_model_name: |
597 | | - |
598 | | - # Set mock return values |
599 | | - mock_manager.get_model_config.side_effect = [ |
| 642 | + # Reset mock call count before test |
| 643 | + mock_model_config.reset_mock() |
| 644 | + |
| 645 | + with patch('backend.agents.create_agent_info.get_model_records') as mock_get_records, \ |
| 646 | + patch('backend.agents.create_agent_info.tenant_config_manager') as mock_manager, \ |
| 647 | + patch('backend.agents.create_agent_info.get_model_name_from_config') as mock_get_model_name, \ |
| 648 | + patch('backend.agents.create_agent_info.add_repo_to_name') as mock_add_repo: |
| 649 | + |
| 650 | + # Mock database records |
| 651 | + mock_get_records.return_value = [ |
600 | 652 | { |
601 | | - "api_key": "main_key", |
602 | | - "model_name": "main_model", |
603 | | - "base_url": "http://main.url", |
604 | | - "is_deep_thinking": True |
| 653 | + "display_name": "GPT-4", |
| 654 | + "api_key": "gpt4_key", |
| 655 | + "model_repo": "openai", |
| 656 | + "model_name": "gpt-4", |
| 657 | + "base_url": "https://api.openai.com" |
605 | 658 | }, |
606 | 659 | { |
607 | | - "api_key": "sub_key", |
608 | | - "model_name": "sub_model", |
609 | | - "base_url": "http://sub.url", |
610 | | - "is_deep_thinking": False |
| 660 | + "display_name": "Claude", |
| 661 | + "api_key": "claude_key", |
| 662 | + "model_repo": "anthropic", |
| 663 | + "model_name": "claude-3", |
| 664 | + "base_url": "https://api.anthropic.com" |
611 | 665 | } |
612 | 666 | ] |
613 | 667 |
|
614 | | - mock_get_model_name.side_effect = [ |
615 | | - "main_model_name", "sub_model_name"] |
| 668 | + # Mock tenant config for main_model and sub_model |
| 669 | + mock_manager.get_model_config.return_value = { |
| 670 | + "api_key": "main_key", |
| 671 | + "model_name": "main_model", |
| 672 | + "base_url": "http://main.url" |
| 673 | + } |
| 674 | + |
| 675 | + # Mock utility functions |
| 676 | + mock_add_repo.side_effect = ["openai/gpt-4", "anthropic/claude-3"] |
| 677 | + mock_get_model_name.return_value = "main_model_name" |
| 678 | + |
| 679 | + result = await create_model_config_list("tenant_1") |
| 680 | + |
| 681 | + # Should have 4 models: 2 from database + 2 default (main_model, sub_model) |
| 682 | + assert len(result) == 4 |
| 683 | + |
| 684 | + # Verify get_model_records was called correctly |
| 685 | + mock_get_records.assert_called_once_with({"model_type": "llm"}, "tenant_1") |
| 686 | + |
| 687 | + # Verify tenant_config_manager was called for default models |
| 688 | + mock_manager.get_model_config.assert_called_once_with( |
| 689 | + key=MODEL_CONFIG_MAPPING["llm"], tenant_id="tenant_1") |
| 690 | + |
| 691 | + # Verify ModelConfig was called 4 times |
| 692 | + assert mock_model_config.call_count == 4 |
| 693 | + |
| 694 | + # Verify the calls to ModelConfig |
| 695 | + calls = mock_model_config.call_args_list |
| 696 | + |
| 697 | + # First call: GPT-4 model from database |
| 698 | + assert calls[0][1]['cite_name'] == "GPT-4" |
| 699 | + assert calls[0][1]['api_key'] == "gpt4_key" |
| 700 | + assert calls[0][1]['model_name'] == "openai/gpt-4" |
| 701 | + assert calls[0][1]['url'] == "https://api.openai.com" |
| 702 | + |
| 703 | + # Second call: Claude model from database |
| 704 | + assert calls[1][1]['cite_name'] == "Claude" |
| 705 | + assert calls[1][1]['api_key'] == "claude_key" |
| 706 | + assert calls[1][1]['model_name'] == "anthropic/claude-3" |
| 707 | + assert calls[1][1]['url'] == "https://api.anthropic.com" |
| 708 | + |
| 709 | + # Third call: main_model |
| 710 | + assert calls[2][1]['cite_name'] == "main_model" |
| 711 | + assert calls[2][1]['api_key'] == "main_key" |
| 712 | + assert calls[2][1]['model_name'] == "main_model_name" |
| 713 | + assert calls[2][1]['url'] == "http://main.url" |
| 714 | + |
| 715 | + # Fourth call: sub_model |
| 716 | + assert calls[3][1]['cite_name'] == "sub_model" |
| 717 | + assert calls[3][1]['api_key'] == "main_key" |
| 718 | + assert calls[3][1]['model_name'] == "main_model_name" |
| 719 | + assert calls[3][1]['url'] == "http://main.url" |
| 720 | + |
| 721 | + @pytest.mark.asyncio |
| 722 | + async def test_create_model_config_list_empty_database(self): |
| 723 | + """Test case when database returns no records""" |
| 724 | + # Reset mock call count before test |
| 725 | + mock_model_config.reset_mock() |
| 726 | + |
| 727 | + with patch('backend.agents.create_agent_info.get_model_records') as mock_get_records, \ |
| 728 | + patch('backend.agents.create_agent_info.tenant_config_manager') as mock_manager, \ |
| 729 | + patch('backend.agents.create_agent_info.get_model_name_from_config') as mock_get_model_name: |
| 730 | + |
| 731 | + # Mock empty database records |
| 732 | + mock_get_records.return_value = [] |
| 733 | + |
| 734 | + # Mock tenant config for main_model and sub_model |
| 735 | + mock_manager.get_model_config.return_value = { |
| 736 | + "api_key": "main_key", |
| 737 | + "model_name": "main_model", |
| 738 | + "base_url": "http://main.url" |
| 739 | + } |
| 740 | + |
| 741 | + mock_get_model_name.return_value = "main_model_name" |
| 742 | + |
| 743 | + result = await create_model_config_list("tenant_1") |
| 744 | + |
| 745 | + # Should have 2 models: only default models (main_model, sub_model) |
| 746 | + assert len(result) == 2 |
| 747 | + |
| 748 | + # Verify ModelConfig was called 2 times |
| 749 | + assert mock_model_config.call_count == 2 |
| 750 | + |
| 751 | + # Verify both calls are for default models |
| 752 | + calls = mock_model_config.call_args_list |
| 753 | + assert calls[0][1]['cite_name'] == "main_model" |
| 754 | + assert calls[1][1]['cite_name'] == "sub_model" |
| 755 | + |
| 756 | + @pytest.mark.asyncio |
| 757 | + async def test_create_model_config_list_no_model_name_in_config(self): |
| 758 | + """Test case when tenant config has no model_name""" |
| 759 | + # Reset mock call count before test |
| 760 | + mock_model_config.reset_mock() |
| 761 | + |
| 762 | + with patch('backend.agents.create_agent_info.get_model_records') as mock_get_records, \ |
| 763 | + patch('backend.agents.create_agent_info.tenant_config_manager') as mock_manager, \ |
| 764 | + patch('backend.agents.create_agent_info.get_model_name_from_config') as mock_get_model_name: |
| 765 | + |
| 766 | + # Mock empty database records |
| 767 | + mock_get_records.return_value = [] |
| 768 | + |
| 769 | + # Mock tenant config without model_name |
| 770 | + mock_manager.get_model_config.return_value = { |
| 771 | + "api_key": "main_key", |
| 772 | + "base_url": "http://main.url" |
| 773 | + # No model_name field |
| 774 | + } |
616 | 775 |
|
617 | 776 | result = await create_model_config_list("tenant_1") |
618 | 777 |
|
| 778 | + # Should have 2 models: only default models (main_model, sub_model) |
619 | 779 | assert len(result) == 2 |
620 | | - # Verify that ModelConfig was called twice |
| 780 | + |
| 781 | + # Verify ModelConfig was called 2 times with empty model_name |
621 | 782 | assert mock_model_config.call_count == 2 |
| 783 | + |
| 784 | + calls = mock_model_config.call_args_list |
| 785 | + assert calls[0][1]['cite_name'] == "main_model" |
| 786 | + assert calls[0][1]['model_name'] == "" # Should be empty when no model_name in config |
| 787 | + assert calls[1][1]['cite_name'] == "sub_model" |
| 788 | + assert calls[1][1]['model_name'] == "" # Should be empty when no model_name in config |
622 | 789 |
|
623 | 790 |
|
624 | 791 | class TestFilterMcpServersAndTools: |
|
0 commit comments