Skip to content

Commit 79f3c4e

Browse files
committed
avoid exceptions during join call tool on timeout as this is expected behaviour use None when no result retrieved instead
1 parent e4c25b7 commit 79f3c4e

File tree

5 files changed

+44
-35
lines changed

5 files changed

+44
-35
lines changed

src/mcp/client/session.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,7 @@ async def join_call_tool(
343343
progress_callback: ProgressFnT | None = None,
344344
request_read_timeout_seconds: timedelta | None = None,
345345
done_on_timeout: bool = True,
346-
) -> types.CallToolResult:
346+
) -> types.CallToolResult | None:
347347
return await self.join_request(
348348
request_id,
349349
types.CallToolResult,

src/mcp/server/streamable_http.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -763,8 +763,8 @@ async def send_event(event_message: EventMessage) -> None:
763763
async with msg_reader:
764764
async for event_message in msg_reader:
765765
event_data = self._create_event_data(event_message)
766-
767766
await sse_stream_writer.send(event_data)
767+
768768
except Exception as e:
769769
logger.exception(f"Error in replay sender: {e}")
770770

src/mcp/shared/session.py

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ async def receive_response(
186186
self,
187187
request_id: RequestId,
188188
timeout: float | None = None,
189-
) -> JSONRPCResponse | JSONRPCError: ...
189+
) -> JSONRPCResponse | JSONRPCError | None: ...
190190

191191
async def handle_response(self, message: JSONRPCResponse | JSONRPCError) -> bool: ...
192192

@@ -278,7 +278,7 @@ async def receive_response(
278278
self,
279279
request_id: RequestId,
280280
timeout: float | None = None,
281-
) -> JSONRPCResponse | JSONRPCError:
281+
) -> JSONRPCResponse | JSONRPCError | None:
282282
_, receive_stream = self._response_streams.get(request_id, [None, None])
283283
if receive_stream is None:
284284
raise McpError(
@@ -302,16 +302,7 @@ async def receive_response(
302302
)
303303
)
304304
except TimeoutError:
305-
raise McpError(
306-
ErrorData(
307-
code=httpx.codes.REQUEST_TIMEOUT,
308-
message=(
309-
f"Timed out while waiting for response to "
310-
f"{request.__class__.__name__}. Waited "
311-
f"{timeout} seconds."
312-
),
313-
)
314-
)
305+
return None
315306

316307
async def handle_response(self, message: JSONRPCResponse | JSONRPCError) -> bool:
317308
send_stream, _ = self._response_streams.get(message.id, [None, None])
@@ -453,9 +444,11 @@ async def join_request(
453444
request_read_timeout_seconds: timedelta | None = None,
454445
progress_callback: ProgressFnT | None = None,
455446
done_on_timeout: bool = True,
456-
) -> ReceiveResultT:
447+
) -> ReceiveResultT | None:
457448
"""
458-
Joins a request previously started via start_request
449+
Joins a request previously started via start_request.
450+
451+
Returns the result or None if timeout is reached.
459452
"""
460453
resume = self._request_state_manager.resume(request_id)
461454

@@ -488,17 +481,23 @@ async def join_request(
488481

489482
response_or_error = await self._request_state_manager.receive_response(request_id, timeout)
490483

491-
if isinstance(response_or_error, JSONRPCError):
484+
if response_or_error is None:
485+
if done_on_timeout:
486+
await self._request_state_manager.close_request(request_id)
487+
return None
488+
elif isinstance(response_or_error, JSONRPCError):
492489
if response_or_error.error.code == httpx.codes.REQUEST_TIMEOUT.value:
493490
if done_on_timeout:
494491
await self._request_state_manager.close_request(request_id)
492+
return None
495493
else:
496494
await self._request_state_manager.close_request(request_id)
497-
raise McpError(response_or_error.error)
498-
else:
495+
raise McpError(response_or_error.error)
496+
else :
499497
await self._request_state_manager.close_request(request_id)
500498
return result_type.model_validate(response_or_error.result)
501499

500+
502501
async def cancel_request(self, request_id: RequestId) -> bool:
503502
"""
504503
Cancels a request previously started via start_request
@@ -533,7 +532,20 @@ async def send_request(
533532
"""
534533
request_id = await self.start_request(request, metadata, progress_callback)
535534
try:
536-
return await self.join_request(request_id, result_type, request_read_timeout_seconds)
535+
result = await self.join_request(request_id, result_type, request_read_timeout_seconds)
536+
if result is None:
537+
raise McpError(
538+
ErrorData(
539+
code=httpx.codes.REQUEST_TIMEOUT,
540+
message=(
541+
f"Timed out while waiting for response to "
542+
f"{request.__class__.__name__}. Waited "
543+
f"{request_read_timeout_seconds} seconds."
544+
),
545+
)
546+
)
547+
else:
548+
return result
537549
finally:
538550
await self._request_state_manager.close_request(request_id)
539551

tests/client/test_session.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -635,14 +635,10 @@ async def message_handler(
635635
request_id = await session.request_call_tool("hello", {"name": "world"})
636636

637637
with anyio.fail_after(3):
638-
try:
639-
result = await session.join_call_tool(
640-
request_id, request_read_timeout_seconds=timedelta(seconds=0.5), done_on_timeout=False
641-
)
642-
# raise RuntimeError("Expected fail")
643-
except McpError as e:
644-
if not e.error.code == httpx.codes.REQUEST_TIMEOUT:
645-
raise e
638+
result = await session.join_call_tool(
639+
request_id, request_read_timeout_seconds=timedelta(seconds=0.5), done_on_timeout=False
640+
)
641+
assert result is None
646642
send_result.set()
647643
result = await session.join_call_tool(
648644
request_id, request_read_timeout_seconds=timedelta(seconds=1), done_on_timeout=False

tests/shared/test_streamable_http.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1351,13 +1351,13 @@ async def run_tool():
13511351
captured_request_id = await session.request_call_tool(
13521352
"long_running_with_checkpoints", arguments={}
13531353
)
1354-
try:
1355-
await session.join_call_tool(
1356-
captured_request_id, request_read_timeout_seconds=timedelta(seconds=0.01)
1357-
)
1358-
raise RuntimeError("Expected timeout")
1359-
except McpError as e:
1360-
assert e.error.code == httpx.codes.REQUEST_TIMEOUT.value
1354+
1355+
result = await session.join_call_tool(
1356+
captured_request_id, request_read_timeout_seconds=timedelta(seconds=0.01),
1357+
done_on_timeout=False
1358+
)
1359+
1360+
assert result is None
13611361

13621362
timed_out.set()
13631363

@@ -1409,6 +1409,7 @@ async def run_tool():
14091409
assert captured_request_id is not None
14101410

14111411
result = await session.join_call_tool(captured_request_id)
1412+
assert result is not None
14121413

14131414
# We should get a complete result
14141415
assert len(result.content) == 1

0 commit comments

Comments
 (0)