1- from unittest .mock import AsyncMock , MagicMock , patch
1+ import sys
2+ from unittest .mock import ANY , AsyncMock , MagicMock , patch
23
34import pytest
4- from mcp import McpError
5+ from mcp import ErrorData , McpError
56
67from vectorcode .cli_utils import Config
78from vectorcode .mcp_main import (
9+ get_arg_parser ,
810 list_collections ,
911 ls_files ,
12+ mcp_config ,
1013 mcp_server ,
14+ parse_cli_args ,
1115 query_tool ,
1216 rm_files ,
1317 vectorise_files ,
@@ -30,13 +34,14 @@ async def test_list_collections_success():
3034
3135@pytest .mark .asyncio
3236async def test_query_tool_invalid_project_root ():
33- with pytest .raises (McpError ) as exc_info :
34- await query_tool (
35- n_query = 5 ,
36- query_messages = ["keyword1" , "keyword2" ],
37- project_root = "invalid_path" ,
38- )
39- assert exc_info .value .error .code == 1
37+ with patch ("os.path.isdir" , return_value = False ):
38+ with pytest .raises (McpError ) as exc_info :
39+ await query_tool (
40+ n_query = 5 ,
41+ query_messages = ["keyword1" , "keyword2" ],
42+ project_root = "invalid_path" ,
43+ )
44+ assert exc_info .value .error .code == 1
4045
4146
4247@pytest .mark .asyncio
@@ -45,10 +50,16 @@ async def test_query_tool_success(tmp_path):
4550 with (
4651 patch ("vectorcode.mcp_main.get_database_connector" ) as mock_get_db ,
4752 patch ("vectorcode.mcp_main.get_project_config" , return_value = mock_config ),
53+ patch (
54+ "vectorcode.subcommands.query.reranker.naive.NaiveReranker.rerank" ,
55+ new_callable = AsyncMock ,
56+ return_value = [],
57+ ),
4858 ):
4959 mock_db = AsyncMock ()
5060 mock_get_db .return_value = mock_db
5161 mock_db ._configs = mock_config
62+ mock_db .query .return_value = []
5263
5364 await query_tool (
5465 n_query = 2 , query_messages = ["keyword1" ], project_root = str (tmp_path )
@@ -69,22 +80,68 @@ async def test_vectorise_tool_invalid_project_root():
6980async def test_vectorise_files_success (tmp_path ):
7081 mock_db = AsyncMock ()
7182 mock_config = Config (project_root = str (tmp_path ))
83+ (tmp_path / "file1.py" ).touch ()
84+ with (
85+ patch ("vectorcode.mcp_main.get_database_connector" , return_value = mock_db ),
86+ patch ("vectorcode.mcp_main.get_project_config" , return_value = mock_config ),
87+ patch (
88+ "vectorcode.mcp_main.vectorise_worker" , new_callable = AsyncMock
89+ ) as mock_worker ,
90+ ):
91+ await vectorise_files (
92+ paths = [str (tmp_path / "file1.py" )], project_root = str (tmp_path )
93+ )
94+ mock_worker .assert_called_once ()
95+
96+
97+ @pytest .mark .asyncio
98+ async def test_vectorise_files_with_ignore_spec (tmp_path ):
99+ project_root = tmp_path
100+ (project_root / ".gitignore" ).write_text ("ignored.py" )
101+ (project_root / "file1.py" ).touch ()
102+ (project_root / "ignored.py" ).touch ()
103+
104+ mock_db = AsyncMock ()
105+ mock_config = Config (project_root = str (project_root ))
72106 with (
73107 patch ("vectorcode.mcp_main.get_database_connector" , return_value = mock_db ),
74108 patch ("vectorcode.mcp_main.get_project_config" , return_value = mock_config ),
75- patch ("os.path.isfile" , side_effect = lambda x : x == "file1.py" ),
109+ patch (
110+ "vectorcode.mcp_main.vectorise_worker" , new_callable = AsyncMock
111+ ) as mock_worker ,
76112 ):
77- await vectorise_files (paths = ["file1.py" ], project_root = str (tmp_path ))
78- mock_db .vectorise .assert_called_with (file_path = "file1.py" )
113+ await vectorise_files (
114+ paths = [str (project_root / "file1.py" ), str (project_root / "ignored.py" )],
115+ project_root = str (project_root ),
116+ )
117+ mock_worker .assert_called_once_with (
118+ mock_db , str (project_root / "file1.py" ), ANY , ANY , ANY
119+ )
120+
121+
122+ @pytest .mark .asyncio
123+ async def test_mcp_server (tmp_path ):
124+ with (
125+ patch ("mcp.server.fastmcp.FastMCP.add_tool" ) as mock_add_tool ,
126+ patch ("vectorcode.mcp_main.find_project_config_dir" , return_value = tmp_path ),
127+ patch ("vectorcode.mcp_main.get_project_config" , return_value = Config ()),
128+ ):
129+ await mcp_server ()
130+ assert mock_add_tool .call_count > 0
79131
80132
81133@pytest .mark .asyncio
82- async def test_mcp_server ( ):
134+ async def test_mcp_server_ls_on_start ( tmp_path ):
83135 with (
84136 patch ("mcp.server.fastmcp.FastMCP.add_tool" ) as mock_add_tool ,
137+ patch ("vectorcode.mcp_main.find_project_config_dir" , return_value = tmp_path ),
138+ patch ("vectorcode.mcp_main.get_project_config" , return_value = Config ()),
139+ patch ("vectorcode.mcp_main.list_collections" , return_value = ["path1" , "path2" ]),
85140 ):
141+ mcp_config .ls_on_start = True
86142 await mcp_server ()
87143 assert mock_add_tool .call_count > 0
144+ mcp_config .ls_on_start = False
88145
89146
90147@pytest .mark .asyncio
@@ -108,15 +165,123 @@ async def test_ls_files_success(tmp_path):
108165
109166@pytest .mark .asyncio
110167async def test_rm_files_success (tmp_path ):
168+ (tmp_path / "file1.py" ).touch ()
111169 with (
112170 patch ("vectorcode.mcp_main.get_database_connector" ) as mock_get_db ,
113171 patch ("vectorcode.mcp_main.get_project_config" ) as mock_get_config ,
114- patch ("os.path.isfile" , side_effect = lambda x : x == "file1.py" ),
115172 ):
116173 mock_db = AsyncMock ()
117174 mock_get_db .return_value = mock_db
118175 mock_get_config .return_value = Config (project_root = str (tmp_path ))
119176
120- await rm_files (files = ["file1.py" ], project_root = str (tmp_path ))
177+ await rm_files (files = [str ( tmp_path / "file1.py" ) ], project_root = str (tmp_path ))
121178
122179 mock_db .delete .assert_called_once ()
180+
181+
182+ @pytest .mark .asyncio
183+ async def test_rm_files_no_files (tmp_path ):
184+ with (
185+ patch ("vectorcode.mcp_main.get_database_connector" ) as mock_get_db ,
186+ patch ("vectorcode.mcp_main.get_project_config" ) as mock_get_config ,
187+ ):
188+ mock_db = AsyncMock ()
189+ mock_get_db .return_value = mock_db
190+ mock_get_config .return_value = Config (project_root = str (tmp_path ))
191+
192+ await rm_files (files = ["file1.py" ], project_root = str (tmp_path ))
193+
194+ mock_db .delete .assert_not_called ()
195+
196+
197+ def test_get_arg_parser ():
198+ parser = get_arg_parser ()
199+ args = parser .parse_args (["-n" , "5" , "--ls-on-start" ])
200+ assert args .number == 5
201+ assert args .ls_on_start is True
202+
203+
204+ def test_parse_cli_args ():
205+ with patch .object (sys , "argv" , ["" , "-n" , "5" , "--ls-on-start" ]):
206+ config = parse_cli_args ()
207+ assert config .n_results == 5
208+ assert config .ls_on_start is True
209+
210+
211+ @pytest .mark .asyncio
212+ async def test_vectorise_files_exception (tmp_path ):
213+ mock_db = AsyncMock ()
214+ mock_config = Config (project_root = str (tmp_path ))
215+ (tmp_path / "file1.py" ).touch ()
216+ with (
217+ patch ("vectorcode.mcp_main.get_database_connector" , return_value = mock_db ),
218+ patch ("vectorcode.mcp_main.get_project_config" , return_value = mock_config ),
219+ patch (
220+ "vectorcode.mcp_main.vectorise_worker" , side_effect = Exception ("test error" )
221+ ),
222+ ):
223+ with pytest .raises (McpError ):
224+ await vectorise_files (
225+ paths = [str (tmp_path / "file1.py" )], project_root = str (tmp_path )
226+ )
227+
228+
229+ @pytest .mark .asyncio
230+ async def test_query_tool_exception (tmp_path ):
231+ mock_config = Config (project_root = tmp_path )
232+ with (
233+ patch ("vectorcode.mcp_main.get_database_connector" ) as mock_get_db ,
234+ patch ("vectorcode.mcp_main.get_project_config" , return_value = mock_config ),
235+ patch (
236+ "vectorcode.mcp_main.get_reranked_results" ,
237+ side_effect = Exception ("test error" ),
238+ ),
239+ ):
240+ mock_db = AsyncMock ()
241+ mock_get_db .return_value = mock_db
242+ mock_db ._configs = mock_config
243+
244+ with pytest .raises (McpError ):
245+ await query_tool (
246+ n_query = 2 , query_messages = ["keyword1" ], project_root = str (tmp_path )
247+ )
248+
249+
250+ @pytest .mark .asyncio
251+ async def test_vectorise_files_mcp_exception (tmp_path ):
252+ mock_db = AsyncMock ()
253+ mock_config = Config (project_root = str (tmp_path ))
254+ (tmp_path / "file1.py" ).touch ()
255+ with (
256+ patch ("vectorcode.mcp_main.get_database_connector" , return_value = mock_db ),
257+ patch ("vectorcode.mcp_main.get_project_config" , return_value = mock_config ),
258+ patch (
259+ "vectorcode.mcp_main.vectorise_worker" ,
260+ side_effect = McpError (ErrorData (code = 1 , message = "test error" )),
261+ ),
262+ ):
263+ with pytest .raises (McpError ):
264+ await vectorise_files (
265+ paths = [str (tmp_path / "file1.py" )], project_root = str (tmp_path )
266+ )
267+
268+
269+ @pytest .mark .asyncio
270+ async def test_query_tool_mcp_exception (tmp_path ):
271+ mock_config = Config (project_root = tmp_path )
272+ with (
273+ patch ("vectorcode.mcp_main.get_database_connector" ) as mock_get_db ,
274+ patch ("vectorcode.mcp_main.get_project_config" , return_value = mock_config ),
275+ patch (
276+ "vectorcode.mcp_main.get_reranked_results" ,
277+ side_effect = McpError (ErrorData (code = 1 , message = "test error" )),
278+ ),
279+ ):
280+ mock_db = AsyncMock ()
281+ mock_get_db .return_value = mock_db
282+ mock_db ._configs = mock_config
283+
284+ with pytest .raises (McpError ):
285+ await query_tool (
286+ n_query = 2 , query_messages = ["keyword1" ], project_root = str (tmp_path )
287+ )
0 commit comments