Skip to content

Commit 76ae778

Browse files
committed
moved core logic for progress call back to BaseSession
1 parent 80913e7 commit 76ae778

File tree

2 files changed

+72
-49
lines changed

2 files changed

+72
-49
lines changed

src/mcp/client/session.py

Lines changed: 12 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import mcp.types as types
99
from mcp.shared.context import RequestContext
1010
from mcp.shared.message import SessionMessage
11-
from mcp.shared.session import BaseSession, RequestResponder
11+
from mcp.shared.session import BaseSession, ProgressFnT, RequestResponder
1212
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS
1313

1414
DEFAULT_CLIENT_INFO = types.Implementation(name="mcp", version="0.1.0")
@@ -35,13 +35,6 @@ async def __call__(
3535
) -> None: ...
3636

3737

38-
class ProgressFnT(Protocol):
39-
async def __call__(
40-
self,
41-
params: types.ProgressNotificationParams,
42-
) -> None: ...
43-
44-
4538
class MessageHandlerFnT(Protocol):
4639
async def __call__(
4740
self,
@@ -98,9 +91,6 @@ class ClientSession(
9891
types.ServerNotification,
9992
]
10093
):
101-
_progress_id: int
102-
_in_progress: dict[types.ProgressToken, ProgressFnT]
103-
10494
def __init__(
10595
self,
10696
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception],
@@ -124,8 +114,6 @@ def __init__(
124114
self._list_roots_callback = list_roots_callback or _default_list_roots_callback
125115
self._logging_callback = logging_callback or _default_logging_callback
126116
self._message_handler = message_handler or _default_message_handler
127-
self._progress_id = 0
128-
self._in_progress = {}
129117

130118
async def initialize(self) -> types.InitializeResult:
131119
sampling = types.SamplingCapability()
@@ -274,34 +262,17 @@ async def call_tool(
274262
progress_callback: ProgressFnT | None = None,
275263
) -> types.CallToolResult:
276264
"""Send a tools/call request."""
277-
278-
if progress_callback is None:
279-
progress_id = None
280-
call_params = types.CallToolRequestParams(name=name, arguments=arguments)
281-
else:
282-
progress_id = self._progress_id
283-
self._progress_id = progress_id + 1
284-
285-
call_meta = types.RequestParams.Meta(progressToken=progress_id)
286-
call_params = types.CallToolRequestParams(
287-
name=name, arguments=arguments, _meta=call_meta
288-
)
289-
self._in_progress[progress_id] = progress_callback
290-
291-
try:
292-
return await self.send_request(
293-
types.ClientRequest(
294-
types.CallToolRequest(
295-
method="tools/call",
296-
params=call_params,
297-
)
298-
),
299-
types.CallToolResult,
300-
request_read_timeout_seconds=read_timeout_seconds,
301-
)
302-
finally:
303-
if progress_id is not None:
304-
self._in_progress.pop(progress_id, None)
265+
return await self.send_request(
266+
types.ClientRequest(
267+
types.CallToolRequest(
268+
method="tools/call",
269+
params=types.CallToolRequestParams(name=name, arguments=arguments),
270+
)
271+
),
272+
types.CallToolResult,
273+
request_read_timeout_seconds=read_timeout_seconds,
274+
progress_callback=progress_callback,
275+
)
305276

306277
async def list_prompts(self) -> types.ListPromptsResult:
307278
"""Send a prompts/list request."""
@@ -414,9 +385,5 @@ async def _received_notification(
414385
match notification.root:
415386
case types.LoggingMessageNotification(params=params):
416387
await self._logging_callback(params)
417-
case types.ProgressNotification(params=params):
418-
if params.progressToken in self._in_progress:
419-
progress_callback = self._in_progress[params.progressToken]
420-
await progress_callback(params)
421388
case _:
422389
pass

src/mcp/shared/session.py

Lines changed: 60 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from contextlib import AsyncExitStack
44
from datetime import timedelta
55
from types import TracebackType
6-
from typing import Any, Generic, TypeVar
6+
from typing import Any, Generic, Protocol, TypeVar
77

88
import anyio
99
import httpx
@@ -24,6 +24,9 @@
2424
JSONRPCNotification,
2525
JSONRPCRequest,
2626
JSONRPCResponse,
27+
ProgressNotification,
28+
ProgressNotificationParams,
29+
ProgressToken,
2730
RequestParams,
2831
ServerNotification,
2932
ServerRequest,
@@ -39,6 +42,14 @@
3942
"ReceiveNotificationT", ClientNotification, ServerNotification
4043
)
4144

45+
46+
class ProgressFnT(Protocol):
47+
async def __call__(
48+
self,
49+
params: ProgressNotificationParams,
50+
) -> None: ...
51+
52+
4253
RequestId = str | int
4354

4455

@@ -168,7 +179,9 @@ class BaseSession(
168179
RequestId, MemoryObjectSendStream[JSONRPCResponse | JSONRPCError]
169180
]
170181
_request_id: int
182+
_progress_id: int
171183
_in_flight: dict[RequestId, RequestResponder[ReceiveRequestT, SendResultT]]
184+
_in_progress: dict[ProgressToken, ProgressFnT]
172185

173186
def __init__(
174187
self,
@@ -187,6 +200,8 @@ def __init__(
187200
self._receive_notification_type = receive_notification_type
188201
self._session_read_timeout_seconds = read_timeout_seconds
189202
self._in_flight = {}
203+
self._progress_id = 0
204+
self._in_progress = {}
190205
self._exit_stack = AsyncExitStack()
191206

192207
async def __aenter__(self) -> Self:
@@ -214,19 +229,44 @@ async def send_request(
214229
result_type: type[ReceiveResultT],
215230
request_read_timeout_seconds: timedelta | None = None,
216231
metadata: MessageMetadata = None,
232+
progress_callback: ProgressFnT | None = None,
217233
) -> ReceiveResultT:
218234
"""
219235
Sends a request and wait for a response. Raises an McpError if the
220236
response contains an error. If a request read timeout is provided, it
221237
will take precedence over the session read timeout.
222238
239+
If progress_callback is provided any progress notifications sent from the
240+
receiver will be passed back to the sender
241+
223242
Do not use this method to emit notifications! Use send_notification()
224243
instead.
225244
"""
226245

227246
request_id = self._request_id
228247
self._request_id = request_id + 1
229248

249+
progress_id = None
250+
send_request = None
251+
252+
if progress_callback is not None:
253+
if request.root.params is not None:
254+
progress_id = self._progress_id
255+
self._progress_id = progress_id + 1
256+
new_params = request.root.params.model_copy(
257+
update={"meta": RequestParams.Meta(progressToken=progress_id)}
258+
)
259+
new_root = request.root.model_copy(update={"params": new_params})
260+
send_request = request.model_copy(update={"root": new_root})
261+
self._in_progress[progress_id] = progress_callback
262+
else:
263+
raise ValueError(
264+
f"{type(request.root).__name__} does not support progress"
265+
)
266+
267+
if send_request is None:
268+
send_request = request
269+
230270
response_stream, response_stream_reader = anyio.create_memory_object_stream[
231271
JSONRPCResponse | JSONRPCError
232272
](1)
@@ -236,11 +276,11 @@ async def send_request(
236276
jsonrpc_request = JSONRPCRequest(
237277
jsonrpc="2.0",
238278
id=request_id,
239-
**request.model_dump(by_alias=True, mode="json", exclude_none=True),
279+
**send_request.model_dump(
280+
by_alias=True, mode="json", exclude_none=True
281+
),
240282
)
241283

242-
# TODO: Support progress callbacks
243-
244284
await self._write_stream.send(
245285
SessionMessage(
246286
message=JSONRPCMessage(jsonrpc_request), metadata=metadata
@@ -276,6 +316,8 @@ async def send_request(
276316

277317
finally:
278318
self._response_streams.pop(request_id, None)
319+
if progress_id is not None:
320+
self._in_progress.pop(progress_id, None)
279321
await response_stream.aclose()
280322
await response_stream_reader.aclose()
281323

@@ -364,6 +406,20 @@ async def _receive_loop(self) -> None:
364406
if cancelled_id in self._in_flight:
365407
await self._in_flight[cancelled_id].cancel()
366408
else:
409+
match notification.root:
410+
case ProgressNotification(params=params):
411+
if params.progressToken in self._in_progress:
412+
progress_callback = self._in_progress[
413+
params.progressToken
414+
]
415+
await progress_callback(params)
416+
else:
417+
logging.warning(
418+
"Unknown progress token %s",
419+
params.progressToken,
420+
)
421+
case _:
422+
pass
367423
await self._received_notification(notification)
368424
await self._handle_incoming(notification)
369425
except Exception as e:

0 commit comments

Comments
 (0)