diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index bcf80d62a..4847a74a3 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -34,12 +34,6 @@ async def __call__( ) -> types.ElicitResult | types.ErrorData: ... -class ListRootsFnT(Protocol): - async def __call__( - self, context: RequestContext["ClientSession", Any] - ) -> types.ListRootsResult | types.ErrorData: ... - - class LoggingFnT(Protocol): async def __call__( self, @@ -80,15 +74,6 @@ async def _default_elicitation_callback( ) -async def _default_list_roots_callback( - context: RequestContext["ClientSession", Any], -) -> types.ListRootsResult | types.ErrorData: - return types.ErrorData( - code=types.INVALID_REQUEST, - message="List roots not supported", - ) - - async def _default_logging_callback( params: types.LoggingMessageNotificationParams, ) -> None: @@ -114,7 +99,6 @@ def __init__( read_timeout_seconds: timedelta | None = None, sampling_callback: SamplingFnT | None = None, elicitation_callback: ElicitationFnT | None = None, - list_roots_callback: ListRootsFnT | None = None, logging_callback: LoggingFnT | None = None, message_handler: MessageHandlerFnT | None = None, client_info: types.Implementation | None = None, @@ -129,7 +113,6 @@ def __init__( self._client_info = client_info or DEFAULT_CLIENT_INFO self._sampling_callback = sampling_callback or _default_sampling_callback self._elicitation_callback = elicitation_callback or _default_elicitation_callback - self._list_roots_callback = list_roots_callback or _default_list_roots_callback self._logging_callback = logging_callback or _default_logging_callback self._message_handler = message_handler or _default_message_handler self._tool_output_schemas: dict[str, dict[str, Any] | None] = {} @@ -139,15 +122,6 @@ async def initialize(self) -> types.InitializeResult: elicitation = ( types.ElicitationCapability() if self._elicitation_callback is not _default_elicitation_callback else None ) - roots = ( - # TODO: Should this be based on whether we - # _will_ send notifications, or only whether - # they're supported? - types.RootsCapability(listChanged=True) - if self._list_roots_callback is not _default_list_roots_callback - else None - ) - result = await self.send_request( types.ClientRequest( types.InitializeRequest( @@ -157,7 +131,6 @@ async def initialize(self) -> types.InitializeResult: sampling=sampling, elicitation=elicitation, experimental=None, - roots=roots, ), clientInfo=self._client_info, ), @@ -381,9 +354,23 @@ async def list_tools(self, cursor: str | None = None) -> types.ListToolsResult: return result - async def send_roots_list_changed(self) -> None: - """Send a roots/list_changed notification.""" - await self.send_notification(types.ClientNotification(types.RootsListChangedNotification())) + async def set_roots(self, roots: list[types.Root]) -> types.EmptyResult: + """Send a roots/set request to set the server's root directories.""" + return await self.send_request( + types.ClientRequest( + types.SetRootsRequest( + params=types.SetRootsRequestParams(roots=roots), + ) + ), + types.EmptyResult, + ) + + async def list_roots(self) -> types.ListRootsResult: + """Send a roots/list request to query the server's current roots.""" + return await self.send_request( + types.ClientRequest(types.ListRootsRequest()), + types.ListRootsResult, + ) async def _received_request(self, responder: RequestResponder[types.ServerRequest, types.ClientResult]) -> None: ctx = RequestContext[ClientSession, Any]( @@ -406,12 +393,6 @@ async def _received_request(self, responder: RequestResponder[types.ServerReques client_response = ClientResponse.validate_python(response) await responder.respond(client_response) - case types.ListRootsRequest(): - with responder: - response = await self._list_roots_callback(ctx) - client_response = ClientResponse.validate_python(response) - await responder.respond(client_response) - case types.PingRequest(): with responder: return await responder.respond(types.ClientResult(root=types.EmptyResult())) diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index d86fa85e3..febb8794f 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -21,6 +21,7 @@ from starlette.routing import Mount, Route from starlette.types import Receive, Scope, Send +import mcp.types as types from mcp.server.auth.middleware.auth_context import AuthContextMiddleware from mcp.server.auth.middleware.bearer_auth import BearerAuthBackend, RequireAuthMiddleware from mcp.server.auth.provider import OAuthAuthorizationServerProvider, ProviderTokenVerifier, TokenVerifier @@ -177,6 +178,7 @@ def __init__( self._tool_manager = ToolManager(tools=tools, warn_on_duplicate_tools=self.settings.warn_on_duplicate_tools) self._resource_manager = ResourceManager(warn_on_duplicate_resources=self.settings.warn_on_duplicate_resources) self._prompt_manager = PromptManager(warn_on_duplicate_prompts=self.settings.warn_on_duplicate_prompts) + self._roots: list[types.Root] = [] # Validate auth configuration if self.settings.auth is not None: if auth_server_provider and token_verifier: @@ -265,6 +267,8 @@ def _setup_handlers(self) -> None: self._mcp_server.list_prompts()(self.list_prompts) self._mcp_server.get_prompt()(self.get_prompt) self._mcp_server.list_resource_templates()(self.list_resource_templates) + self._mcp_server.set_roots()(self.set_roots) + self._mcp_server.list_roots()(self.list_roots) async def list_tools(self) -> list[MCPTool]: """List all available tools.""" @@ -340,6 +344,19 @@ async def read_resource(self, uri: AnyUrl | str) -> Iterable[ReadResourceContent logger.exception(f"Error reading resource {uri}") raise ResourceError(str(e)) + async def set_roots(self, roots: list[types.Root]) -> None: + """Handle set_roots request from the client.""" + self._roots = roots + + async def list_roots(self) -> list[types.Root]: + """Handle list_roots request from the client.""" + return self._roots + + @property + def roots(self) -> list[types.Root]: + """Get the currently set roots.""" + return self._roots + def add_tool( self, fn: AnyFunction, diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 3076e283e..c17cca41c 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -212,6 +212,11 @@ def get_capabilities( if types.CompleteRequest in self.request_handlers: completions_capability = types.CompletionsCapability() + # Set roots capability if handler exists + roots_capability = None + if types.SetRootsRequest in self.request_handlers: + roots_capability = types.RootsCapability() + return types.ServerCapabilities( prompts=prompts_capability, resources=resources_capability, @@ -219,6 +224,7 @@ def get_capabilities( logging=logging_capability, experimental=experimental_capabilities, completions=completions_capability, + roots=roots_capability, ) @property @@ -353,6 +359,36 @@ async def handler(req: types.SetLevelRequest): return decorator + def set_roots(self): + """Register a handler for SetRootsRequest.""" + + def decorator(func: Callable[[list[types.Root]], Awaitable[None]]): + logger.debug("Registering handler for SetRootsRequest") + + async def handler(req: types.SetRootsRequest): + await func(req.params.roots) + return types.ServerResult(types.EmptyResult()) + + self.request_handlers[types.SetRootsRequest] = handler + return func + + return decorator + + def list_roots(self): + """Register a handler for ListRootsRequest.""" + + def decorator(func: Callable[[], Awaitable[list[types.Root]]]): + logger.debug("Registering handler for ListRootsRequest") + + async def handler(_: Any): + roots = await func() + return types.ServerResult(types.ListRootsResult(roots=roots)) + + self.request_handlers[types.ListRootsRequest] = handler + return func + + return decorator + def subscribe_resource(self): def decorator(func: Callable[[AnyUrl], Awaitable[None]]): logger.debug("Registering handler for SubscribeRequest") diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index 7b3680f7c..c0f3e4218 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -111,12 +111,6 @@ def check_client_capability(self, capability: types.ClientCapabilities) -> bool: client_caps = self._client_params.capabilities # Check each specified capability in the passed in capability object - if capability.roots is not None: - if client_caps.roots is None: - return False - if capability.roots.listChanged and not client_caps.roots.listChanged: - return False - if capability.sampling is not None: if client_caps.sampling is None: return False @@ -244,13 +238,6 @@ async def create_message( ), ) - async def list_roots(self) -> types.ListRootsResult: - """Send a roots/list request.""" - return await self.send_request( - types.ServerRequest(types.ListRootsRequest()), - types.ListRootsResult, - ) - async def elicit( self, message: str, diff --git a/src/mcp/shared/memory.py b/src/mcp/shared/memory.py index c94e5e6ac..57d78805e 100644 --- a/src/mcp/shared/memory.py +++ b/src/mcp/shared/memory.py @@ -14,7 +14,6 @@ from mcp.client.session import ( ClientSession, ElicitationFnT, - ListRootsFnT, LoggingFnT, MessageHandlerFnT, SamplingFnT, @@ -55,7 +54,6 @@ async def create_connected_server_and_client_session( server: Server[Any], read_timeout_seconds: timedelta | None = None, sampling_callback: SamplingFnT | None = None, - list_roots_callback: ListRootsFnT | None = None, logging_callback: LoggingFnT | None = None, message_handler: MessageHandlerFnT | None = None, client_info: types.Implementation | None = None, @@ -87,7 +85,6 @@ async def create_connected_server_and_client_session( write_stream=client_write, read_timeout_seconds=read_timeout_seconds, sampling_callback=sampling_callback, - list_roots_callback=list_roots_callback, logging_callback=logging_callback, message_handler=message_handler, client_info=client_info, diff --git a/src/mcp/types.py b/src/mcp/types.py index 62feda87a..0c905750b 100644 --- a/src/mcp/types.py +++ b/src/mcp/types.py @@ -223,8 +223,6 @@ class Implementation(BaseMetadata): class RootsCapability(BaseModel): """Capability for root operations.""" - listChanged: bool | None = None - """Whether the client supports notifications for changes to the roots list.""" model_config = ConfigDict(extra="allow") @@ -249,8 +247,6 @@ class ClientCapabilities(BaseModel): """Present if the client supports sampling from an LLM.""" elicitation: ElicitationCapability | None = None """Present if the client supports elicitation from the user.""" - roots: RootsCapability | None = None - """Present if the client supports listing roots.""" model_config = ConfigDict(extra="allow") @@ -307,6 +303,8 @@ class ServerCapabilities(BaseModel): """Present if the server offers any tools to call.""" completions: CompletionsCapability | None = None """Present if the server offers autocompletion suggestions for prompts and resources.""" + roots: RootsCapability | None = None + """Present if the server supports accepting roots from the client.""" model_config = ConfigDict(extra="allow") @@ -1133,21 +1131,6 @@ class CompleteResult(Result): completion: Completion -class ListRootsRequest(Request[RequestParams | None, Literal["roots/list"]]): - """ - Sent from the server to request a list of root URIs from the client. Roots allow - servers to ask for specific directories or files to operate on. A common example - for roots is providing a set of repositories or directories a server should operate - on. - - This request is typically used when the server needs to understand the file system - structure or access specific locations that the client has permission to read from. - """ - - method: Literal["roots/list"] = "roots/list" - params: RequestParams | None = None - - class Root(BaseModel): """Represents a root directory or file that the server can operate on.""" @@ -1171,30 +1154,43 @@ class Root(BaseModel): model_config = ConfigDict(extra="allow") -class ListRootsResult(Result): +class SetRootsRequestParams(RequestParams): + """Parameters for setting the root directories or files.""" + + roots: list[Root] + """The new list of roots to set. This replaces any existing roots.""" + model_config = ConfigDict(extra="allow") + + +class SetRootsRequest(Request[SetRootsRequestParams, Literal["roots/set"]]): """ - The client's response to a roots/list request from the server. - This result contains an array of Root objects, each representing a root directory - or file that the server can operate on. + Sent from the client to the server to set the root directories or files that the server can operate on. + This replaces the entire list of roots with the new set provided. + The server responds with EmptyResult to acknowledge the roots have been set. """ - roots: list[Root] + method: Literal["roots/set"] = "roots/set" + params: SetRootsRequestParams -class RootsListChangedNotification( - Notification[NotificationParams | None, Literal["notifications/roots/list_changed"]] -): +class ListRootsRequest(Request[RequestParams | None, Literal["roots/list"]]): + """ + Sent from the client to query the list of root URIs currently set on the server. + The server responds with the list of roots that were previously set via SetRootsRequest. """ - A notification from the client to the server, informing it that the list of - roots has changed. - This notification should be sent whenever the client adds, removes, or - modifies any root. The server should then request an updated list of roots - using the ListRootsRequest. + method: Literal["roots/list"] = "roots/list" + params: RequestParams | None = None + + +class ListRootsResult(Result): + """ + The server's response to a roots/list request from the client. + This result contains an array of Root objects representing the root directories + or files currently configured on the server. """ - method: Literal["notifications/roots/list_changed"] = "notifications/roots/list_changed" - params: NotificationParams | None = None + roots: list[Root] class CancelledNotificationParams(NotificationParams): @@ -1232,14 +1228,14 @@ class ClientRequest( | UnsubscribeRequest | CallToolRequest | ListToolsRequest + | SetRootsRequest + | ListRootsRequest ] ): pass -class ClientNotification( - RootModel[CancelledNotification | ProgressNotification | InitializedNotification | RootsListChangedNotification] -): +class ClientNotification(RootModel[CancelledNotification | ProgressNotification | InitializedNotification]): pass @@ -1281,11 +1277,11 @@ class ElicitResult(Result): """ -class ClientResult(RootModel[EmptyResult | CreateMessageResult | ListRootsResult | ElicitResult]): +class ClientResult(RootModel[EmptyResult | CreateMessageResult | ElicitResult]): pass -class ServerRequest(RootModel[PingRequest | CreateMessageRequest | ListRootsRequest | ElicitRequest]): +class ServerRequest(RootModel[PingRequest | CreateMessageRequest | ElicitRequest]): pass @@ -1315,6 +1311,7 @@ class ServerResult( | ReadResourceResult | CallToolResult | ListToolsResult + | ListRootsResult ] ): pass diff --git a/tests/client/test_list_roots_callback.py b/tests/client/test_list_roots_callback.py deleted file mode 100644 index 0da0fff07..000000000 --- a/tests/client/test_list_roots_callback.py +++ /dev/null @@ -1,58 +0,0 @@ -import pytest -from pydantic import FileUrl - -from mcp.client.session import ClientSession -from mcp.server.fastmcp.server import Context -from mcp.server.session import ServerSession -from mcp.shared.context import RequestContext -from mcp.shared.memory import ( - create_connected_server_and_client_session as create_session, -) -from mcp.types import ListRootsResult, Root, TextContent - - -@pytest.mark.anyio -async def test_list_roots_callback(): - from mcp.server.fastmcp import FastMCP - - server = FastMCP("test") - - callback_return = ListRootsResult( - roots=[ - Root( - uri=FileUrl("file://users/fake/test"), - name="Test Root 1", - ), - Root( - uri=FileUrl("file://users/fake/test/2"), - name="Test Root 2", - ), - ] - ) - - async def list_roots_callback( - context: RequestContext[ClientSession, None], - ) -> ListRootsResult: - return callback_return - - @server.tool("test_list_roots") - async def test_list_roots(context: Context[ServerSession, None], message: str): - roots = await context.session.list_roots() - assert roots == callback_return - return True - - # Test with list_roots callback - async with create_session(server._mcp_server, list_roots_callback=list_roots_callback) as client_session: - # Make a request to trigger sampling callback - result = await client_session.call_tool("test_list_roots", {"message": "test message"}) - assert result.isError is False - assert isinstance(result.content[0], TextContent) - assert result.content[0].text == "true" - - # Test without list_roots callback - async with create_session(server._mcp_server) as client_session: - # Make a request to trigger sampling callback - result = await client_session.call_tool("test_list_roots", {"message": "test message"}) - assert result.isError is True - assert isinstance(result.content[0], TextContent) - assert result.content[0].text == "Error executing tool test_list_roots: List roots not supported" diff --git a/tests/client/test_notification_response.py b/tests/client/test_notification_response.py index 88e64711b..0b7963943 100644 --- a/tests/client/test_notification_response.py +++ b/tests/client/test_notification_response.py @@ -21,7 +21,7 @@ from mcp import ClientSession, types from mcp.client.streamable_http import streamablehttp_client from mcp.shared.session import RequestResponder -from mcp.types import ClientNotification, RootsListChangedNotification +from mcp.types import ClientNotification, InitializedNotification def create_non_sdk_server_app() -> Starlette: @@ -142,8 +142,9 @@ async def message_handler( await session.initialize() # The test server returns a 204 instead of the expected 202 + # Send a duplicate initialized notification to test notification handling await session.send_notification( - ClientNotification(RootsListChangedNotification(method="notifications/roots/list_changed")) + ClientNotification(InitializedNotification(method="notifications/initialized")) ) if returned_exception: diff --git a/tests/client/test_session.py b/tests/client/test_session.py index 53b60fce6..6abc82387 100644 --- a/tests/client/test_session.py +++ b/tests/client/test_session.py @@ -413,7 +413,6 @@ async def mock_server(): # Assert that capabilities are properly set with defaults assert received_capabilities is not None assert received_capabilities.sampling is None # No custom sampling callback - assert received_capabilities.roots is None # No custom list_roots callback @pytest.mark.anyio @@ -434,11 +433,6 @@ async def custom_sampling_callback( model="test-model", ) - async def custom_list_roots_callback( - context: RequestContext["ClientSession", Any], - ) -> types.ListRootsResult | types.ErrorData: - return types.ListRootsResult(roots=[]) - async def mock_server(): nonlocal received_capabilities @@ -479,7 +473,6 @@ async def mock_server(): server_to_client_receive, client_to_server_send, sampling_callback=custom_sampling_callback, - list_roots_callback=custom_list_roots_callback, ) as session, anyio.create_task_group() as tg, client_to_server_send, @@ -494,6 +487,3 @@ async def mock_server(): assert received_capabilities is not None assert received_capabilities.sampling is not None # Custom sampling callback provided assert isinstance(received_capabilities.sampling, types.SamplingCapability) - assert received_capabilities.roots is not None # Custom list_roots callback provided - assert isinstance(received_capabilities.roots, types.RootsCapability) - assert received_capabilities.roots.listChanged is True # Should be True for custom callback diff --git a/tests/test_roots.py b/tests/test_roots.py new file mode 100644 index 000000000..5b81cf545 --- /dev/null +++ b/tests/test_roots.py @@ -0,0 +1,106 @@ +import pytest +from pydantic import FileUrl + +import mcp.types as types +from mcp.server.fastmcp import FastMCP +from mcp.shared.memory import ( + create_connected_server_and_client_session as create_session, +) + + +@pytest.mark.anyio +async def test_set_roots(): + server = FastMCP("test") + + @server.tool("check_roots") + async def check_roots(): + return f"Server has {len(server.roots)} roots" + + async with create_session(server._mcp_server) as client_session: + roots_to_set = [ + types.Root( + uri=FileUrl("file:///users/fake/test"), + name="Test Root 1", + ), + types.Root( + uri=FileUrl("file:///users/fake/test/2"), + name="Test Root 2", + ), + ] + + result = await client_session.set_roots(roots_to_set) + assert isinstance(result, types.EmptyResult) + + tool_result = await client_session.call_tool("check_roots", {}) + assert tool_result.isError is False + assert len(tool_result.content) > 0 + content = tool_result.content[0] + assert isinstance(content, types.TextContent) + assert content.text == "Server has 2 roots" + + +@pytest.mark.anyio +async def test_list_roots(): + server = FastMCP("test") + + async with create_session(server._mcp_server) as client_session: + # Initially no roots + result = await client_session.list_roots() + assert result.roots == [] + + # Set some roots + roots_to_set = [ + types.Root(uri=FileUrl("file:///project/src"), name="Source"), + types.Root(uri=FileUrl("file:///project/tests"), name="Tests"), + ] + await client_session.set_roots(roots_to_set) + + # Query them back + result = await client_session.list_roots() + assert len(result.roots) == 2 + assert result.roots[0].name == "Source" + assert result.roots[1].name == "Tests" + + +@pytest.mark.anyio +async def test_roots_replacement(): + server = FastMCP("test") + + @server.tool("get_root_names") + async def get_root_names(): + if not server.roots: + return "No roots" + return ", ".join(r.name or str(r.uri) for r in server.roots) + + async with create_session(server._mcp_server) as client_session: + # Set initial roots + initial_roots = [ + types.Root(uri=FileUrl("file:///project/src"), name="Source Code"), + types.Root(uri=FileUrl("file:///project/tests"), name="Tests"), + types.Root(uri=FileUrl("file:///project/docs"), name="Documentation"), + ] + await client_session.set_roots(initial_roots) + + result = await client_session.call_tool("get_root_names", {}) + content = result.content[0] + assert isinstance(content, types.TextContent) + assert "Source Code" in content.text + assert "Tests" in content.text + assert "Documentation" in content.text + + # Replace with new roots + new_roots = [types.Root(uri=FileUrl("file:///new/location"), name="New Location")] + await client_session.set_roots(new_roots) + + result = await client_session.call_tool("get_root_names", {}) + content = result.content[0] + assert isinstance(content, types.TextContent) + assert content.text == "New Location" + + # Clear roots + await client_session.set_roots([]) + + result = await client_session.call_tool("get_root_names", {}) + content = result.content[0] + assert isinstance(content, types.TextContent) + assert content.text == "No roots"