Skip to content

Commit 3fbd94a

Browse files
committed
fix: patch fastmcp lowlevel session method
We need this until jlowin/fastmcp#2531 is accepted.
1 parent d1afcb9 commit 3fbd94a

File tree

3 files changed

+142
-0
lines changed

3 files changed

+142
-0
lines changed

mcp_proxy_for_aws/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
from importlib.metadata import version as _metadata_version
1818

19+
import mcp_proxy_for_aws.fastmcp_patch as _fastmcp_patch
20+
1921

2022
__all__ = ['__version__']
2123
__version__ = _metadata_version('mcp-proxy-for-aws')

mcp_proxy_for_aws/fastmcp_patch.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import fastmcp.server.low_level as low_level_module
2+
import mcp.types
3+
from functools import wraps
4+
from mcp import McpError
5+
from mcp.server.stdio import stdio_server as stdio_server
6+
from mcp.shared.session import RequestResponder
7+
8+
9+
original_receive_request = low_level_module.MiddlewareServerSession._received_request
10+
11+
12+
@wraps(original_receive_request)
13+
async def _received_request(
14+
self,
15+
responder: RequestResponder[mcp.types.ClientRequest, mcp.types.ServerResult],
16+
):
17+
"""Monkey patch fastmcp so that the initialize error from the middleware can be send back to the client.
18+
19+
https://github.com/jlowin/fastmcp/pull/2531
20+
"""
21+
if isinstance(responder.request.root, mcp.types.InitializeRequest):
22+
try:
23+
return await original_receive_request(self, responder)
24+
except McpError as e:
25+
if not responder._completed:
26+
with responder:
27+
return await responder.respond(e.error)
28+
29+
raise e
30+
else:
31+
return await original_receive_request(self, responder)
32+
33+
34+
low_level_module.MiddlewareServerSession._received_request = _received_request

tests/unit/test_fastmcp_patch.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
import mcp.types as mt
2+
import pytest
3+
from mcp import McpError
4+
from mcp.shared.session import RequestResponder
5+
from unittest.mock import AsyncMock, Mock, patch
6+
7+
8+
@pytest.mark.asyncio
9+
async def test_patched_received_request_initialize_success():
10+
"""Test that patched _received_request calls original for successful initialize."""
11+
# Import after patching is applied
12+
import fastmcp.server.low_level as low_level_module
13+
from mcp_proxy_for_aws import fastmcp_patch
14+
15+
mock_self = Mock()
16+
mock_self.fastmcp = Mock()
17+
18+
mock_request = Mock()
19+
mock_request.root = Mock(spec=mt.InitializeRequest)
20+
21+
mock_responder = Mock(spec=RequestResponder)
22+
mock_responder.request = mock_request
23+
24+
with patch.object(
25+
fastmcp_patch, 'original_receive_request', new_callable=AsyncMock
26+
) as mock_original:
27+
await low_level_module.MiddlewareServerSession._received_request(mock_self, mock_responder)
28+
mock_original.assert_called_once_with(mock_self, mock_responder)
29+
30+
31+
@pytest.mark.asyncio
32+
async def test_patched_received_request_initialize_mcp_error_not_completed():
33+
"""Test that patched _received_request handles McpError when responder not completed."""
34+
import fastmcp.server.low_level as low_level_module
35+
from mcp_proxy_for_aws import fastmcp_patch
36+
37+
mock_self = Mock()
38+
mock_self.fastmcp = Mock()
39+
40+
mock_request = Mock()
41+
mock_request.root = Mock(spec=mt.InitializeRequest)
42+
43+
mock_responder = Mock(spec=RequestResponder)
44+
mock_responder.request = mock_request
45+
mock_responder._completed = False
46+
mock_responder.__enter__ = Mock(return_value=mock_responder)
47+
mock_responder.__exit__ = Mock(return_value=False)
48+
mock_responder.respond = AsyncMock()
49+
50+
error = mt.ErrorData(code=1, message='test error')
51+
mcp_error = McpError(error=error)
52+
53+
with patch.object(
54+
fastmcp_patch, 'original_receive_request', new_callable=AsyncMock, side_effect=mcp_error
55+
):
56+
await low_level_module.MiddlewareServerSession._received_request(mock_self, mock_responder)
57+
mock_responder.respond.assert_called_once_with(error)
58+
59+
60+
@pytest.mark.asyncio
61+
async def test_patched_received_request_initialize_mcp_error_completed():
62+
"""Test that patched _received_request re-raises McpError when responder completed."""
63+
import fastmcp.server.low_level as low_level_module
64+
from mcp_proxy_for_aws import fastmcp_patch
65+
66+
mock_self = Mock()
67+
mock_self.fastmcp = Mock()
68+
69+
mock_request = Mock()
70+
mock_request.root = Mock(spec=mt.InitializeRequest)
71+
72+
mock_responder = Mock(spec=RequestResponder)
73+
mock_responder.request = mock_request
74+
mock_responder._completed = True
75+
76+
error = mt.ErrorData(code=1, message='test error')
77+
mcp_error = McpError(error=error)
78+
79+
with patch.object(
80+
fastmcp_patch, 'original_receive_request', new_callable=AsyncMock, side_effect=mcp_error
81+
):
82+
with pytest.raises(McpError):
83+
await low_level_module.MiddlewareServerSession._received_request(
84+
mock_self, mock_responder
85+
)
86+
87+
88+
@pytest.mark.asyncio
89+
async def test_patched_received_request_non_initialize():
90+
"""Test that patched _received_request calls original for non-initialize requests."""
91+
import fastmcp.server.low_level as low_level_module
92+
from mcp_proxy_for_aws import fastmcp_patch
93+
94+
mock_self = Mock()
95+
96+
mock_request = Mock()
97+
mock_request.root = Mock(spec=mt.CallToolRequest)
98+
99+
mock_responder = Mock(spec=RequestResponder)
100+
mock_responder.request = mock_request
101+
102+
with patch.object(
103+
fastmcp_patch, 'original_receive_request', new_callable=AsyncMock
104+
) as mock_original:
105+
await low_level_module.MiddlewareServerSession._received_request(mock_self, mock_responder)
106+
mock_original.assert_called_once_with(mock_self, mock_responder)

0 commit comments

Comments
 (0)