Skip to content

Commit c3ef7de

Browse files
committed
tinkering with resource progress
1 parent f2f4dbd commit c3ef7de

File tree

6 files changed

+65
-22
lines changed

6 files changed

+65
-22
lines changed

src/mcp/client/session.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
from datetime import timedelta
2-
from typing import Any, Protocol
2+
from typing import Annotated, Any, Protocol
33

44
import anyio.lowlevel
55
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
6-
from pydantic import AnyUrl, TypeAdapter
6+
from pydantic import TypeAdapter
7+
from pydantic.networks import AnyUrl, UrlConstraints
78

89
import mcp.types as types
910
from mcp.shared.context import RequestContext
1011
from mcp.shared.message import SessionMessage
11-
from mcp.shared.session import BaseSession, ProgressFnT, RequestResponder
12+
from mcp.shared.session import BaseSession, ProgressFnT, RequestResponder, ResourceProgressFnT
1213
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS
1314

1415
DEFAULT_CLIENT_INFO = types.Implementation(name="mcp", version="0.1.0")
@@ -173,6 +174,7 @@ async def send_progress_notification(
173174
progress: float,
174175
total: float | None = None,
175176
message: str | None = None,
177+
# TODO decide whether clients can send resource progress too?
176178
) -> None:
177179
"""Send a progress notification."""
178180
await self.send_notification(
@@ -203,6 +205,7 @@ async def set_logging_level(self, level: types.LoggingLevel) -> types.EmptyResul
203205

204206
async def list_resources(
205207
self, cursor: str | None = None
208+
# TODO suggest in progress resources should be excluded by default? possibly add an optional flag to include?
206209
) -> types.ListResourcesResult:
207210
"""Send a resources/list request."""
208211
return await self.send_request(
@@ -233,7 +236,7 @@ async def list_resource_templates(
233236
types.ListResourceTemplatesResult,
234237
)
235238

236-
async def read_resource(self, uri: AnyUrl) -> types.ReadResourceResult:
239+
async def read_resource(self, uri: Annotated[AnyUrl, UrlConstraints(host_required=False)]) -> types.ReadResourceResult:
237240
"""Send a resources/read request."""
238241
return await self.send_request(
239242
types.ClientRequest(
@@ -245,7 +248,7 @@ async def read_resource(self, uri: AnyUrl) -> types.ReadResourceResult:
245248
types.ReadResourceResult,
246249
)
247250

248-
async def subscribe_resource(self, uri: AnyUrl) -> types.EmptyResult:
251+
async def subscribe_resource(self, uri: Annotated[AnyUrl, UrlConstraints(host_required=False)]) -> types.EmptyResult:
249252
"""Send a resources/subscribe request."""
250253
return await self.send_request(
251254
types.ClientRequest(
@@ -257,7 +260,7 @@ async def subscribe_resource(self, uri: AnyUrl) -> types.EmptyResult:
257260
types.EmptyResult,
258261
)
259262

260-
async def unsubscribe_resource(self, uri: AnyUrl) -> types.EmptyResult:
263+
async def unsubscribe_resource(self, uri: Annotated[AnyUrl, UrlConstraints(host_required=False)]) -> types.EmptyResult:
261264
"""Send a resources/unsubscribe request."""
262265
return await self.send_request(
263266
types.ClientRequest(
@@ -274,7 +277,7 @@ async def call_tool(
274277
name: str,
275278
arguments: dict[str, Any] | None = None,
276279
read_timeout_seconds: timedelta | None = None,
277-
progress_callback: ProgressFnT | None = None,
280+
progress_callback: ProgressFnT | ResourceProgressFnT | None = None,
278281
) -> types.CallToolResult:
279282
"""Send a tools/call request with optional progress callback support."""
280283

src/mcp/server/fastmcp/server.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,12 @@
1010
asynccontextmanager,
1111
)
1212
from itertools import chain
13-
from typing import Any, Generic, Literal
13+
from typing import Annotated, Any, Generic, Literal
1414

1515
import anyio
1616
import pydantic_core
1717
from pydantic import BaseModel, Field
18-
from pydantic.networks import AnyUrl
18+
from pydantic.networks import AnyUrl, UrlConstraints
1919
from pydantic_settings import BaseSettings, SettingsConfigDict
2020
from starlette.applications import Starlette
2121
from starlette.middleware import Middleware
@@ -956,7 +956,11 @@ def request_context(self) -> RequestContext[ServerSessionT, LifespanContextT]:
956956
return self._request_context
957957

958958
async def report_progress(
959-
self, progress: float, total: float | None = None, message: str | None = None
959+
self,
960+
progress: float,
961+
total: float | None = None,
962+
message: str | None = None,
963+
resource_uri: Annotated[AnyUrl, UrlConstraints(host_required=False)] | None = None,
960964
) -> None:
961965
"""Report progress for the current operation.
962966
@@ -979,6 +983,7 @@ async def report_progress(
979983
progress=progress,
980984
total=total,
981985
message=message,
986+
resource_uri=resource_uri,
982987
)
983988

984989
async def read_resource(self, uri: str | AnyUrl) -> Iterable[ReadResourceContents]:

src/mcp/server/session.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,12 @@ async def handle_list_prompts(ctx: RequestContext) -> list[types.Prompt]:
3838
"""
3939

4040
from enum import Enum
41-
from typing import Any, TypeVar
41+
from typing import Annotated, Any, TypeVar
4242

4343
import anyio
4444
import anyio.lowlevel
4545
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
46-
from pydantic import AnyUrl
46+
from pydantic.networks import AnyUrl, UrlConstraints
4747

4848
import mcp.types as types
4949
from mcp.server.models import InitializationOptions
@@ -288,6 +288,7 @@ async def send_progress_notification(
288288
total: float | None = None,
289289
message: str | None = None,
290290
related_request_id: str | None = None,
291+
resource_uri: Annotated[AnyUrl, UrlConstraints(host_required=False)] | None = None,
291292
) -> None:
292293
"""Send a progress notification."""
293294
await self.send_notification(
@@ -299,6 +300,7 @@ async def send_progress_notification(
299300
progress=progress,
300301
total=total,
301302
message=message,
303+
resource_uri=resource_uri,
302304
),
303305
)
304306
),

src/mcp/shared/session.py

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,14 @@
33
from contextlib import AsyncExitStack
44
from datetime import timedelta
55
from types import TracebackType
6-
from typing import Any, Generic, Protocol, TypeVar
6+
from typing import Annotated, Any, Generic, Protocol, TypeVar, runtime_checkable
77

88
import anyio
99
import httpx
10+
import inspect
1011
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
1112
from pydantic import BaseModel
13+
from pydantic.networks import AnyUrl, UrlConstraints
1214
from typing_extensions import Self
1315

1416
from mcp.shared.exceptions import McpError
@@ -43,13 +45,21 @@
4345
RequestId = str | int
4446

4547

48+
@runtime_checkable
4649
class ProgressFnT(Protocol):
4750
"""Protocol for progress notification callbacks."""
4851

4952
async def __call__(
5053
self, progress: float, total: float | None, message: str | None
5154
) -> None: ...
5255

56+
@runtime_checkable
57+
class ResourceProgressFnT(Protocol):
58+
"""Protocol for progress notification callbacks with resources."""
59+
60+
async def __call__(
61+
self, progress: float, total: float | None, message: str | None, resource_uri: Annotated[AnyUrl, UrlConstraints(host_required=False)] | None = None
62+
) -> None: ...
5363

5464
class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
5565
"""Handles responding to MCP requests and manages request lifecycle.
@@ -178,7 +188,8 @@ class BaseSession(
178188
]
179189
_request_id: int
180190
_in_flight: dict[RequestId, RequestResponder[ReceiveRequestT, SendResultT]]
181-
_progress_callbacks: dict[RequestId, ProgressFnT]
191+
_progress_callbacks: dict[RequestId, ProgressFnT ]
192+
_resource_progress_callbacks: dict[RequestId, ResourceProgressFnT]
182193

183194
def __init__(
184195
self,
@@ -198,6 +209,7 @@ def __init__(
198209
self._session_read_timeout_seconds = read_timeout_seconds
199210
self._in_flight = {}
200211
self._progress_callbacks = {}
212+
self._resource_progress_callbacks = {}
201213
self._exit_stack = AsyncExitStack()
202214

203215
async def __aenter__(self) -> Self:
@@ -225,7 +237,7 @@ async def send_request(
225237
result_type: type[ReceiveResultT],
226238
request_read_timeout_seconds: timedelta | None = None,
227239
metadata: MessageMetadata = None,
228-
progress_callback: ProgressFnT | None = None,
240+
progress_callback: ProgressFnT | ResourceProgressFnT | None = None,
229241
) -> ReceiveResultT:
230242
"""
231243
Sends a request and wait for a response. Raises an McpError if the
@@ -252,8 +264,14 @@ async def send_request(
252264
if "_meta" not in request_data["params"]:
253265
request_data["params"]["_meta"] = {}
254266
request_data["params"]["_meta"]["progressToken"] = request_id
255-
# Store the callback for this request
256-
self._progress_callbacks[request_id] = progress_callback
267+
# note this is required to ensure backwards compatibility for previous clients
268+
signature = inspect.signature(progress_callback.__call__)
269+
if 'resource_uri' in signature.parameters:
270+
# Store the callback for this request
271+
self._resource_progress_callbacks[request_id] = progress_callback # type: ignore
272+
else:
273+
# Store the callback for this request
274+
self._progress_callbacks[request_id] = progress_callback
257275

258276
try:
259277
jsonrpc_request = JSONRPCRequest(
@@ -397,6 +415,15 @@ async def _receive_loop(self) -> None:
397415
notification.root.params.total,
398416
notification.root.params.message,
399417
)
418+
elif progress_token in self._resource_progress_callbacks:
419+
callback = self._resource_progress_callbacks[progress_token]
420+
await callback(
421+
notification.root.params.progress,
422+
notification.root.params.total,
423+
notification.root.params.message,
424+
notification.root.params.resource_uri,
425+
)
426+
400427
await self._received_notification(notification)
401428
await self._handle_incoming(notification)
402429
except Exception as e:

src/mcp/types.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -346,12 +346,18 @@ class ProgressNotificationParams(NotificationParams):
346346
total is unknown.
347347
"""
348348
total: float | None = None
349+
"""Total number of items to process (or total progress required), if known."""
350+
message: str | None = None
349351
"""
350352
Message related to progress. This should provide relevant human readable
351353
progress information.
352354
"""
353-
message: str | None = None
354-
"""Total number of items to process (or total progress required), if known."""
355+
resource_uri: Annotated[AnyUrl, UrlConstraints(host_required=False)] | None = None
356+
"""
357+
An optional reference to an ephemeral resource associated with this progress, servers
358+
may delete these at their descretion, but are encouraged to make them available for
359+
a reasonable time period to allow clients to retrieve and cache the resources locally
360+
"""
355361
model_config = ConfigDict(extra="allow")
356362

357363

tests/issues/test_176_progress_token.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,11 @@ async def test_progress_token_zero_first_call():
3939
mock_session.send_progress_notification.call_count == 3
4040
), "All progress notifications should be sent"
4141
mock_session.send_progress_notification.assert_any_call(
42-
progress_token=0, progress=0.0, total=10.0, message=None
42+
progress_token=0, progress=0.0, total=10.0, message=None, resource_uri=None
4343
)
4444
mock_session.send_progress_notification.assert_any_call(
45-
progress_token=0, progress=5.0, total=10.0, message=None
45+
progress_token=0, progress=5.0, total=10.0, message=None, resource_uri=None
4646
)
4747
mock_session.send_progress_notification.assert_any_call(
48-
progress_token=0, progress=10.0, total=10.0, message=None
48+
progress_token=0, progress=10.0, total=10.0, message=None, resource_uri=None
4949
)

0 commit comments

Comments
 (0)