Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
211 changes: 154 additions & 57 deletions src/fastmcp/server/middleware/caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,132 @@ def unwrap(self) -> ToolResult:
)


class CachableTool(BaseModel):
"""A wrapper for Tool that can be cached, preserving the key field."""

name: str
key: str
title: str | None
description: str | None
parameters: dict[str, Any]
output_schema: dict[str, Any] | None
annotations: Any | None
meta: dict[str, Any] | None
tags: set[str]
enabled: bool

@classmethod
def wrap(cls, tool: Tool) -> Self:
return cls(
name=tool.name,
key=tool.key,
title=tool.title,
description=tool.description,
parameters=tool.parameters,
output_schema=tool.output_schema,
annotations=tool.annotations,
meta=tool.meta,
tags=tool.tags,
enabled=tool.enabled,
)

def unwrap(self) -> Tool:
return Tool(
name=self.name,
key=self.key,
title=self.title,
description=self.description,
parameters=self.parameters,
output_schema=self.output_schema,
annotations=self.annotations,
meta=self.meta,
tags=self.tags,
enabled=self.enabled,
)


class CachableResource(BaseModel):
"""A wrapper for Resource that can be cached, preserving the key field."""

name: str
key: str
title: str | None
description: str | None
uri: str
mime_type: str
annotations: Any | None
meta: dict[str, Any] | None
tags: set[str]
enabled: bool

@classmethod
def wrap(cls, resource: Resource) -> Self:
return cls(
name=resource.name,
key=resource.key,
title=resource.title,
description=resource.description,
uri=str(resource.uri),
mime_type=resource.mime_type,
annotations=resource.annotations,
meta=resource.meta,
tags=resource.tags,
enabled=resource.enabled,
)

def unwrap(self) -> Resource:
return Resource(
name=self.name,
key=self.key,
title=self.title,
description=self.description,
uri=self.uri,
mime_type=self.mime_type,
annotations=self.annotations,
meta=self.meta,
tags=self.tags,
enabled=self.enabled,
)


class CachablePrompt(BaseModel):
"""A wrapper for Prompt that can be cached, preserving the key field."""

name: str
key: str
title: str | None
description: str | None
arguments: list[Any] | None
meta: dict[str, Any] | None
tags: set[str]
enabled: bool

@classmethod
def wrap(cls, prompt: Prompt) -> Self:
return cls(
name=prompt.name,
key=prompt.key,
title=prompt.title,
description=prompt.description,
arguments=prompt.arguments,
meta=prompt.meta,
tags=prompt.tags,
enabled=prompt.enabled,
)

def unwrap(self) -> Prompt:
return Prompt(
name=self.name,
key=self.key,
title=self.title,
description=self.description,
arguments=self.arguments,
meta=self.meta,
tags=self.tags,
enabled=self.enabled,
)


class SharedMethodSettings(TypedDict):
"""Shared config for a cache method."""

Expand Down Expand Up @@ -182,22 +308,26 @@ def __init__(
call_tool_settings or CallToolSettings()
)

self._list_tools_cache: PydanticAdapter[list[Tool]] = PydanticAdapter(
self._list_tools_cache: PydanticAdapter[list[CachableTool]] = PydanticAdapter(
key_value=self._stats,
pydantic_model=list[Tool],
pydantic_model=list[CachableTool],
default_collection="tools/list",
)

self._list_resources_cache: PydanticAdapter[list[Resource]] = PydanticAdapter(
key_value=self._stats,
pydantic_model=list[Resource],
default_collection="resources/list",
self._list_resources_cache: PydanticAdapter[list[CachableResource]] = (
PydanticAdapter(
key_value=self._stats,
pydantic_model=list[CachableResource],
default_collection="resources/list",
)
)

self._list_prompts_cache: PydanticAdapter[list[Prompt]] = PydanticAdapter(
key_value=self._stats,
pydantic_model=list[Prompt],
default_collection="prompts/list",
self._list_prompts_cache: PydanticAdapter[list[CachablePrompt]] = (
PydanticAdapter(
key_value=self._stats,
pydantic_model=list[CachablePrompt],
default_collection="prompts/list",
)
)

self._read_resource_cache: PydanticAdapter[
Expand Down Expand Up @@ -234,33 +364,20 @@ async def on_list_tools(
return await call_next(context)

if cached_value := await self._list_tools_cache.get(key=GLOBAL_KEY):
return cached_value
return [item.unwrap() for item in cached_value]

tools: Sequence[Tool] = await call_next(context=context)

# Turn any subclass of Tool into a Tool
cachable_tools: list[Tool] = [
Tool(
name=tool.name,
title=tool.title,
description=tool.description,
parameters=tool.parameters,
output_schema=tool.output_schema,
annotations=tool.annotations,
meta=tool.meta,
tags=tool.tags,
enabled=tool.enabled,
)
for tool in tools
]
# Wrap tools in cacheable models
cachable_tools: list[CachableTool] = [CachableTool.wrap(tool) for tool in tools]

await self._list_tools_cache.put(
key=GLOBAL_KEY,
value=cachable_tools,
ttl=self._list_tools_settings.get("ttl", FIVE_MINUTES_IN_SECONDS),
)

return cachable_tools
return [item.unwrap() for item in cachable_tools]

@override
async def on_list_resources(
Expand All @@ -274,24 +391,13 @@ async def on_list_resources(
return await call_next(context)

if cached_value := await self._list_resources_cache.get(key=GLOBAL_KEY):
return cached_value
return [item.unwrap() for item in cached_value]

resources: Sequence[Resource] = await call_next(context=context)

# Turn any subclass of Resource into a Resource
cachable_resources: list[Resource] = [
Resource(
name=resource.name,
title=resource.title,
description=resource.description,
tags=resource.tags,
meta=resource.meta,
mime_type=resource.mime_type,
annotations=resource.annotations,
enabled=resource.enabled,
uri=resource.uri,
)
for resource in resources
# Wrap resources in cacheable models
cachable_resources: list[CachableResource] = [
CachableResource.wrap(resource) for resource in resources
]

await self._list_resources_cache.put(
Expand All @@ -300,7 +406,7 @@ async def on_list_resources(
ttl=self._list_resources_settings.get("ttl", FIVE_MINUTES_IN_SECONDS),
)

return cachable_resources
return [item.unwrap() for item in cachable_resources]

@override
async def on_list_prompts(
Expand All @@ -314,22 +420,13 @@ async def on_list_prompts(
return await call_next(context)

if cached_value := await self._list_prompts_cache.get(key=GLOBAL_KEY):
return cached_value
return [item.unwrap() for item in cached_value]

prompts: Sequence[Prompt] = await call_next(context=context)

# Turn any subclass of Prompt into a Prompt
cachable_prompts: list[Prompt] = [
Prompt(
name=prompt.name,
title=prompt.title,
description=prompt.description,
tags=prompt.tags,
meta=prompt.meta,
enabled=prompt.enabled,
arguments=prompt.arguments,
)
for prompt in prompts
# Wrap prompts in cacheable models
cachable_prompts: list[CachablePrompt] = [
CachablePrompt.wrap(prompt) for prompt in prompts
]

await self._list_prompts_cache.put(
Expand All @@ -338,7 +435,7 @@ async def on_list_prompts(
ttl=self._list_prompts_settings.get("ttl", FIVE_MINUTES_IN_SECONDS),
)

return cachable_prompts
return [item.unwrap() for item in cachable_prompts]

@override
async def on_call_tool(
Expand Down
76 changes: 76 additions & 0 deletions tests/server/middleware/test_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,3 +505,79 @@ async def test_statistics(
),
)
)

async def test_mounted_server_prefixes_preserved(self):
"""Test that caching preserves prefixes from mounted servers."""
# Create child servers with tools, resources, and prompts
child = FastMCP("child")
calculator = TrackingCalculator()
calculator.add_tools(fastmcp=child)
calculator.add_resources(fastmcp=child)
calculator.add_prompts(fastmcp=child)

# Create parent with caching middleware
parent = FastMCP("parent")
parent.add_middleware(ResponseCachingMiddleware())
await parent.import_server(child, prefix="child")

async with Client[FastMCPTransport](transport=parent) as client:
# First call - populates cache
tools1 = await client.list_tools()
tool_names1 = [tool.name for tool in tools1]

# Second call - from cache (this is where the bug would occur)
tools2 = await client.list_tools()
tool_names2 = [tool.name for tool in tools2]

# All tools should have the prefix in both calls
for name in tool_names1:
assert name.startswith("child_"), (
f"Tool {name} missing prefix (first call)"
)
for name in tool_names2:
assert name.startswith("child_"), (
f"Tool {name} missing prefix (cached call)"
)

# Both calls should return the same tools
assert tool_names1 == tool_names2

# Verify tool can be called with prefixed name
result = await client.call_tool("child_add", {"a": 5, "b": 3})
assert not result.is_error

# Test resources
resources1 = await client.list_resources()
resource_names1 = [resource.name for resource in resources1]

resources2 = await client.list_resources()
resource_names2 = [resource.name for resource in resources2]

for name in resource_names1:
assert name.startswith("child_"), (
f"Resource {name} missing prefix (first call)"
)
for name in resource_names2:
assert name.startswith("child_"), (
f"Resource {name} missing prefix (cached call)"
)

assert resource_names1 == resource_names2

# Test prompts
prompts1 = await client.list_prompts()
prompt_names1 = [prompt.name for prompt in prompts1]

prompts2 = await client.list_prompts()
prompt_names2 = [prompt.name for prompt in prompts2]

for name in prompt_names1:
assert name.startswith("child_"), (
f"Prompt {name} missing prefix (first call)"
)
for name in prompt_names2:
assert name.startswith("child_"), (
f"Prompt {name} missing prefix (cached call)"
)

assert prompt_names1 == prompt_names2
Loading