Skip to content

Commit 526ddff

Browse files
committed
Merge branch 'ihrpr/client-resumability' into ihrpr/server-closing-streams
2 parents db24790 + d1bd44b commit 526ddff

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+4578
-137
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,4 +166,5 @@ cython_debug/
166166

167167
# vscode
168168
.vscode/
169+
.windsurfrules
169170
**/CLAUDE.local.md

CLAUDE.md

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ This document contains critical information about working with this codebase. Fo
1919
- Line length: 88 chars maximum
2020

2121
3. Testing Requirements
22-
- Framework: `uv run pytest`
22+
- Framework: `uv run --frozen pytest`
2323
- Async testing: use anyio, not asyncio
2424
- Coverage: test edge cases and errors
2525
- New features require tests
@@ -54,9 +54,9 @@ This document contains critical information about working with this codebase. Fo
5454
## Code Formatting
5555

5656
1. Ruff
57-
- Format: `uv run ruff format .`
58-
- Check: `uv run ruff check .`
59-
- Fix: `uv run ruff check . --fix`
57+
- Format: `uv run --frozen ruff format .`
58+
- Check: `uv run --frozen ruff check .`
59+
- Fix: `uv run --frozen ruff check . --fix`
6060
- Critical issues:
6161
- Line length (88 chars)
6262
- Import sorting (I001)
@@ -67,7 +67,7 @@ This document contains critical information about working with this codebase. Fo
6767
- Imports: split into multiple lines
6868

6969
2. Type Checking
70-
- Tool: `uv run pyright`
70+
- Tool: `uv run --frozen pyright`
7171
- Requirements:
7272
- Explicit None checks for Optional
7373
- Type narrowing for strings
@@ -104,6 +104,10 @@ This document contains critical information about working with this codebase. Fo
104104
- Add None checks
105105
- Narrow string types
106106
- Match existing patterns
107+
- Pytest:
108+
- If the tests aren't finding the anyio pytest mark, try adding PYTEST_DISABLE_PLUGIN_AUTOLOAD=""
109+
to the start of the pytest run command eg:
110+
`PYTEST_DISABLE_PLUGIN_AUTOLOAD="" uv run --frozen pytest`
107111

108112
3. Best Practices
109113
- Check git status before commits

README.md

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,33 @@ async def long_task(files: list[str], ctx: Context) -> str:
309309
return "Processing complete"
310310
```
311311

312+
### Authentication
313+
314+
Authentication can be used by servers that want to expose tools accessing protected resources.
315+
316+
`mcp.server.auth` implements an OAuth 2.0 server interface, which servers can use by
317+
providing an implementation of the `OAuthServerProvider` protocol.
318+
319+
```
320+
mcp = FastMCP("My App",
321+
auth_provider=MyOAuthServerProvider(),
322+
auth=AuthSettings(
323+
issuer_url="https://myapp.com",
324+
revocation_options=RevocationOptions(
325+
enabled=True,
326+
),
327+
client_registration_options=ClientRegistrationOptions(
328+
enabled=True,
329+
valid_scopes=["myscope", "myotherscope"],
330+
default_scopes=["myscope"],
331+
),
332+
required_scopes=["myscope"],
333+
),
334+
)
335+
```
336+
337+
See [OAuthServerProvider](mcp/server/auth/provider.py) for more details.
338+
312339
## Running Your Server
313340

314341
### Development Mode

examples/clients/simple-chatbot/mcp_simple_chatbot/main.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -323,8 +323,7 @@ async def process_llm_response(self, llm_response: str) -> str:
323323
total = result["total"]
324324
percentage = (progress / total) * 100
325325
logging.info(
326-
f"Progress: {progress}/{total} "
327-
f"({percentage:.1f}%)"
326+
f"Progress: {progress}/{total} ({percentage:.1f}%)"
328327
)
329328

330329
return f"Tool execution result: {result}"

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ dependencies = [
2727
"httpx-sse>=0.4",
2828
"pydantic>=2.7.2,<3.0.0",
2929
"starlette>=0.27",
30+
"python-multipart>=0.0.9",
3031
"sse-starlette>=1.6.1",
3132
"pydantic-settings>=2.5.2",
3233
"uvicorn>=0.23.1; sys_platform != 'emscripten'",
@@ -53,6 +54,7 @@ dev = [
5354
"pytest-flakefinder>=1.1.0",
5455
"pytest-xdist>=3.6.1",
5556
"pytest-examples>=0.0.14",
57+
"pytest-pretty>=1.2.0",
5658
]
5759
docs = [
5860
"mkdocs>=1.6.1",

src/mcp/client/session.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,6 @@
88
import mcp.types as types
99
from mcp.shared.context import RequestContext
1010
from mcp.shared.message import (
11-
ClientMessageMetadata,
12-
ResumptionToken,
13-
ResumptionTokenUpdateCallback,
1411
SessionMessage,
1512
)
1613
from mcp.shared.session import BaseSession, RequestResponder
@@ -263,16 +260,9 @@ async def call_tool(
263260
self,
264261
name: str,
265262
arguments: dict[str, Any] | None = None,
266-
on_resumption_token_update: ResumptionTokenUpdateCallback | None = None,
267-
resumption_token: ResumptionToken | None = None,
263+
read_timeout_seconds: timedelta | None = None,
268264
) -> types.CallToolResult:
269265
"""Send a tools/call request."""
270-
metadata = None
271-
if on_resumption_token_update or resumption_token:
272-
metadata = ClientMessageMetadata(
273-
on_resumption_token_update=on_resumption_token_update,
274-
resumption_token=resumption_token,
275-
)
276266

277267
return await self.send_request(
278268
types.ClientRequest(
@@ -282,7 +272,7 @@ async def call_tool(
282272
)
283273
),
284274
types.CallToolResult,
285-
metadata=metadata,
275+
request_read_timeout_seconds=read_timeout_seconds,
286276
)
287277

288278
async def list_prompts(self) -> types.ListPromptsResult:

src/mcp/client/streamable_http.py

Lines changed: 33 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
"""
88

99
import logging
10-
from collections.abc import Awaitable, Callable
10+
from collections.abc import AsyncGenerator, Awaitable, Callable
1111
from contextlib import asynccontextmanager
1212
from dataclasses import dataclass
1313
from datetime import timedelta
@@ -16,7 +16,7 @@
1616
import anyio
1717
import httpx
1818
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
19-
from httpx_sse import EventSource, aconnect_sse
19+
from httpx_sse import EventSource, ServerSentEvent, aconnect_sse
2020

2121
from mcp.shared.message import ClientMessageMetadata, SessionMessage
2222
from mcp.types import (
@@ -26,15 +26,16 @@
2626
JSONRPCNotification,
2727
JSONRPCRequest,
2828
JSONRPCResponse,
29+
RequestId,
2930
)
3031

3132
logger = logging.getLogger(__name__)
3233

3334

34-
MessageOrError = SessionMessage | Exception
35-
StreamWriter = MemoryObjectSendStream[MessageOrError]
35+
SessionMessageOrError = SessionMessage | Exception
36+
StreamWriter = MemoryObjectSendStream[SessionMessageOrError]
3637
StreamReader = MemoryObjectReceiveStream[SessionMessage]
37-
38+
GetSessionIdCallback = Callable[[], str | None]
3839

3940
MCP_SESSION_ID = "mcp-session-id"
4041
LAST_EVENT_ID = "last-event-id"
@@ -123,23 +124,21 @@ def _is_initialized_notification(self, message: JSONRPCMessage) -> bool:
123124
and message.root.method == "notifications/initialized"
124125
)
125126

126-
def _extract_session_id_from_response(
127+
def _maybe_extract_session_id_from_response(
127128
self,
128129
response: httpx.Response,
129-
is_initialization: bool,
130130
) -> None:
131131
"""Extract and store session ID from response headers."""
132-
if is_initialization:
133-
new_session_id = response.headers.get(MCP_SESSION_ID)
134-
if new_session_id:
135-
self.session_id = new_session_id
136-
logger.info(f"Received session ID: {self.session_id}")
132+
new_session_id = response.headers.get(MCP_SESSION_ID)
133+
if new_session_id:
134+
self.session_id = new_session_id
135+
logger.info(f"Received session ID: {self.session_id}")
137136

138137
async def _handle_sse_event(
139138
self,
140-
sse: Any,
139+
sse: ServerSentEvent,
141140
read_stream_writer: StreamWriter,
142-
original_request_id: Any | None = None,
141+
original_request_id: RequestId | None = None,
143142
resumption_callback: Callable[[str], Awaitable[None]] | None = None,
144143
) -> bool:
145144
"""Handle an SSE event, returning True if the response is complete."""
@@ -161,7 +160,8 @@ async def _handle_sse_event(
161160
if sse.id and resumption_callback:
162161
await resumption_callback(sse.id)
163162

164-
# If this is a response or error, we're done
163+
# If this is a response or error return True indicating completion
164+
# Otherwise, return False to continue listening
165165
return isinstance(message.root, JSONRPCResponse | JSONRPCError)
166166

167167
except Exception as exc:
@@ -262,7 +262,8 @@ async def _handle_post_request(self, ctx: RequestContext) -> None:
262262
return
263263

264264
response.raise_for_status()
265-
self._extract_session_id_from_response(response, is_initialization)
265+
if is_initialization:
266+
self._maybe_extract_session_id_from_response(response)
266267

267268
content_type = response.headers.get(CONTENT_TYPE, "").lower()
268269

@@ -324,7 +325,7 @@ async def _handle_unexpected_content_type(
324325
async def _send_session_terminated_error(
325326
self,
326327
read_stream_writer: StreamWriter,
327-
request_id: Any,
328+
request_id: RequestId,
328329
) -> None:
329330
"""Send a session terminated error response."""
330331
jsonrpc_error = JSONRPCError(
@@ -411,16 +412,26 @@ async def streamablehttp_client(
411412
headers: dict[str, Any] | None = None,
412413
timeout: timedelta = timedelta(seconds=30),
413414
sse_read_timeout: timedelta = timedelta(seconds=60 * 5),
414-
):
415+
terminate_on_close: bool = True,
416+
) -> AsyncGenerator[
417+
tuple[
418+
MemoryObjectReceiveStream[SessionMessage | Exception],
419+
MemoryObjectSendStream[SessionMessage],
420+
GetSessionIdCallback,
421+
],
422+
None,
423+
]:
415424
"""
416425
Client transport for StreamableHTTP.
417426
418427
`sse_read_timeout` determines how long (in seconds) the client will wait for a new
419428
event before disconnecting. All other HTTP operations are controlled by `timeout`.
420429
421430
Yields:
422-
Tuple of (read_stream, write_stream, terminate_callback,
423-
get_session_id_callback)
431+
Tuple containing:
432+
- read_stream: Stream for reading messages from the server
433+
- write_stream: Stream for sending messages to the server
434+
- get_session_id_callback: Function to retrieve the current session ID
424435
"""
425436
transport = StreamableHTTPTransport(url, headers, timeout, sse_read_timeout)
426437

@@ -448,9 +459,6 @@ def start_get_stream() -> None:
448459
transport.handle_get_stream, client, read_stream_writer
449460
)
450461

451-
async def terminate_session() -> None:
452-
await transport.terminate_session(client)
453-
454462
tg.start_soon(
455463
transport.post_writer,
456464
client,
@@ -464,10 +472,11 @@ async def terminate_session() -> None:
464472
yield (
465473
read_stream,
466474
write_stream,
467-
terminate_session,
468475
transport.get_session_id,
469476
)
470477
finally:
478+
if transport.session_id and terminate_on_close:
479+
await transport.terminate_session(client)
471480
tg.cancel_scope.cancel()
472481
finally:
473482
await read_stream_writer.aclose()

src/mcp/server/auth/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
"""
2+
MCP OAuth server authorization components.
3+
"""

src/mcp/server/auth/errors.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
from pydantic import ValidationError
2+
3+
4+
def stringify_pydantic_error(validation_error: ValidationError) -> str:
5+
return "\n".join(
6+
f"{'.'.join(str(loc) for loc in e['loc'])}: {e['msg']}"
7+
for e in validation_error.errors()
8+
)
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
"""
2+
Request handlers for MCP authorization endpoints.
3+
"""

0 commit comments

Comments
 (0)