11import unittest
2- import logging
32import sys
43import os
5- from unittest .mock import MagicMock , patch , Mock , call
6- import pytest
4+ from unittest .mock import MagicMock , patch
75
86# 添加项目根目录到Python路径
9- sys .path .insert (0 , os .path .abspath (os .path .join (os .path .dirname (__file__ ), '..' , '..' , '..' )))
7+ sys .path .insert (0 , os .path .abspath (os .path .join (
8+ os .path .dirname (__file__ ), '..' , '..' , '..' )))
109
1110# 模拟主要依赖
1211sys .modules ['langchain_core.tools' ] = MagicMock ()
@@ -31,12 +30,12 @@ def test_is_langchain_tool(self):
3130 """测试_is_langchain_tool函数"""
3231 # 创建一个BaseTool实例的模拟
3332 mock_tool = MagicMock ()
34-
33+
3534 # 模拟isinstance返回值
3635 with patch ('backend.utils.langchain_utils.isinstance' , return_value = True ):
3736 result = self ._is_langchain_tool (mock_tool )
3837 self .assertTrue (result )
39-
38+
4039 # 测试非BaseTool对象
4140 with patch ('backend.utils.langchain_utils.isinstance' , return_value = False ):
4241 result = self ._is_langchain_tool ("not a tool" )
@@ -46,44 +45,45 @@ def test_discover_langchain_modules_success(self):
4645 """测试成功发现LangChain工具的情况"""
4746 # 创建一个临时目录结构
4847 with patch ('os.path.isdir' , return_value = True ), \
49- patch ('os.listdir' , return_value = ['tool1.py' , 'tool2.py' , '__init__.py' , 'not_a_py_file.txt' ]), \
50- patch ('importlib.util.spec_from_file_location' ) as mock_spec , \
51- patch ('importlib.util.module_from_spec' ) as mock_module_from_spec :
52-
48+ patch ('os.listdir' , return_value = ['tool1.py' , 'tool2.py' , '__init__.py' , 'not_a_py_file.txt' ]), \
49+ patch ('importlib.util.spec_from_file_location' ) as mock_spec , \
50+ patch ('importlib.util.module_from_spec' ) as mock_module_from_spec :
51+
5352 # 创建模拟工具对象
5453 mock_tool1 = MagicMock (name = "tool1" )
5554 mock_tool2 = MagicMock (name = "tool2" )
56-
55+
5756 # 设置模拟module
5857 mock_module_obj1 = MagicMock ()
5958 mock_module_obj1 .tool_obj1 = mock_tool1
6059
6160 mock_module_obj2 = MagicMock ()
6261 mock_module_obj2 .tool_obj2 = mock_tool2
63-
64- mock_module_from_spec .side_effect = [mock_module_obj1 , mock_module_obj2 ]
65-
62+
63+ mock_module_from_spec .side_effect = [
64+ mock_module_obj1 , mock_module_obj2 ]
65+
6666 # 设置模拟spec和loader
6767 mock_spec_obj1 = MagicMock ()
6868 mock_spec_obj2 = MagicMock ()
6969 mock_spec .side_effect = [mock_spec_obj1 , mock_spec_obj2 ]
70-
70+
7171 mock_loader1 = MagicMock ()
7272 mock_loader2 = MagicMock ()
7373 mock_spec_obj1 .loader = mock_loader1
7474 mock_spec_obj2 .loader = mock_loader2
75-
75+
7676 # 设置过滤函数始终返回True
7777 def mock_filter (obj ):
7878 return obj is mock_tool1 or obj is mock_tool2
79-
79+
8080 # 执行函数
8181 result = self .discover_langchain_modules (filter_func = mock_filter )
82-
82+
8383 # 验证loader.exec_module被调用
8484 mock_loader1 .exec_module .assert_called_once_with (mock_module_obj1 )
8585 mock_loader2 .exec_module .assert_called_once_with (mock_module_obj2 )
86-
86+
8787 # 验证结果
8888 self .assertEqual (len (result ), 2 )
8989 discovered_objs = [obj for (obj , _ ) in result ]
@@ -93,39 +93,41 @@ def mock_filter(obj):
9393 def test_discover_langchain_modules_directory_not_found (self ):
9494 """测试目录不存在的情况"""
9595 with patch ('os.path.isdir' , return_value = False ):
96- result = self .discover_langchain_modules (directory = "non_existent_dir" )
96+ result = self .discover_langchain_modules (
97+ directory = "non_existent_dir" )
9798 self .assertEqual (result , [])
9899
99100 def test_discover_langchain_modules_module_exception (self ):
100101 """测试处理模块异常的情况"""
101102 with patch ('os.path.isdir' , return_value = True ), \
102- patch ('os.listdir' , return_value = ['error_module.py' ]), \
103- patch ('importlib.util.spec_from_file_location' ) as mock_spec , \
104- patch ('backend.utils.langchain_utils.logger' , logger_mock ):
105-
103+ patch ('os.listdir' , return_value = ['error_module.py' ]), \
104+ patch ('importlib.util.spec_from_file_location' ) as mock_spec , \
105+ patch ('backend.utils.langchain_utils.logger' , logger_mock ):
106+
106107 # 设置spec_from_file_location抛出异常
107108 mock_spec .side_effect = Exception ("Module error" )
108-
109+
109110 # 执行函数 - 应该捕获异常并继续
110111 result = self .discover_langchain_modules ()
111-
112+
112113 # 验证结果为空列表
113114 self .assertEqual (result , [])
114115 # 验证错误被记录
115116 self .assertTrue (logger_mock .error .called )
116117 # 验证错误消息包含预期内容
117- logger_mock .error .assert_called_with ("Error processing module error_module.py: Module error" )
118+ logger_mock .error .assert_called_with (
119+ "Error processing module error_module.py: Module error" )
118120
119121 def test_discover_langchain_modules_spec_loader_none (self ):
120122 """测试spec或loader为None的情况"""
121123 with patch ('os.path.isdir' , return_value = True ), \
122- patch ('os.listdir' , return_value = ['invalid_module.py' ]), \
123- patch ('importlib.util.spec_from_file_location' , return_value = None ), \
124- patch ('backend.utils.langchain_utils.logger' , logger_mock ):
125-
124+ patch ('os.listdir' , return_value = ['invalid_module.py' ]), \
125+ patch ('importlib.util.spec_from_file_location' , return_value = None ), \
126+ patch ('backend.utils.langchain_utils.logger' , logger_mock ):
127+
126128 # 执行函数
127129 result = self .discover_langchain_modules ()
128-
130+
129131 # 验证结果为空列表
130132 self .assertEqual (result , [])
131133 # 验证警告被记录
@@ -138,36 +140,36 @@ def test_discover_langchain_modules_spec_loader_none(self):
138140 def test_discover_langchain_modules_custom_filter (self ):
139141 """测试使用自定义过滤函数的情况"""
140142 with patch ('os.path.isdir' , return_value = True ), \
141- patch ('os.listdir' , return_value = ['tool.py' ]), \
142- patch ('importlib.util.spec_from_file_location' ) as mock_spec , \
143- patch ('importlib.util.module_from_spec' ) as mock_module_from_spec :
144-
143+ patch ('os.listdir' , return_value = ['tool.py' ]), \
144+ patch ('importlib.util.spec_from_file_location' ) as mock_spec , \
145+ patch ('importlib.util.module_from_spec' ) as mock_module_from_spec :
146+
145147 # 创建两个对象,一个通过过滤,一个不通过
146148 obj_pass = MagicMock (name = "pass_object" )
147149 obj_fail = MagicMock (name = "fail_object" )
148-
150+
149151 # 设置模拟module,使其包含我们的两个测试对象
150152 mock_module_obj = MagicMock ()
151153 mock_module_obj .obj_pass = obj_pass
152154 mock_module_obj .obj_fail = obj_fail
153155 mock_module_from_spec .return_value = mock_module_obj
154-
156+
155157 # 设置模拟spec和loader
156158 mock_spec_obj = MagicMock ()
157159 mock_spec .return_value = mock_spec_obj
158160 mock_loader = MagicMock ()
159161 mock_spec_obj .loader = mock_loader
160-
162+
161163 # 自定义过滤函数,只接受obj_pass
162164 def custom_filter (obj ):
163165 return obj is obj_pass
164-
166+
165167 # 执行函数
166168 result = self .discover_langchain_modules (filter_func = custom_filter )
167-
169+
168170 # 验证loader.exec_module被调用
169171 mock_loader .exec_module .assert_called_once_with (mock_module_obj )
170-
172+
171173 # 验证结果 - 应该只有一个对象通过过滤
172174 self .assertEqual (len (result ), 1 )
173175 self .assertEqual (result [0 ][0 ], obj_pass )
0 commit comments