Skip to content

Commit 092d769

Browse files
committed
pass context to all permit methods to allow checking other request params such as http headers
1 parent a587453 commit 092d769

File tree

6 files changed

+118
-54
lines changed

6 files changed

+118
-54
lines changed

src/mcp/server/fastmcp/authorizer.py

Lines changed: 52 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -4,24 +4,28 @@
44
from typing import TYPE_CHECKING, Any
55

66
from pydantic import AnyUrl
7+
from starlette.requests import Request
78

8-
from mcp.shared.context import LifespanContextT, RequestT
9+
from mcp.server.session import ServerSession
910

1011
if TYPE_CHECKING:
1112
from mcp.server.fastmcp.server import Context
12-
from mcp.server.session import ServerSessionT
1313

1414

1515
class Authorizer:
1616
__metaclass__ = abc.ABCMeta
1717

1818
@abc.abstractmethod
19-
def permit_get_tool(self, name: str) -> bool:
19+
def permit_get_tool(self, name: str, context: Context[ServerSession, object, Request] | None = None) -> bool:
2020
"""Check if the specified tool can be retrieved from the associated mcp server"""
2121
return False
2222

2323
@abc.abstractmethod
24-
def permit_list_tool(self, name: str) -> bool:
24+
def permit_list_tool(
25+
self,
26+
name: str,
27+
context: Context[ServerSession, object, Request] | None = None,
28+
) -> bool:
2529
"""Check if the specified tool can be listed from the associated mcp server"""
2630
return False
2731

@@ -30,79 +34,105 @@ def permit_call_tool(
3034
self,
3135
name: str,
3236
arguments: dict[str, Any],
33-
context: Context[ServerSessionT, LifespanContextT, RequestT] | None = None,
37+
context: Context[ServerSession, object, Request] | None = None,
3438
) -> bool:
3539
"""Check if the specified tool can be called from the associated mcp server"""
3640
return False
3741

3842
@abc.abstractmethod
39-
def permit_get_resource(self, resource: AnyUrl | str) -> bool:
43+
def permit_get_resource(
44+
self, resource: AnyUrl | str, context: Context[ServerSession, object, Request] | None = None
45+
) -> bool:
4046
"""Check if the specified resource can be retrieved from the associated mcp server"""
4147
return False
4248

4349
@abc.abstractmethod
44-
def permit_create_resource(self, uri: str, params: dict[str, Any]) -> bool:
50+
def permit_create_resource(
51+
self, uri: str, params: dict[str, Any], context: Context[ServerSession, object, Request] | None = None
52+
) -> bool:
4553
"""Check if the specified resource can be created on the associated mcp server"""
4654
return False
4755

4856
@abc.abstractmethod
49-
def permit_list_resource(self, resource: AnyUrl | str) -> bool:
57+
def permit_list_resource(
58+
self, resource: AnyUrl | str, context: Context[ServerSession, object, Request] | None = None
59+
) -> bool:
5060
"""Check if the specified resource can be listed from the associated mcp server"""
5161
return False
5262

5363
@abc.abstractmethod
54-
def permit_list_template(self, resource: AnyUrl | str) -> bool:
64+
def permit_list_template(
65+
self, resource: AnyUrl | str, context: Context[ServerSession, object, Request] | None = None
66+
) -> bool:
5567
"""Check if the specified template can be listed from the associated mcp server"""
5668
return False
5769

5870
@abc.abstractmethod
59-
def permit_get_prompt(self, name: str) -> bool:
71+
def permit_get_prompt(self, name: str, context: Context[ServerSession, object, Request] | None = None) -> bool:
6072
"""Check if the specified prompt can be retrieved from the associated mcp server"""
6173
return False
6274

6375
@abc.abstractmethod
64-
def permit_list_prompt(self, name: str) -> bool:
76+
def permit_list_prompt(self, name: str, context: Context[ServerSession, object, Request] | None = None) -> bool:
6577
"""Check if the specified prompt can be listed from the associated mcp server"""
6678
return False
6779

6880
@abc.abstractmethod
69-
def permit_render_prompt(self, name: str, arguments: dict[str, Any] | None = None) -> bool:
81+
def permit_render_prompt(
82+
self,
83+
name: str,
84+
arguments: dict[str, Any] | None = None,
85+
context: Context[ServerSession, object, Request] | None = None,
86+
) -> bool:
7087
"""Check if the specified prompt can be rendered from the associated mcp server"""
7188
return False
7289

7390

7491
class AllAllAuthorizer(Authorizer):
75-
def permit_get_tool(self, name: str) -> bool:
92+
def permit_get_tool(self, name: str, context: Context[ServerSession, object, Request] | None = None) -> bool:
7693
return True
7794

78-
def permit_list_tool(self, name: str) -> bool:
95+
def permit_list_tool(self, name: str, context: Context[ServerSession, object, Request] | None = None) -> bool:
7996
return True
8097

8198
def permit_call_tool(
8299
self,
83100
name: str,
84101
arguments: dict[str, Any],
85-
context: Context[ServerSessionT, LifespanContextT, RequestT] | None = None,
102+
context: Context[ServerSession, object, Request] | None = None,
86103
) -> bool:
87104
return True
88105

89-
def permit_get_resource(self, resource: AnyUrl | str) -> bool:
106+
def permit_get_resource(
107+
self, resource: AnyUrl | str, context: Context[ServerSession, object, Request] | None = None
108+
) -> bool:
90109
return True
91110

92-
def permit_create_resource(self, uri: str, params: dict[str, Any]) -> bool:
111+
def permit_create_resource(
112+
self, uri: str, params: dict[str, Any], context: Context[ServerSession, object, Request] | None = None
113+
) -> bool:
93114
return True
94115

95-
def permit_list_resource(self, resource: AnyUrl | str) -> bool:
116+
def permit_list_resource(
117+
self, resource: AnyUrl | str, context: Context[ServerSession, object, Request] | None = None
118+
) -> bool:
96119
return True
97120

98-
def permit_list_template(self, resource: AnyUrl | str) -> bool:
121+
def permit_list_template(
122+
self, resource: AnyUrl | str, context: Context[ServerSession, object, Request] | None = None
123+
) -> bool:
99124
return True
100125

101-
def permit_get_prompt(self, name: str) -> bool:
126+
def permit_get_prompt(self, name: str, context: Context[ServerSession, object, Request] | None = None) -> bool:
102127
return True
103128

104-
def permit_list_prompt(self, name: str) -> bool:
129+
def permit_list_prompt(self, name: str, context: Context[ServerSession, object, Request] | None = None) -> bool:
105130
return True
106131

107-
def permit_render_prompt(self, name: str, arguments: dict[str, Any] | None = None) -> bool:
132+
def permit_render_prompt(
133+
self,
134+
name: str,
135+
arguments: dict[str, Any] | None = None,
136+
context: Context[ServerSession, object, Request] | None = None,
137+
) -> bool:
108138
return True

src/mcp/server/fastmcp/prompts/manager.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,18 @@
11
"""Prompt management functionality."""
22

3-
from typing import Any
3+
from __future__ import annotations as _annotations
4+
5+
from typing import TYPE_CHECKING, Any
6+
7+
from starlette.requests import Request
48

59
from mcp.server.fastmcp.authorizer import AllAllAuthorizer, Authorizer
610
from mcp.server.fastmcp.prompts.base import Message, Prompt
711
from mcp.server.fastmcp.utilities.logging import get_logger
12+
from mcp.server.session import ServerSession
13+
14+
if TYPE_CHECKING:
15+
from mcp.server.fastmcp.server import Context
816

917
logger = get_logger(__name__)
1018

@@ -21,16 +29,16 @@ def __init__(
2129
self._authorizer = authorizer
2230
self.warn_on_duplicate_prompts = warn_on_duplicate_prompts
2331

24-
def get_prompt(self, name: str) -> Prompt | None:
32+
def get_prompt(self, name: str, context: Context[ServerSession, object, Request] | None = None) -> Prompt | None:
2533
"""Get prompt by name."""
26-
if self._authorizer.permit_get_prompt(name):
34+
if self._authorizer.permit_get_prompt(name, context):
2735
return self._prompts.get(name)
2836
else:
2937
return None
3038

31-
def list_prompts(self) -> list[Prompt]:
39+
def list_prompts(self, context: Context[ServerSession, object, Request] | None = None) -> list[Prompt]:
3240
"""List all registered prompts."""
33-
return [prompt for name, prompt in self._prompts.items() if self._authorizer.permit_list_prompt(name)]
41+
return [prompt for name, prompt in self._prompts.items() if self._authorizer.permit_list_prompt(name, context)]
3442

3543
def add_prompt(
3644
self,
@@ -48,12 +56,17 @@ def add_prompt(
4856
self._prompts[prompt.name] = prompt
4957
return prompt
5058

51-
async def render_prompt(self, name: str, arguments: dict[str, Any] | None = None) -> list[Message]:
59+
async def render_prompt(
60+
self,
61+
name: str,
62+
arguments: dict[str, Any] | None = None,
63+
context: Context[ServerSession, object, Request] | None = None,
64+
) -> list[Message]:
5265
"""Render a prompt by name with arguments."""
5366
prompt = self.get_prompt(name)
5467
if not prompt:
5568
raise ValueError(f"Unknown prompt: {name}")
56-
if self._authorizer.permit_render_prompt(name, arguments):
69+
if self._authorizer.permit_render_prompt(name, arguments, context):
5770
return await prompt.render(arguments)
5871
else:
5972
raise ValueError(f"Unknown prompt: {name}")

src/mcp/server/fastmcp/resources/resource_manager.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,21 @@
11
"""Resource manager functionality."""
22

3+
from __future__ import annotations as _annotations
4+
35
from collections.abc import Callable
4-
from typing import Any
6+
from typing import TYPE_CHECKING, Any
57

68
from pydantic import AnyUrl
9+
from starlette.requests import Request
710

811
from mcp.server.fastmcp.authorizer import AllAllAuthorizer, Authorizer
912
from mcp.server.fastmcp.resources.base import Resource
1013
from mcp.server.fastmcp.resources.templates import ResourceTemplate
1114
from mcp.server.fastmcp.utilities.logging import get_logger
15+
from mcp.server.session import ServerSession
16+
17+
if TYPE_CHECKING:
18+
from mcp.server.fastmcp.server import Context
1219

1320
logger = get_logger(__name__)
1421

@@ -73,14 +80,16 @@ def add_template(
7380
self._templates[template.uri_template] = template
7481
return template
7582

76-
async def get_resource(self, uri: AnyUrl | str) -> Resource | None:
83+
async def get_resource(
84+
self, uri: AnyUrl | str, context: Context[ServerSession, object, Request] | None = None
85+
) -> Resource | None:
7786
"""Get resource by URI, checking concrete resources first, then templates."""
7887
uri_str = str(uri)
7988
logger.debug("Getting resource", extra={"uri": uri_str})
8089

8190
# First check concrete resources
8291
if resource := self._resources.get(uri_str):
83-
if self._authorizer.permit_get_resource(uri_str):
92+
if self._authorizer.permit_get_resource(uri_str, context):
8493
return resource
8594
else:
8695
raise ValueError(f"Unknown resource: {uri}")
@@ -98,12 +107,16 @@ async def get_resource(self, uri: AnyUrl | str) -> Resource | None:
98107

99108
raise ValueError(f"Unknown resource: {uri}")
100109

101-
def list_resources(self) -> list[Resource]:
110+
def list_resources(self, context: Context[ServerSession, object, Request] | None = None) -> list[Resource]:
102111
"""List all registered resources."""
103112
logger.debug("Listing resources", extra={"count": len(self._resources)})
104-
return [resource for uri, resource in self._resources.items() if self._authorizer.permit_list_resource(uri)]
113+
return [
114+
resource for uri, resource in self._resources.items() if self._authorizer.permit_list_resource(uri, context)
115+
]
105116

106-
def list_templates(self) -> list[ResourceTemplate]:
117+
def list_templates(self, context: Context[ServerSession, object, Request] | None = None) -> list[ResourceTemplate]:
107118
"""List all registered templates."""
108119
logger.debug("Listing templates", extra={"count": len(self._templates)})
109-
return [template for uri, template in self._templates.items() if self._authorizer.permit_list_template(uri)]
120+
return [
121+
template for uri, template in self._templates.items() if self._authorizer.permit_list_template(uri, context)
122+
]

src/mcp/server/fastmcp/server.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,8 @@ def _setup_handlers(self) -> None:
257257

258258
async def list_tools(self) -> list[MCPTool]:
259259
"""List all available tools."""
260-
tools = self._tool_manager.list_tools()
260+
context = self.get_context()
261+
tools = self._tool_manager.list_tools(context)
261262
return [
262263
MCPTool(
263264
name=info.name,
@@ -289,8 +290,8 @@ async def call_tool(self, name: str, arguments: dict[str, Any]) -> Sequence[Cont
289290

290291
async def list_resources(self) -> list[MCPResource]:
291292
"""List all available resources."""
292-
293-
resources = self._resource_manager.list_resources()
293+
context = self.get_context()
294+
resources = self._resource_manager.list_resources(context)
294295
return [
295296
MCPResource(
296297
uri=resource.uri,
@@ -303,7 +304,8 @@ async def list_resources(self) -> list[MCPResource]:
303304
]
304305

305306
async def list_resource_templates(self) -> list[MCPResourceTemplate]:
306-
templates = self._resource_manager.list_templates()
307+
context = self.get_context()
308+
templates = self._resource_manager.list_templates(context)
307309
return [
308310
MCPResourceTemplate(
309311
uriTemplate=template.uri_template,
@@ -316,8 +318,8 @@ async def list_resource_templates(self) -> list[MCPResourceTemplate]:
316318

317319
async def read_resource(self, uri: AnyUrl | str) -> Iterable[ReadResourceContents]:
318320
"""Read a resource by URI."""
319-
320-
resource = await self._resource_manager.get_resource(uri)
321+
context = self.get_context()
322+
resource = await self._resource_manager.get_resource(uri, context)
321323
if not resource:
322324
raise ResourceError(f"Unknown resource: {uri}")
323325

@@ -924,9 +926,9 @@ async def handle_streamable_http(scope: Scope, receive: Receive, send: Send) ->
924926
lifespan=lambda app: self.session_manager.run(),
925927
)
926928

927-
async def list_prompts(self) -> list[MCPPrompt]:
929+
async def list_prompts(self, context: Context[ServerSession, object, Request] | None = None) -> list[MCPPrompt]:
928930
"""List all available prompts."""
929-
prompts = self._prompt_manager.list_prompts()
931+
prompts = self._prompt_manager.list_prompts(context)
930932
return [
931933
MCPPrompt(
932934
name=prompt.name,
@@ -944,10 +946,15 @@ async def list_prompts(self) -> list[MCPPrompt]:
944946
for prompt in prompts
945947
]
946948

947-
async def get_prompt(self, name: str, arguments: dict[str, Any] | None = None) -> GetPromptResult:
949+
async def get_prompt(
950+
self,
951+
name: str,
952+
arguments: dict[str, Any] | None = None,
953+
context: Context[ServerSession, object, Request] | None = None,
954+
) -> GetPromptResult:
948955
"""Get a prompt by name with arguments."""
949956
try:
950-
messages = await self._prompt_manager.render_prompt(name, arguments)
957+
messages = await self._prompt_manager.render_prompt(name, arguments, context)
951958

952959
return GetPromptResult(messages=pydantic_core.to_jsonable_python(messages))
953960
except Exception as e:

0 commit comments

Comments
 (0)