Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitattribute
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Generated
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this supposed to be commited?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes

uv.lock linguist-generated=true
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ coverage.xml
*.py,cover
.hypothesis/
.pytest_cache/
.ruff_cache/
cover/

# Translations
Expand Down Expand Up @@ -168,3 +169,6 @@ cython_debug/
.vscode/
.windsurfrules
**/CLAUDE.local.md

# claude code
.claude/
2 changes: 1 addition & 1 deletion src/mcp/server/auth/handlers/authorize.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ async def error_response(
if client is None and attempt_load_client:
# make last-ditch attempt to load the client
client_id = best_effort_extract_string("client_id", params)
client = client_id and await self.provider.get_client(client_id)
client = await self.provider.get_client(client_id) if client_id else None
if redirect_uri is None and client:
# make last-ditch effort to load the redirect uri
try:
Expand Down
39 changes: 33 additions & 6 deletions src/mcp/server/fastmcp/prompts/base.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,23 @@
"""Base classes for FastMCP prompts."""

from __future__ import annotations

import inspect
from collections.abc import Awaitable, Callable, Sequence
from typing import Any, Literal
from typing import TYPE_CHECKING, Any, Literal

import pydantic_core
from pydantic import BaseModel, Field, TypeAdapter, validate_call

from mcp.server.fastmcp.utilities.context_injection import find_context_parameter, inject_context
from mcp.server.fastmcp.utilities.func_metadata import func_metadata
from mcp.types import ContentBlock, TextContent

if TYPE_CHECKING:
from mcp.server.fastmcp.server import Context
from mcp.server.session import ServerSessionT
from mcp.shared.context import LifespanContextT, RequestT


class Message(BaseModel):
"""Base class for all prompt messages."""
Expand Down Expand Up @@ -62,6 +71,7 @@ class Prompt(BaseModel):
description: str | None = Field(None, description="Description of what the prompt does")
arguments: list[PromptArgument] | None = Field(None, description="Arguments that can be passed to the prompt")
fn: Callable[..., PromptResult | Awaitable[PromptResult]] = Field(exclude=True)
context_kwarg: str | None = Field(None, description="Name of the kwarg that should receive context", exclude=True)

@classmethod
def from_function(
Expand All @@ -70,7 +80,8 @@ def from_function(
name: str | None = None,
title: str | None = None,
description: str | None = None,
) -> "Prompt":
context_kwarg: str | None = None,
) -> Prompt:
"""Create a Prompt from a function.

The function can return:
Expand All @@ -84,8 +95,16 @@ def from_function(
if func_name == "<lambda>":
raise ValueError("You must provide a name for lambda functions")

# Get schema from TypeAdapter - will fail if function isn't properly typed
parameters = TypeAdapter(fn).json_schema()
# Find context parameter if it exists
if context_kwarg is None:
context_kwarg = find_context_parameter(fn)

# Get schema from func_metadata, excluding context parameter
func_arg_metadata = func_metadata(
fn,
skip_names=[context_kwarg] if context_kwarg is not None else [],
)
parameters = func_arg_metadata.arg_model.model_json_schema()

# Convert parameters to PromptArguments
arguments: list[PromptArgument] = []
Expand All @@ -109,9 +128,14 @@ def from_function(
description=description or fn.__doc__ or "",
arguments=arguments,
fn=fn,
context_kwarg=context_kwarg,
)

async def render(self, arguments: dict[str, Any] | None = None) -> list[Message]:
async def render(
self,
arguments: dict[str, Any] | None = None,
context: Context[ServerSessionT, LifespanContextT, RequestT] | None = None,
) -> list[Message]:
"""Render the prompt with arguments."""
# Validate required arguments
if self.arguments:
Expand All @@ -122,8 +146,11 @@ async def render(self, arguments: dict[str, Any] | None = None) -> list[Message]
raise ValueError(f"Missing required arguments: {missing}")

try:
# Add context to arguments if needed
call_args = inject_context(self.fn, arguments or {}, context, self.context_kwarg)

# Call function and check if result is a coroutine
result = self.fn(**(arguments or {}))
result = self.fn(**call_args)
if inspect.iscoroutine(result):
result = await result

Expand Down
18 changes: 15 additions & 3 deletions src/mcp/server/fastmcp/prompts/manager.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
"""Prompt management functionality."""

from typing import Any
from __future__ import annotations

from typing import TYPE_CHECKING, Any

from mcp.server.fastmcp.prompts.base import Message, Prompt
from mcp.server.fastmcp.utilities.logging import get_logger

if TYPE_CHECKING:
from mcp.server.fastmcp.server import Context
from mcp.server.session import ServerSessionT
from mcp.shared.context import LifespanContextT, RequestT

logger = get_logger(__name__)


Expand Down Expand Up @@ -39,10 +46,15 @@ def add_prompt(
self._prompts[prompt.name] = prompt
return prompt

async def render_prompt(self, name: str, arguments: dict[str, Any] | None = None) -> list[Message]:
async def render_prompt(
self,
name: str,
arguments: dict[str, Any] | None = None,
context: Context[ServerSessionT, LifespanContextT, RequestT] | None = None,
) -> list[Message]:
"""Render a prompt by name with arguments."""
prompt = self.get_prompt(name)
if not prompt:
raise ValueError(f"Unknown prompt: {name}")

return await prompt.render(arguments)
return await prompt.render(arguments, context=context)
17 changes: 14 additions & 3 deletions src/mcp/server/fastmcp/resources/resource_manager.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,21 @@
"""Resource manager functionality."""

from __future__ import annotations

from collections.abc import Callable
from typing import Any
from typing import TYPE_CHECKING, Any

from pydantic import AnyUrl

from mcp.server.fastmcp.resources.base import Resource
from mcp.server.fastmcp.resources.templates import ResourceTemplate
from mcp.server.fastmcp.utilities.logging import get_logger

if TYPE_CHECKING:
from mcp.server.fastmcp.server import Context
from mcp.server.session import ServerSessionT
from mcp.shared.context import LifespanContextT, RequestT

logger = get_logger(__name__)


Expand Down Expand Up @@ -67,7 +74,11 @@ def add_template(
self._templates[template.uri_template] = template
return template

async def get_resource(self, uri: AnyUrl | str) -> Resource | None:
async def get_resource(
self,
uri: AnyUrl | str,
context: Context[ServerSessionT, LifespanContextT, RequestT] | None = None,
) -> Resource | None:
"""Get resource by URI, checking concrete resources first, then templates."""
uri_str = str(uri)
logger.debug("Getting resource", extra={"uri": uri_str})
Expand All @@ -80,7 +91,7 @@ async def get_resource(self, uri: AnyUrl | str) -> Resource | None:
for template in self._templates.values():
if params := template.matches(uri_str):
try:
return await template.create_resource(uri_str, params)
return await template.create_resource(uri_str, params, context=context)
except Exception as e:
raise ValueError(f"Error creating resource from template: {e}")

Expand Down
36 changes: 31 additions & 5 deletions src/mcp/server/fastmcp/resources/templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,18 @@
import inspect
import re
from collections.abc import Callable
from typing import Any
from typing import TYPE_CHECKING, Any

from pydantic import BaseModel, Field, TypeAdapter, validate_call
from pydantic import BaseModel, Field, validate_call

from mcp.server.fastmcp.resources.types import FunctionResource, Resource
from mcp.server.fastmcp.utilities.context_injection import find_context_parameter, inject_context
from mcp.server.fastmcp.utilities.func_metadata import func_metadata

if TYPE_CHECKING:
from mcp.server.fastmcp.server import Context
from mcp.server.session import ServerSessionT
from mcp.shared.context import LifespanContextT, RequestT


class ResourceTemplate(BaseModel):
Expand All @@ -22,6 +29,7 @@ class ResourceTemplate(BaseModel):
mime_type: str = Field(default="text/plain", description="MIME type of the resource content")
fn: Callable[..., Any] = Field(exclude=True)
parameters: dict[str, Any] = Field(description="JSON schema for function parameters")
context_kwarg: str | None = Field(None, description="Name of the kwarg that should receive context")

@classmethod
def from_function(
Expand All @@ -32,14 +40,23 @@ def from_function(
title: str | None = None,
description: str | None = None,
mime_type: str | None = None,
context_kwarg: str | None = None,
) -> ResourceTemplate:
"""Create a template from a function."""
func_name = name or fn.__name__
if func_name == "<lambda>":
raise ValueError("You must provide a name for lambda functions")

# Get schema from TypeAdapter - will fail if function isn't properly typed
parameters = TypeAdapter(fn).json_schema()
# Find context parameter if it exists
if context_kwarg is None:
context_kwarg = find_context_parameter(fn)

# Get schema from func_metadata, excluding context parameter
func_arg_metadata = func_metadata(
fn,
skip_names=[context_kwarg] if context_kwarg is not None else [],
)
parameters = func_arg_metadata.arg_model.model_json_schema()

# ensure the arguments are properly cast
fn = validate_call(fn)
Expand All @@ -52,6 +69,7 @@ def from_function(
mime_type=mime_type or "text/plain",
fn=fn,
parameters=parameters,
context_kwarg=context_kwarg,
)

def matches(self, uri: str) -> dict[str, Any] | None:
Expand All @@ -63,9 +81,17 @@ def matches(self, uri: str) -> dict[str, Any] | None:
return match.groupdict()
return None

async def create_resource(self, uri: str, params: dict[str, Any]) -> Resource:
async def create_resource(
self,
uri: str,
params: dict[str, Any],
context: Context[ServerSessionT, LifespanContextT, RequestT] | None = None,
) -> Resource:
"""Create a resource from the template with the given parameters."""
try:
# Add context to params if needed
params = inject_context(self.fn, params, context, self.context_kwarg)

# Call function and check if result is a coroutine
result = self.fn(**params)
if inspect.iscoroutine(result):
Expand Down
18 changes: 13 additions & 5 deletions src/mcp/server/fastmcp/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from mcp.server.fastmcp.prompts import Prompt, PromptManager
from mcp.server.fastmcp.resources import FunctionResource, Resource, ResourceManager
from mcp.server.fastmcp.tools import Tool, ToolManager
from mcp.server.fastmcp.utilities.context_injection import find_context_parameter
from mcp.server.fastmcp.utilities.logging import configure_logging, get_logger
from mcp.server.lowlevel.helper_types import ReadResourceContents
from mcp.server.lowlevel.server import LifespanResultT
Expand Down Expand Up @@ -326,7 +327,8 @@ async def list_resource_templates(self) -> list[MCPResourceTemplate]:
async def read_resource(self, uri: AnyUrl | str) -> Iterable[ReadResourceContents]:
"""Read a resource by URI."""

resource = await self._resource_manager.get_resource(uri)
context = self.get_context()
resource = await self._resource_manager.get_resource(uri, context=context)
if not resource:
raise ResourceError(f"Unknown resource: {uri}")

Expand Down Expand Up @@ -510,13 +512,19 @@ async def get_weather(city: str) -> str:

def decorator(fn: AnyFunction) -> AnyFunction:
# Check if this should be a template
sig = inspect.signature(fn)
has_uri_params = "{" in uri and "}" in uri
has_func_params = bool(inspect.signature(fn).parameters)
has_func_params = bool(sig.parameters)

if has_uri_params or has_func_params:
# Validate that URI params match function params
# Check for Context parameter to exclude from validation
context_param = find_context_parameter(fn)

# Validate that URI params match function params (excluding context)
uri_params = set(re.findall(r"{(\w+)}", uri))
func_params = set(inspect.signature(fn).parameters.keys())
# We need to remove the context_param from the resource function if
# there is any.
func_params = {p for p in sig.parameters.keys() if p != context_param}

if uri_params != func_params:
raise ValueError(
Expand Down Expand Up @@ -982,7 +990,7 @@ async def get_prompt(self, name: str, arguments: dict[str, Any] | None = None) -
if not prompt:
raise ValueError(f"Unknown prompt: {name}")

messages = await prompt.render(arguments)
messages = await prompt.render(arguments, context=self.get_context())

return GetPromptResult(
description=prompt.description,
Expand Down
13 changes: 3 additions & 10 deletions src/mcp/server/fastmcp/tools/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@
import inspect
from collections.abc import Callable
from functools import cached_property
from typing import TYPE_CHECKING, Any, get_origin
from typing import TYPE_CHECKING, Any

from pydantic import BaseModel, Field

from mcp.server.fastmcp.exceptions import ToolError
from mcp.server.fastmcp.utilities.context_injection import find_context_parameter
from mcp.server.fastmcp.utilities.func_metadata import FuncMetadata, func_metadata
from mcp.types import ToolAnnotations

Expand Down Expand Up @@ -49,8 +50,6 @@ def from_function(
structured_output: bool | None = None,
) -> Tool:
"""Create a Tool from a function."""
from mcp.server.fastmcp.server import Context

func_name = name or fn.__name__

if func_name == "<lambda>":
Expand All @@ -60,13 +59,7 @@ def from_function(
is_async = _is_async_callable(fn)

if context_kwarg is None:
sig = inspect.signature(fn)
for param_name, param in sig.parameters.items():
if get_origin(param.annotation) is not None:
continue
if issubclass(param.annotation, Context):
context_kwarg = param_name
break
context_kwarg = find_context_parameter(fn)

func_arg_metadata = func_metadata(
fn,
Expand Down
Loading
Loading