-
Notifications
You must be signed in to change notification settings - Fork 6
Expand file tree
/
Copy pathtest_tool_toggle.py
More file actions
227 lines (178 loc) · 9.05 KB
/
test_tool_toggle.py
File metadata and controls
227 lines (178 loc) · 9.05 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
"""Tests for the tool toggling feature (--disabled-tools + TOML config)."""
import asyncio
import pytest
import mcp_codebase_index.server as srv
@pytest.fixture(autouse=True)
def _reset_toggle_state():
"""Reset the disabled-tools set and related state before each test."""
srv._disabled_tools = set()
yield
srv._disabled_tools = set()
# ---------------------------------------------------------------------------
# _load_disabled_tools_from_config
# ---------------------------------------------------------------------------
class TestLoadDisabledToolsFromConfig:
def test_no_config_file(self, tmp_path):
result = srv._load_disabled_tools_from_config(str(tmp_path))
assert result == set()
def test_valid_config(self, tmp_path):
(tmp_path / ".mcp-codebase-index.toml").write_text(
'disabled_tools = ["search_codebase", "get_call_chain"]\n'
)
result = srv._load_disabled_tools_from_config(str(tmp_path))
assert result == {"search_codebase", "get_call_chain"}
def test_empty_list(self, tmp_path):
(tmp_path / ".mcp-codebase-index.toml").write_text("disabled_tools = []\n")
result = srv._load_disabled_tools_from_config(str(tmp_path))
assert result == set()
def test_invalid_type_not_list(self, tmp_path):
(tmp_path / ".mcp-codebase-index.toml").write_text('disabled_tools = "search_codebase"\n')
result = srv._load_disabled_tools_from_config(str(tmp_path))
assert result == set()
def test_invalid_type_list_of_non_strings(self, tmp_path):
(tmp_path / ".mcp-codebase-index.toml").write_text("disabled_tools = [1, 2]\n")
result = srv._load_disabled_tools_from_config(str(tmp_path))
assert result == set()
def test_malformed_toml(self, tmp_path):
(tmp_path / ".mcp-codebase-index.toml").write_text("not valid toml [[[")
result = srv._load_disabled_tools_from_config(str(tmp_path))
assert result == set()
def test_missing_key(self, tmp_path):
(tmp_path / ".mcp-codebase-index.toml").write_text("[other]\nfoo = 1\n")
result = srv._load_disabled_tools_from_config(str(tmp_path))
assert result == set()
def test_whitespace_in_values_stripped(self, tmp_path):
(tmp_path / ".mcp-codebase-index.toml").write_text(
'disabled_tools = [" search_codebase ", "get_call_chain "]\n'
)
result = srv._load_disabled_tools_from_config(str(tmp_path))
assert result == {"search_codebase", "get_call_chain"}
def test_blank_strings_filtered(self, tmp_path):
(tmp_path / ".mcp-codebase-index.toml").write_text(
'disabled_tools = ["search_codebase", " ", ""]\n'
)
result = srv._load_disabled_tools_from_config(str(tmp_path))
assert result == {"search_codebase"}
# ---------------------------------------------------------------------------
# _init_disabled_tools
# ---------------------------------------------------------------------------
class TestInitDisabledTools:
def test_cli_only(self, tmp_path):
srv._init_disabled_tools(["search_codebase"], project_root=str(tmp_path))
assert srv._disabled_tools == {"search_codebase"}
def test_config_only(self, tmp_path):
(tmp_path / ".mcp-codebase-index.toml").write_text(
'disabled_tools = ["get_call_chain"]\n'
)
srv._init_disabled_tools(None, project_root=str(tmp_path))
assert srv._disabled_tools == {"get_call_chain"}
def test_union_of_cli_and_config(self, tmp_path):
(tmp_path / ".mcp-codebase-index.toml").write_text(
'disabled_tools = ["get_call_chain"]\n'
)
srv._init_disabled_tools(["search_codebase"], project_root=str(tmp_path))
assert srv._disabled_tools == {"search_codebase", "get_call_chain"}
def test_protected_tools_cannot_be_disabled(self, tmp_path):
srv._init_disabled_tools(["reindex", "get_usage_stats", "search_codebase"],
project_root=str(tmp_path))
assert "reindex" not in srv._disabled_tools
assert "get_usage_stats" not in srv._disabled_tools
assert "search_codebase" in srv._disabled_tools
def test_unknown_tools_ignored(self, tmp_path):
srv._init_disabled_tools(["not_a_real_tool", "search_codebase"],
project_root=str(tmp_path))
assert "not_a_real_tool" not in srv._disabled_tools
assert "search_codebase" in srv._disabled_tools
def test_empty_cli_list(self, tmp_path):
srv._init_disabled_tools([], project_root=str(tmp_path))
assert srv._disabled_tools == set()
def test_none_cli_no_config(self, tmp_path):
srv._init_disabled_tools(None, project_root=str(tmp_path))
assert srv._disabled_tools == set()
# ---------------------------------------------------------------------------
# list_tools filtering
# ---------------------------------------------------------------------------
class TestListToolsFiltering:
def test_no_disabled_returns_all(self):
srv._disabled_tools = set()
tools = asyncio.run(srv.list_tools())
assert len(tools) == len(srv.TOOLS)
def test_disabled_tools_excluded(self):
srv._disabled_tools = {"search_codebase", "get_call_chain"}
tools = asyncio.run(srv.list_tools())
names = {t.name for t in tools}
assert "search_codebase" not in names
assert "get_call_chain" not in names
assert len(tools) == len(srv.TOOLS) - 2
def test_protected_always_present(self):
srv._disabled_tools = {"search_codebase"}
tools = asyncio.run(srv.list_tools())
names = {t.name for t in tools}
assert "reindex" in names
assert "get_usage_stats" in names
# ---------------------------------------------------------------------------
# call_tool guard
# ---------------------------------------------------------------------------
class TestCallToolGuard:
def test_disabled_tool_returns_error(self):
srv._disabled_tools = {"search_codebase"}
result = asyncio.run(srv.call_tool("search_codebase", {"pattern": "foo"}))
assert len(result) == 1
assert "disabled" in result[0].text
assert "search_codebase" in result[0].text
def test_disabled_tool_not_counted(self):
srv._tool_call_counts.clear()
srv._disabled_tools = {"search_codebase"}
asyncio.run(srv.call_tool("search_codebase", {"pattern": "foo"}))
assert "search_codebase" not in srv._tool_call_counts
def test_enabled_tool_not_blocked(self, monkeypatch):
"""An enabled tool should proceed past the guard (we mock _ensure_index)."""
srv._disabled_tools = set()
srv._tool_call_counts.clear()
# Prevent actual indexing
monkeypatch.setattr(srv, "_ensure_index", lambda: None)
monkeypatch.setattr(srv, "_maybe_incremental_update", lambda: None)
monkeypatch.setattr(srv, "_query_fns", {
"get_project_summary": lambda: "mock summary",
})
result = asyncio.run(srv.call_tool("get_project_summary", {}))
assert result[0].text == "mock summary"
assert srv._tool_call_counts.get("get_project_summary") == 1
# ---------------------------------------------------------------------------
# CLI argument parsing (main_sync integration)
# ---------------------------------------------------------------------------
class TestCliParsing:
def _parse(self, argv: list[str]) -> tuple:
"""Run the same argparse logic as main_sync with custom argv."""
import argparse
parser = argparse.ArgumentParser(description="MCP codebase index server")
parser.add_argument(
"--disabled-tools",
type=lambda s: [t.strip() for t in s.split(",") if t.strip()],
default=None,
help="Comma-separated list of tool names to disable",
)
return parser.parse_known_args(argv)
def test_disabled_tools_parsed(self):
args, unknown = self._parse(["--disabled-tools", "search_codebase,get_call_chain"])
assert args.disabled_tools == ["search_codebase", "get_call_chain"]
assert unknown == []
def test_disabled_tools_with_spaces(self):
args, _ = self._parse(["--disabled-tools", " search_codebase , get_call_chain "])
assert args.disabled_tools == ["search_codebase", "get_call_chain"]
def test_no_disabled_tools_flag(self):
args, unknown = self._parse([])
assert args.disabled_tools is None
assert unknown == []
def test_unknown_args_not_fatal(self):
"""parse_known_args must not raise SystemExit on unknown flags."""
args, unknown = self._parse(["--some-future-flag", "value"])
assert args.disabled_tools is None
assert "--some-future-flag" in unknown
def test_unknown_args_coexist_with_disabled_tools(self):
args, unknown = self._parse([
"--disabled-tools", "search_codebase",
"--unknown-flag",
])
assert args.disabled_tools == ["search_codebase"]
assert "--unknown-flag" in unknown