Skip to content

Commit 04207c4

Browse files
committed
Add tool enable/disable functionality with client notifications
1 parent b8f7b02 commit 04207c4

File tree

3 files changed

+82
-3
lines changed

3 files changed

+82
-3
lines changed

src/mcp/server/fastmcp/tools/base.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from mcp.server.fastmcp.exceptions import ToolError
1010
from mcp.server.fastmcp.utilities.func_metadata import FuncMetadata, func_metadata
11-
from mcp.types import ToolAnnotations
11+
from mcp.types import ServerNotification, ToolAnnotations, ToolListChangedNotification
1212

1313
if TYPE_CHECKING:
1414
from mcp.server.fastmcp.server import Context
@@ -34,6 +34,7 @@ class Tool(BaseModel):
3434
annotations: ToolAnnotations | None = Field(
3535
None, description="Optional annotations for the tool"
3636
)
37+
enabled: bool = Field(default=True, description="Whether the tool is enabled")
3738

3839
@classmethod
3940
def from_function(
@@ -98,3 +99,29 @@ async def run(
9899
)
99100
except Exception as e:
100101
raise ToolError(f"Error executing tool {self.name}: {e}") from e
102+
103+
async def enable(
104+
self, context: Context[ServerSessionT, LifespanContextT] | None = None
105+
) -> None:
106+
"""Enable the tool and notify clients."""
107+
if not self.enabled:
108+
self.enabled = True
109+
if context and context.session:
110+
notification = ToolListChangedNotification(
111+
method="notifications/tools/list_changed"
112+
)
113+
server_notification = ServerNotification.model_validate(notification)
114+
await context.session.send_notification(server_notification)
115+
116+
async def disable(
117+
self, context: Context[ServerSessionT, LifespanContextT] | None = None
118+
) -> None:
119+
"""Disable the tool and notify clients."""
120+
if self.enabled:
121+
self.enabled = False
122+
if context and context.session:
123+
notification = ToolListChangedNotification(
124+
method="notifications/tools/list_changed"
125+
)
126+
server_notification = ServerNotification.model_validate(notification)
127+
await context.session.send_notification(server_notification)

src/mcp/server/fastmcp/tools/tool_manager.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ def get_tool(self, name: str) -> Tool | None:
2828
return self._tools.get(name)
2929

3030
def list_tools(self) -> list[Tool]:
31-
"""List all registered tools."""
32-
return list(self._tools.values())
31+
"""List all enabled registered tools."""
32+
return [tool for tool in self._tools.values() if tool.enabled]
3333

3434
def add_tool(
3535
self,
@@ -61,4 +61,7 @@ async def call_tool(
6161
if not tool:
6262
raise ToolError(f"Unknown tool: {name}")
6363

64+
if not tool.enabled:
65+
raise ToolError(f"Tool is disabled: {name}")
66+
6467
return await tool.run(arguments, context=context)

tests/server/fastmcp/test_tool_manager.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -362,3 +362,52 @@ def echo(message: str) -> str:
362362
assert tools[0].annotations is not None
363363
assert tools[0].annotations.title == "Echo Tool"
364364
assert tools[0].annotations.readOnlyHint is True
365+
366+
367+
class TestToolEnableDisable:
368+
"""Test enabling and disabling tools."""
369+
370+
@pytest.mark.anyio
371+
async def test_enable_disable_tool(self):
372+
"""Test enabling and disabling a tool."""
373+
374+
def add(a: int, b: int) -> int:
375+
"""Add two numbers."""
376+
return a + b
377+
378+
manager = ToolManager()
379+
tool = manager.add_tool(add)
380+
381+
# Tool should be enabled by default
382+
assert tool.enabled is True
383+
384+
# Disable the tool
385+
await tool.disable()
386+
assert tool.enabled is False
387+
388+
# Enable the tool
389+
await tool.enable()
390+
assert tool.enabled is True
391+
392+
@pytest.mark.anyio
393+
async def test_enable_disable_no_change(self):
394+
"""Test enabling and disabling a tool when there's no state change."""
395+
396+
def add(a: int, b: int) -> int:
397+
"""Add two numbers."""
398+
return a + b
399+
400+
manager = ToolManager()
401+
tool = manager.add_tool(add)
402+
403+
# Enable an already enabled tool (should not change state)
404+
await tool.enable()
405+
assert tool.enabled is True
406+
407+
# Disable the tool
408+
await tool.disable()
409+
assert tool.enabled is False
410+
411+
# Disable an already disabled tool (should not change state)
412+
await tool.disable()
413+
assert tool.enabled is False

0 commit comments

Comments
 (0)