diff --git a/examples/snippets/servers/context_resource.py b/examples/snippets/servers/context_resource.py new file mode 100644 index 000000000..d2d7c5409 --- /dev/null +++ b/examples/snippets/servers/context_resource.py @@ -0,0 +1,11 @@ +from mcp.server.fastmcp import Context, FastMCP +from mcp.server.session import ServerSession + +mcp = FastMCP(name="Context Resource Example") + + +@mcp.resource("resource://only_context") +def resource_only_context(ctx: Context[ServerSession, None]) -> str: + """Resource that only receives context.""" + assert ctx is not None + return "Resource with only context injected" diff --git a/src/mcp/server/fastmcp/resources/base.py b/src/mcp/server/fastmcp/resources/base.py index 0bef1a266..4a6ec0420 100644 --- a/src/mcp/server/fastmcp/resources/base.py +++ b/src/mcp/server/fastmcp/resources/base.py @@ -1,7 +1,7 @@ """Base classes and interfaces for FastMCP resources.""" import abc -from typing import Annotated +from typing import Annotated, Any from pydantic import ( AnyUrl, @@ -43,6 +43,6 @@ def set_default_name(cls, name: str | None, info: ValidationInfo) -> str: raise ValueError("Either name or uri must be provided") @abc.abstractmethod - async def read(self) -> str | bytes: + async def read(self, context: Any | None = None) -> str | bytes: """Read the resource content.""" pass diff --git a/src/mcp/server/fastmcp/resources/types.py b/src/mcp/server/fastmcp/resources/types.py index c578e23de..fbfcc21ed 100644 --- a/src/mcp/server/fastmcp/resources/types.py +++ b/src/mcp/server/fastmcp/resources/types.py @@ -14,6 +14,7 @@ from pydantic import AnyUrl, Field, ValidationInfo, validate_call from mcp.server.fastmcp.resources.base import Resource +from mcp.server.fastmcp.utilities.context_injection import find_context_parameter from mcp.types import Icon @@ -22,7 +23,7 @@ class TextResource(Resource): text: str = Field(description="Text content of the resource") - async def read(self) -> str: + async def read(self, context: Any | None = None) -> str: """Read the text content.""" return self.text @@ -32,7 +33,7 @@ class BinaryResource(Resource): data: bytes = Field(description="Binary content of the resource") - async def read(self) -> bytes: + async def read(self, context: Any | None = None) -> bytes: """Read the binary content.""" return self.data @@ -51,24 +52,30 @@ class FunctionResource(Resource): """ fn: Callable[[], Any] = Field(exclude=True) + context_kwarg: str | None = Field(None, exclude=True) + + async def read(self, context: Any | None = None) -> str | bytes: + """Read the resource content by calling the function.""" + args = {} + if self.context_kwarg: + args[self.context_kwarg] = context - async def read(self) -> str | bytes: - """Read the resource by calling the wrapped function.""" try: - # Call the function first to see if it returns a coroutine - result = self.fn() - # If it's a coroutine, await it - if inspect.iscoroutine(result): - result = await result - - if isinstance(result, Resource): - return await result.read() - elif isinstance(result, bytes): - return result - elif isinstance(result, str): - return result + if inspect.iscoroutinefunction(self.fn): + result = await self.fn(**args) else: - return pydantic_core.to_json(result, fallback=str, indent=2).decode() + result = self.fn(**args) + + if isinstance(result, str | bytes): + return result + if isinstance(result, pydantic.BaseModel): + return result.model_dump_json(indent=2) + + # For other types, convert to a JSON string + try: + return json.dumps(pydantic_core.to_jsonable_python(result)) + except pydantic_core.PydanticSerializationError: + return json.dumps(str(result)) except Exception as e: raise ValueError(f"Error reading resource {self.uri}: {e}") @@ -88,6 +95,8 @@ def from_function( if func_name == "": raise ValueError("You must provide a name for lambda functions") + context_kwarg = find_context_parameter(fn) + # ensure the arguments are properly cast fn = validate_call(fn) @@ -99,6 +108,7 @@ def from_function( mime_type=mime_type or "text/plain", fn=fn, icons=icons, + context_kwarg=context_kwarg, ) @@ -135,7 +145,7 @@ def set_binary_from_mime_type(cls, is_binary: bool, info: ValidationInfo) -> boo mime_type = info.data.get("mime_type", "text/plain") return not mime_type.startswith("text/") - async def read(self) -> str | bytes: + async def read(self, context: Any | None = None) -> str | bytes: """Read the file content.""" try: if self.is_binary: @@ -151,7 +161,7 @@ class HttpResource(Resource): url: str = Field(description="URL to fetch content from") mime_type: str = Field(default="application/json", description="MIME type of the resource content") - async def read(self) -> str | bytes: + async def read(self, context: Any | None = None) -> str | bytes: """Read the HTTP content.""" async with httpx.AsyncClient() as client: response = await client.get(self.url) @@ -189,7 +199,7 @@ def list_files(self) -> list[Path]: except Exception as e: raise ValueError(f"Error listing directory {self.path}: {e}") - async def read(self) -> str: # Always returns JSON string + async def read(self, context: Any | None = None) -> str: # Always returns JSON string """Read the directory listing.""" try: files = await anyio.to_thread.run_sync(self.list_files) diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index 485ef1519..d587c2ab7 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -349,7 +349,7 @@ async def read_resource(self, uri: AnyUrl | str) -> Iterable[ReadResourceContent raise ResourceError(f"Unknown resource: {uri}") try: - content = await resource.read() + content = await resource.read(context=context) return [ReadResourceContents(content=content, mime_type=resource.mime_type)] except Exception as e: logger.exception(f"Error reading resource {uri}") @@ -543,27 +543,24 @@ async def get_weather(city: str) -> str: ) def decorator(fn: AnyFunction) -> AnyFunction: - # Check if this should be a template sig = inspect.signature(fn) - has_uri_params = "{" in uri and "}" in uri - has_func_params = bool(sig.parameters) + context_param = find_context_parameter(fn) + + # Determine effective parameters, excluding context + effective_func_params = {p for p in sig.parameters.keys() if p != context_param} - if has_uri_params or has_func_params: - # Check for Context parameter to exclude from validation - context_param = find_context_parameter(fn) + has_uri_params = "{" in uri and "}" in uri + has_effective_func_params = bool(effective_func_params) - # Validate that URI params match function params (excluding context) + if has_uri_params or has_effective_func_params: + # Register as template uri_params = set(re.findall(r"{(\w+)}", uri)) - # We need to remove the context_param from the resource function if - # there is any. - func_params = {p for p in sig.parameters.keys() if p != context_param} - if uri_params != func_params: + if uri_params != effective_func_params: raise ValueError( - f"Mismatch between URI parameters {uri_params} and function parameters {func_params}" + f"Mismatch between URI parameters {uri_params} and function parameters {effective_func_params}" ) - # Register as template self._resource_manager.add_template( fn=fn, uri_template=uri, diff --git a/src/mcp/server/fastmcp/utilities/context_injection.py b/src/mcp/server/fastmcp/utilities/context_injection.py index 66d0cbaa0..88370086d 100644 --- a/src/mcp/server/fastmcp/utilities/context_injection.py +++ b/src/mcp/server/fastmcp/utilities/context_injection.py @@ -31,17 +31,16 @@ def find_context_parameter(fn: Callable[..., Any]) -> str | None: # Check each parameter's type hint for param_name, annotation in hints.items(): - # Handle direct Context type + # Handle direct Context type and generic aliases of Context + origin = typing.get_origin(annotation) + + # Check if the annotation itself is Context or a subclass if inspect.isclass(annotation) and issubclass(annotation, Context): return param_name - # Handle generic types like Optional[Context] - origin = typing.get_origin(annotation) - if origin is not None: - args = typing.get_args(annotation) - for arg in args: - if inspect.isclass(arg) and issubclass(arg, Context): - return param_name + # Check if it's a generic alias of Context (e.g., Context[...]) + if origin is not None and inspect.isclass(origin) and issubclass(origin, Context): + return param_name return None diff --git a/tests/server/fastmcp/resources/test_function_resources.py b/tests/server/fastmcp/resources/test_function_resources.py index f30c6e713..05e5838ec 100644 --- a/tests/server/fastmcp/resources/test_function_resources.py +++ b/tests/server/fastmcp/resources/test_function_resources.py @@ -18,6 +18,7 @@ def my_func() -> str: name="test", description="test function", fn=my_func, + context_kwarg=None, ) assert str(resource.uri) == "fn://test" assert resource.name == "test" @@ -36,6 +37,7 @@ def get_data() -> str: uri=AnyUrl("function://test"), name="test", fn=get_data, + context_kwarg=None, ) content = await resource.read() assert content == "Hello, world!" @@ -52,6 +54,7 @@ def get_data() -> bytes: uri=AnyUrl("function://test"), name="test", fn=get_data, + context_kwarg=None, ) content = await resource.read() assert content == b"Hello, world!" @@ -67,6 +70,7 @@ def get_data() -> dict[str, str]: uri=AnyUrl("function://test"), name="test", fn=get_data, + context_kwarg=None, ) content = await resource.read() assert isinstance(content, str) @@ -83,6 +87,7 @@ def failing_func() -> str: uri=AnyUrl("function://test"), name="test", fn=failing_func, + context_kwarg=None, ) with pytest.raises(ValueError, match="Error reading resource function://test"): await resource.read() @@ -98,6 +103,7 @@ class MyModel(BaseModel): uri=AnyUrl("function://test"), name="test", fn=lambda: MyModel(name="test"), + context_kwarg=None, ) content = await resource.read() assert content == '{\n "name": "test"\n}' @@ -117,6 +123,7 @@ def get_data() -> CustomData: uri=AnyUrl("function://test"), name="test", fn=get_data, + context_kwarg=None, ) content = await resource.read() assert isinstance(content, str) @@ -132,6 +139,7 @@ async def get_data() -> str: uri=AnyUrl("function://test"), name="test", fn=get_data, + context_kwarg=None, ) content = await resource.read() assert content == "Hello, world!" diff --git a/tests/server/fastmcp/resources/test_resources.py b/tests/server/fastmcp/resources/test_resources.py index 08b3e65e1..24df3069a 100644 --- a/tests/server/fastmcp/resources/test_resources.py +++ b/tests/server/fastmcp/resources/test_resources.py @@ -18,6 +18,7 @@ def dummy_func() -> str: uri=AnyUrl("http://example.com/data"), name="test", fn=dummy_func, + context_kwarg=None, ) assert str(resource.uri) == "http://example.com/data" @@ -27,6 +28,7 @@ def dummy_func() -> str: uri=AnyUrl("invalid"), name="test", fn=dummy_func, + context_kwarg=None, ) # Missing host @@ -35,6 +37,7 @@ def dummy_func() -> str: uri=AnyUrl("http://"), name="test", fn=dummy_func, + context_kwarg=None, ) def test_resource_name_from_uri(self): @@ -46,6 +49,7 @@ def dummy_func() -> str: resource = FunctionResource( uri=AnyUrl("resource://my-resource"), fn=dummy_func, + context_kwarg=None, ) assert resource.name == "resource://my-resource" @@ -59,6 +63,7 @@ def dummy_func() -> str: with pytest.raises(ValueError, match="Either name or uri must be provided"): FunctionResource( fn=dummy_func, + context_kwarg=None, ) # Explicit name takes precedence over URI @@ -66,6 +71,7 @@ def dummy_func() -> str: uri=AnyUrl("resource://uri-name"), name="explicit-name", fn=dummy_func, + context_kwarg=None, ) assert resource.name == "explicit-name" @@ -79,6 +85,7 @@ def dummy_func() -> str: resource = FunctionResource( uri=AnyUrl("resource://test"), fn=dummy_func, + context_kwarg=None, ) assert resource.mime_type == "text/plain" @@ -87,6 +94,7 @@ def dummy_func() -> str: uri=AnyUrl("resource://test"), fn=dummy_func, mime_type="application/json", + context_kwarg=None, ) assert resource.mime_type == "application/json" diff --git a/tests/server/fastmcp/test_integration.py b/tests/server/fastmcp/test_integration.py index dc88cc025..8acde6b03 100644 --- a/tests/server/fastmcp/test_integration.py +++ b/tests/server/fastmcp/test_integration.py @@ -26,6 +26,7 @@ basic_resource, basic_tool, completion, + context_resource, elicitation, fastmcp_quickstart, notifications, @@ -124,6 +125,8 @@ def run_server_with_transport(module_name: str, port: int, transport: str) -> No mcp = fastmcp_quickstart.mcp elif module_name == "structured_output": mcp = structured_output.mcp + elif module_name == "context_resource": + mcp = context_resource.mcp else: raise ImportError(f"Unknown module: {module_name}") @@ -697,3 +700,42 @@ async def test_structured_output(server_transport: str, server_url: str) -> None assert "sunny" in result_text # condition assert "45" in result_text # humidity assert "5.2" in result_text # wind_speed + + +@pytest.mark.anyio +@pytest.mark.parametrize( + "server_transport", + [ + ("context_resource", "sse"), + ("context_resource", "streamable-http"), + ], + indirect=True, +) +async def test_context_only_resource(server_transport: str, server_url: str) -> None: + """Test that a resource with only a context argument is registered as a regular resource.""" + transport = server_transport + client_cm = create_client_for_transport(transport, server_url) + + async with client_cm as client_streams: + read_stream, write_stream = unpack_streams(client_streams) + async with ClientSession(read_stream, write_stream) as session: + # Test initialization + result = await session.initialize() + assert isinstance(result, InitializeResult) + assert result.serverInfo.name == "Context Resource Example" + + # Check that it is not in templates + templates = await session.list_resource_templates() + assert len(templates.resourceTemplates) == 0 + + # Check that it is in resources + resources = await session.list_resources() + assert len(resources.resources) == 1 + resource = resources.resources[0] + assert resource.uri == AnyUrl("resource://only_context") + + # Check that we can read it + read_result = await session.read_resource(AnyUrl("resource://only_context")) + assert len(read_result.contents) == 1 + assert isinstance(read_result.contents[0], TextResourceContents) + assert read_result.contents[0].text == "Resource with only context injected" diff --git a/tests/server/fastmcp/test_server.py b/tests/server/fastmcp/test_server.py index 8caa3b1f6..f8fce42cc 100644 --- a/tests/server/fastmcp/test_server.py +++ b/tests/server/fastmcp/test_server.py @@ -686,7 +686,7 @@ async def test_text_resource(self): def get_text(): return "Hello, world!" - resource = FunctionResource(uri=AnyUrl("resource://test"), name="test", fn=get_text) + resource = FunctionResource(uri=AnyUrl("resource://test"), name="test", fn=get_text, context_kwarg=None) mcp.add_resource(resource) async with client_session(mcp._mcp_server) as client: @@ -706,6 +706,7 @@ def get_binary(): name="binary", fn=get_binary, mime_type="application/octet-stream", + context_kwarg=None, ) mcp.add_resource(resource) @@ -1120,6 +1121,33 @@ def resource_custom_ctx(id: str, my_ctx: Context[ServerSession, None]) -> str: assert isinstance(content, TextResourceContents) assert "Resource 123 with context" in content.text + @pytest.mark.anyio + async def test_resource_only_context(self): + """Test that resources without template args can receive context.""" + mcp = FastMCP() + + @mcp.resource("resource://only_context", name="resource_with_context_no_args") + def resource_only_context(ctx: Context[ServerSession, None]) -> str: + """Resource that only receives context.""" + assert ctx is not None + return "Resource with only context injected" + + # Test via client + async with client_session(mcp._mcp_server) as client: + # Verify resource is registered via client + resources = await client.list_resources() + assert len(resources.resources) == 1 + resource = resources.resources[0] + assert resource.uri == AnyUrl("resource://only_context") + assert resource.name == "resource_with_context_no_args" + + # Test reading the resource + result = await client.read_resource(AnyUrl("resource://only_context")) + assert len(result.contents) == 1 + content = result.contents[0] + assert isinstance(content, TextResourceContents) + assert content.text == "Resource with only context injected" + @pytest.mark.anyio async def test_prompt_with_context(self): """Test that prompts can receive context parameter.""" diff --git a/tests/server/fastmcp/test_title.py b/tests/server/fastmcp/test_title.py index a94f6671d..5af129986 100644 --- a/tests/server/fastmcp/test_title.py +++ b/tests/server/fastmcp/test_title.py @@ -115,6 +115,7 @@ def get_basic_data() -> str: name="basic_resource", description="Basic resource", fn=get_basic_data, + context_kwarg=None, ) mcp.add_resource(basic_resource) @@ -128,6 +129,7 @@ def get_titled_data() -> str: title="User-Friendly Resource", description="Resource with title", fn=get_titled_data, + context_kwarg=None, ) mcp.add_resource(titled_resource)