Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 17 additions & 36 deletions src/mcp/client/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,6 @@ async def __call__(
) -> types.ElicitResult | types.ErrorData: ...


class ListRootsFnT(Protocol):
async def __call__(
self, context: RequestContext["ClientSession", Any]
) -> types.ListRootsResult | types.ErrorData: ...


class LoggingFnT(Protocol):
async def __call__(
self,
Expand Down Expand Up @@ -80,15 +74,6 @@ async def _default_elicitation_callback(
)


async def _default_list_roots_callback(
context: RequestContext["ClientSession", Any],
) -> types.ListRootsResult | types.ErrorData:
return types.ErrorData(
code=types.INVALID_REQUEST,
message="List roots not supported",
)


async def _default_logging_callback(
params: types.LoggingMessageNotificationParams,
) -> None:
Expand All @@ -114,7 +99,6 @@ def __init__(
read_timeout_seconds: timedelta | None = None,
sampling_callback: SamplingFnT | None = None,
elicitation_callback: ElicitationFnT | None = None,
list_roots_callback: ListRootsFnT | None = None,
logging_callback: LoggingFnT | None = None,
message_handler: MessageHandlerFnT | None = None,
client_info: types.Implementation | None = None,
Expand All @@ -129,7 +113,6 @@ def __init__(
self._client_info = client_info or DEFAULT_CLIENT_INFO
self._sampling_callback = sampling_callback or _default_sampling_callback
self._elicitation_callback = elicitation_callback or _default_elicitation_callback
self._list_roots_callback = list_roots_callback or _default_list_roots_callback
self._logging_callback = logging_callback or _default_logging_callback
self._message_handler = message_handler or _default_message_handler
self._tool_output_schemas: dict[str, dict[str, Any] | None] = {}
Expand All @@ -139,15 +122,6 @@ async def initialize(self) -> types.InitializeResult:
elicitation = (
types.ElicitationCapability() if self._elicitation_callback is not _default_elicitation_callback else None
)
roots = (
# TODO: Should this be based on whether we
# _will_ send notifications, or only whether
# they're supported?
types.RootsCapability(listChanged=True)
if self._list_roots_callback is not _default_list_roots_callback
else None
)

result = await self.send_request(
types.ClientRequest(
types.InitializeRequest(
Expand All @@ -157,7 +131,6 @@ async def initialize(self) -> types.InitializeResult:
sampling=sampling,
elicitation=elicitation,
experimental=None,
roots=roots,
),
clientInfo=self._client_info,
),
Expand Down Expand Up @@ -381,9 +354,23 @@ async def list_tools(self, cursor: str | None = None) -> types.ListToolsResult:

return result

async def send_roots_list_changed(self) -> None:
"""Send a roots/list_changed notification."""
await self.send_notification(types.ClientNotification(types.RootsListChangedNotification()))
async def set_roots(self, roots: list[types.Root]) -> types.EmptyResult:
"""Send a roots/set request to set the server's root directories."""
return await self.send_request(
types.ClientRequest(
types.SetRootsRequest(
params=types.SetRootsRequestParams(roots=roots),
)
),
types.EmptyResult,
)

async def list_roots(self) -> types.ListRootsResult:
"""Send a roots/list request to query the server's current roots."""
return await self.send_request(
types.ClientRequest(types.ListRootsRequest()),
types.ListRootsResult,
)

async def _received_request(self, responder: RequestResponder[types.ServerRequest, types.ClientResult]) -> None:
ctx = RequestContext[ClientSession, Any](
Expand All @@ -406,12 +393,6 @@ async def _received_request(self, responder: RequestResponder[types.ServerReques
client_response = ClientResponse.validate_python(response)
await responder.respond(client_response)

case types.ListRootsRequest():
with responder:
response = await self._list_roots_callback(ctx)
client_response = ClientResponse.validate_python(response)
await responder.respond(client_response)

case types.PingRequest():
with responder:
return await responder.respond(types.ClientResult(root=types.EmptyResult()))
Expand Down
17 changes: 17 additions & 0 deletions src/mcp/server/fastmcp/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from starlette.routing import Mount, Route
from starlette.types import Receive, Scope, Send

import mcp.types as types
from mcp.server.auth.middleware.auth_context import AuthContextMiddleware
from mcp.server.auth.middleware.bearer_auth import BearerAuthBackend, RequireAuthMiddleware
from mcp.server.auth.provider import OAuthAuthorizationServerProvider, ProviderTokenVerifier, TokenVerifier
Expand Down Expand Up @@ -177,6 +178,7 @@ def __init__(
self._tool_manager = ToolManager(tools=tools, warn_on_duplicate_tools=self.settings.warn_on_duplicate_tools)
self._resource_manager = ResourceManager(warn_on_duplicate_resources=self.settings.warn_on_duplicate_resources)
self._prompt_manager = PromptManager(warn_on_duplicate_prompts=self.settings.warn_on_duplicate_prompts)
self._roots: list[types.Root] = []
# Validate auth configuration
if self.settings.auth is not None:
if auth_server_provider and token_verifier:
Expand Down Expand Up @@ -265,6 +267,8 @@ def _setup_handlers(self) -> None:
self._mcp_server.list_prompts()(self.list_prompts)
self._mcp_server.get_prompt()(self.get_prompt)
self._mcp_server.list_resource_templates()(self.list_resource_templates)
self._mcp_server.set_roots()(self.set_roots)
self._mcp_server.list_roots()(self.list_roots)

async def list_tools(self) -> list[MCPTool]:
"""List all available tools."""
Expand Down Expand Up @@ -340,6 +344,19 @@ async def read_resource(self, uri: AnyUrl | str) -> Iterable[ReadResourceContent
logger.exception(f"Error reading resource {uri}")
raise ResourceError(str(e))

async def set_roots(self, roots: list[types.Root]) -> None:
"""Handle set_roots request from the client."""
self._roots = roots

async def list_roots(self) -> list[types.Root]:
"""Handle list_roots request from the client."""
return self._roots

@property
def roots(self) -> list[types.Root]:
"""Get the currently set roots."""
return self._roots

def add_tool(
self,
fn: AnyFunction,
Expand Down
36 changes: 36 additions & 0 deletions src/mcp/server/lowlevel/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,13 +212,19 @@ def get_capabilities(
if types.CompleteRequest in self.request_handlers:
completions_capability = types.CompletionsCapability()

# Set roots capability if handler exists
roots_capability = None
if types.SetRootsRequest in self.request_handlers:
roots_capability = types.RootsCapability()

return types.ServerCapabilities(
prompts=prompts_capability,
resources=resources_capability,
tools=tools_capability,
logging=logging_capability,
experimental=experimental_capabilities,
completions=completions_capability,
roots=roots_capability,
)

@property
Expand Down Expand Up @@ -353,6 +359,36 @@ async def handler(req: types.SetLevelRequest):

return decorator

def set_roots(self):
"""Register a handler for SetRootsRequest."""

def decorator(func: Callable[[list[types.Root]], Awaitable[None]]):
logger.debug("Registering handler for SetRootsRequest")

async def handler(req: types.SetRootsRequest):
await func(req.params.roots)
return types.ServerResult(types.EmptyResult())

self.request_handlers[types.SetRootsRequest] = handler
return func

return decorator

def list_roots(self):
"""Register a handler for ListRootsRequest."""

def decorator(func: Callable[[], Awaitable[list[types.Root]]]):
logger.debug("Registering handler for ListRootsRequest")

async def handler(_: Any):
roots = await func()
return types.ServerResult(types.ListRootsResult(roots=roots))

self.request_handlers[types.ListRootsRequest] = handler
return func

return decorator

def subscribe_resource(self):
def decorator(func: Callable[[AnyUrl], Awaitable[None]]):
logger.debug("Registering handler for SubscribeRequest")
Expand Down
13 changes: 0 additions & 13 deletions src/mcp/server/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,12 +111,6 @@ def check_client_capability(self, capability: types.ClientCapabilities) -> bool:
client_caps = self._client_params.capabilities

# Check each specified capability in the passed in capability object
if capability.roots is not None:
if client_caps.roots is None:
return False
if capability.roots.listChanged and not client_caps.roots.listChanged:
return False

if capability.sampling is not None:
if client_caps.sampling is None:
return False
Expand Down Expand Up @@ -244,13 +238,6 @@ async def create_message(
),
)

async def list_roots(self) -> types.ListRootsResult:
"""Send a roots/list request."""
return await self.send_request(
types.ServerRequest(types.ListRootsRequest()),
types.ListRootsResult,
)

async def elicit(
self,
message: str,
Expand Down
3 changes: 0 additions & 3 deletions src/mcp/shared/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from mcp.client.session import (
ClientSession,
ElicitationFnT,
ListRootsFnT,
LoggingFnT,
MessageHandlerFnT,
SamplingFnT,
Expand Down Expand Up @@ -55,7 +54,6 @@ async def create_connected_server_and_client_session(
server: Server[Any],
read_timeout_seconds: timedelta | None = None,
sampling_callback: SamplingFnT | None = None,
list_roots_callback: ListRootsFnT | None = None,
logging_callback: LoggingFnT | None = None,
message_handler: MessageHandlerFnT | None = None,
client_info: types.Implementation | None = None,
Expand Down Expand Up @@ -87,7 +85,6 @@ async def create_connected_server_and_client_session(
write_stream=client_write,
read_timeout_seconds=read_timeout_seconds,
sampling_callback=sampling_callback,
list_roots_callback=list_roots_callback,
logging_callback=logging_callback,
message_handler=message_handler,
client_info=client_info,
Expand Down
Loading