Skip to content

Commit e1e476c

Browse files
committed
fix(mcp): Improve error handling and test coverage
1 parent 4b6d8d5 commit e1e476c

File tree

2 files changed

+186
-21
lines changed

2 files changed

+186
-21
lines changed

src/vectorcode/mcp_main.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -130,18 +130,18 @@ async def vectorise_files(paths: list[str], project_root: str) -> dict[str, int]
130130
semaphore = asyncio.Semaphore(os.cpu_count() or 1)
131131
tasks = [
132132
asyncio.create_task(
133-
vectorise_worker(database, file, semaphore, stats, stats_lock)
133+
vectorise_worker(database, str(file), semaphore, stats, stats_lock)
134134
)
135-
for file in paths
135+
for file in final_config.files
136136
]
137137
for i, task in enumerate(asyncio.as_completed(tasks), start=1):
138138
await task
139139

140140
await database.check_orphanes()
141141

142142
return stats.to_dict()
143-
except Exception as e: # pragma: nocover
144-
if isinstance(e, McpError):
143+
except Exception as e:
144+
if isinstance(e, McpError): # pragma: nocover
145145
logger.error("Failed to access collection at %s", project_root)
146146
raise
147147
else:
@@ -185,8 +185,8 @@ async def query_tool(
185185
reranked_results = await get_reranked_results(config, database)
186186
return list(str(i) for i in _prepare_formatted_result(reranked_results))
187187

188-
except Exception as e: # pragma: nocover
189-
if isinstance(e, McpError):
188+
except Exception as e:
189+
if isinstance(e, McpError): # pragma: nocover
190190
logger.error("Failed to access collection at %s", project_root)
191191
raise
192192
else:

tests/test_mcp.py

Lines changed: 180 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,17 @@
1-
from unittest.mock import AsyncMock, MagicMock, patch
1+
import sys
2+
from unittest.mock import ANY, AsyncMock, MagicMock, patch
23

34
import pytest
4-
from mcp import McpError
5+
from mcp import ErrorData, McpError
56

67
from vectorcode.cli_utils import Config
78
from vectorcode.mcp_main import (
9+
get_arg_parser,
810
list_collections,
911
ls_files,
12+
mcp_config,
1013
mcp_server,
14+
parse_cli_args,
1115
query_tool,
1216
rm_files,
1317
vectorise_files,
@@ -30,13 +34,14 @@ async def test_list_collections_success():
3034

3135
@pytest.mark.asyncio
3236
async def test_query_tool_invalid_project_root():
33-
with pytest.raises(McpError) as exc_info:
34-
await query_tool(
35-
n_query=5,
36-
query_messages=["keyword1", "keyword2"],
37-
project_root="invalid_path",
38-
)
39-
assert exc_info.value.error.code == 1
37+
with patch("os.path.isdir", return_value=False):
38+
with pytest.raises(McpError) as exc_info:
39+
await query_tool(
40+
n_query=5,
41+
query_messages=["keyword1", "keyword2"],
42+
project_root="invalid_path",
43+
)
44+
assert exc_info.value.error.code == 1
4045

4146

4247
@pytest.mark.asyncio
@@ -45,10 +50,16 @@ async def test_query_tool_success(tmp_path):
4550
with (
4651
patch("vectorcode.mcp_main.get_database_connector") as mock_get_db,
4752
patch("vectorcode.mcp_main.get_project_config", return_value=mock_config),
53+
patch(
54+
"vectorcode.subcommands.query.reranker.naive.NaiveReranker.rerank",
55+
new_callable=AsyncMock,
56+
return_value=[],
57+
),
4858
):
4959
mock_db = AsyncMock()
5060
mock_get_db.return_value = mock_db
5161
mock_db._configs = mock_config
62+
mock_db.query.return_value = []
5263

5364
await query_tool(
5465
n_query=2, query_messages=["keyword1"], project_root=str(tmp_path)
@@ -69,22 +80,68 @@ async def test_vectorise_tool_invalid_project_root():
6980
async def test_vectorise_files_success(tmp_path):
7081
mock_db = AsyncMock()
7182
mock_config = Config(project_root=str(tmp_path))
83+
(tmp_path / "file1.py").touch()
84+
with (
85+
patch("vectorcode.mcp_main.get_database_connector", return_value=mock_db),
86+
patch("vectorcode.mcp_main.get_project_config", return_value=mock_config),
87+
patch(
88+
"vectorcode.mcp_main.vectorise_worker", new_callable=AsyncMock
89+
) as mock_worker,
90+
):
91+
await vectorise_files(
92+
paths=[str(tmp_path / "file1.py")], project_root=str(tmp_path)
93+
)
94+
mock_worker.assert_called_once()
95+
96+
97+
@pytest.mark.asyncio
98+
async def test_vectorise_files_with_ignore_spec(tmp_path):
99+
project_root = tmp_path
100+
(project_root / ".gitignore").write_text("ignored.py")
101+
(project_root / "file1.py").touch()
102+
(project_root / "ignored.py").touch()
103+
104+
mock_db = AsyncMock()
105+
mock_config = Config(project_root=str(project_root))
72106
with (
73107
patch("vectorcode.mcp_main.get_database_connector", return_value=mock_db),
74108
patch("vectorcode.mcp_main.get_project_config", return_value=mock_config),
75-
patch("os.path.isfile", side_effect=lambda x: x == "file1.py"),
109+
patch(
110+
"vectorcode.mcp_main.vectorise_worker", new_callable=AsyncMock
111+
) as mock_worker,
76112
):
77-
await vectorise_files(paths=["file1.py"], project_root=str(tmp_path))
78-
mock_db.vectorise.assert_called_with(file_path="file1.py")
113+
await vectorise_files(
114+
paths=[str(project_root / "file1.py"), str(project_root / "ignored.py")],
115+
project_root=str(project_root),
116+
)
117+
mock_worker.assert_called_once_with(
118+
mock_db, str(project_root / "file1.py"), ANY, ANY, ANY
119+
)
120+
121+
122+
@pytest.mark.asyncio
123+
async def test_mcp_server(tmp_path):
124+
with (
125+
patch("mcp.server.fastmcp.FastMCP.add_tool") as mock_add_tool,
126+
patch("vectorcode.mcp_main.find_project_config_dir", return_value=tmp_path),
127+
patch("vectorcode.mcp_main.get_project_config", return_value=Config()),
128+
):
129+
await mcp_server()
130+
assert mock_add_tool.call_count > 0
79131

80132

81133
@pytest.mark.asyncio
82-
async def test_mcp_server():
134+
async def test_mcp_server_ls_on_start(tmp_path):
83135
with (
84136
patch("mcp.server.fastmcp.FastMCP.add_tool") as mock_add_tool,
137+
patch("vectorcode.mcp_main.find_project_config_dir", return_value=tmp_path),
138+
patch("vectorcode.mcp_main.get_project_config", return_value=Config()),
139+
patch("vectorcode.mcp_main.list_collections", return_value=["path1", "path2"]),
85140
):
141+
mcp_config.ls_on_start = True
86142
await mcp_server()
87143
assert mock_add_tool.call_count > 0
144+
mcp_config.ls_on_start = False
88145

89146

90147
@pytest.mark.asyncio
@@ -108,15 +165,123 @@ async def test_ls_files_success(tmp_path):
108165

109166
@pytest.mark.asyncio
110167
async def test_rm_files_success(tmp_path):
168+
(tmp_path / "file1.py").touch()
111169
with (
112170
patch("vectorcode.mcp_main.get_database_connector") as mock_get_db,
113171
patch("vectorcode.mcp_main.get_project_config") as mock_get_config,
114-
patch("os.path.isfile", side_effect=lambda x: x == "file1.py"),
115172
):
116173
mock_db = AsyncMock()
117174
mock_get_db.return_value = mock_db
118175
mock_get_config.return_value = Config(project_root=str(tmp_path))
119176

120-
await rm_files(files=["file1.py"], project_root=str(tmp_path))
177+
await rm_files(files=[str(tmp_path / "file1.py")], project_root=str(tmp_path))
121178

122179
mock_db.delete.assert_called_once()
180+
181+
182+
@pytest.mark.asyncio
183+
async def test_rm_files_no_files(tmp_path):
184+
with (
185+
patch("vectorcode.mcp_main.get_database_connector") as mock_get_db,
186+
patch("vectorcode.mcp_main.get_project_config") as mock_get_config,
187+
):
188+
mock_db = AsyncMock()
189+
mock_get_db.return_value = mock_db
190+
mock_get_config.return_value = Config(project_root=str(tmp_path))
191+
192+
await rm_files(files=["file1.py"], project_root=str(tmp_path))
193+
194+
mock_db.delete.assert_not_called()
195+
196+
197+
def test_get_arg_parser():
198+
parser = get_arg_parser()
199+
args = parser.parse_args(["-n", "5", "--ls-on-start"])
200+
assert args.number == 5
201+
assert args.ls_on_start is True
202+
203+
204+
def test_parse_cli_args():
205+
with patch.object(sys, "argv", ["", "-n", "5", "--ls-on-start"]):
206+
config = parse_cli_args()
207+
assert config.n_results == 5
208+
assert config.ls_on_start is True
209+
210+
211+
@pytest.mark.asyncio
212+
async def test_vectorise_files_exception(tmp_path):
213+
mock_db = AsyncMock()
214+
mock_config = Config(project_root=str(tmp_path))
215+
(tmp_path / "file1.py").touch()
216+
with (
217+
patch("vectorcode.mcp_main.get_database_connector", return_value=mock_db),
218+
patch("vectorcode.mcp_main.get_project_config", return_value=mock_config),
219+
patch(
220+
"vectorcode.mcp_main.vectorise_worker", side_effect=Exception("test error")
221+
),
222+
):
223+
with pytest.raises(McpError):
224+
await vectorise_files(
225+
paths=[str(tmp_path / "file1.py")], project_root=str(tmp_path)
226+
)
227+
228+
229+
@pytest.mark.asyncio
230+
async def test_query_tool_exception(tmp_path):
231+
mock_config = Config(project_root=tmp_path)
232+
with (
233+
patch("vectorcode.mcp_main.get_database_connector") as mock_get_db,
234+
patch("vectorcode.mcp_main.get_project_config", return_value=mock_config),
235+
patch(
236+
"vectorcode.mcp_main.get_reranked_results",
237+
side_effect=Exception("test error"),
238+
),
239+
):
240+
mock_db = AsyncMock()
241+
mock_get_db.return_value = mock_db
242+
mock_db._configs = mock_config
243+
244+
with pytest.raises(McpError):
245+
await query_tool(
246+
n_query=2, query_messages=["keyword1"], project_root=str(tmp_path)
247+
)
248+
249+
250+
@pytest.mark.asyncio
251+
async def test_vectorise_files_mcp_exception(tmp_path):
252+
mock_db = AsyncMock()
253+
mock_config = Config(project_root=str(tmp_path))
254+
(tmp_path / "file1.py").touch()
255+
with (
256+
patch("vectorcode.mcp_main.get_database_connector", return_value=mock_db),
257+
patch("vectorcode.mcp_main.get_project_config", return_value=mock_config),
258+
patch(
259+
"vectorcode.mcp_main.vectorise_worker",
260+
side_effect=McpError(ErrorData(code=1, message="test error")),
261+
),
262+
):
263+
with pytest.raises(McpError):
264+
await vectorise_files(
265+
paths=[str(tmp_path / "file1.py")], project_root=str(tmp_path)
266+
)
267+
268+
269+
@pytest.mark.asyncio
270+
async def test_query_tool_mcp_exception(tmp_path):
271+
mock_config = Config(project_root=tmp_path)
272+
with (
273+
patch("vectorcode.mcp_main.get_database_connector") as mock_get_db,
274+
patch("vectorcode.mcp_main.get_project_config", return_value=mock_config),
275+
patch(
276+
"vectorcode.mcp_main.get_reranked_results",
277+
side_effect=McpError(ErrorData(code=1, message="test error")),
278+
),
279+
):
280+
mock_db = AsyncMock()
281+
mock_get_db.return_value = mock_db
282+
mock_db._configs = mock_config
283+
284+
with pytest.raises(McpError):
285+
await query_tool(
286+
n_query=2, query_messages=["keyword1"], project_root=str(tmp_path)
287+
)

0 commit comments

Comments
 (0)