Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions mcp_proxy_for_aws/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

from importlib.metadata import version as _metadata_version

import mcp_proxy_for_aws.fastmcp_patch as _fastmcp_patch


__all__ = ['__version__']
__version__ = _metadata_version('mcp-proxy-for-aws')
34 changes: 34 additions & 0 deletions mcp_proxy_for_aws/fastmcp_patch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import fastmcp.server.low_level as low_level_module
import mcp.types
from functools import wraps
from mcp import McpError
from mcp.server.stdio import stdio_server as stdio_server
from mcp.shared.session import RequestResponder


original_receive_request = low_level_module.MiddlewareServerSession._received_request


@wraps(original_receive_request)
async def _received_request(
self,
responder: RequestResponder[mcp.types.ClientRequest, mcp.types.ServerResult],
):
"""Monkey patch fastmcp so that the initialize error from the middleware can be send back to the client.

https://github.com/jlowin/fastmcp/pull/2531
"""
if isinstance(responder.request.root, mcp.types.InitializeRequest):
try:
return await original_receive_request(self, responder)
except McpError as e:
if not responder._completed:
with responder:
return await responder.respond(e.error)

raise e
else:
return await original_receive_request(self, responder)


low_level_module.MiddlewareServerSession._received_request = _received_request
14 changes: 14 additions & 0 deletions mcp_proxy_for_aws/middleware/initialize_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,20 @@ async def on_initialize(
try:
logger.debug('Received initialize request %s.', context.message)
self._client_factory.set_init_params(context.message)
client = await self._client_factory.get_client()
# connect the http client, fail and don't succeed the stdio connect
# if remote client cannot be connected
client_name = context.message.params.clientInfo.name.lower()
if 'kiro cli' not in client_name and 'q dev cli' not in client_name:
# q cli / kiro cli uses the rust SDK which does not handle json rpc error
# properly during initialization.
# https://github.com/modelcontextprotocol/rust-sdk/pull/569
# if calling _connect below raise mcp error, the q cli will skip the message
# and continue wait for a json rpc response message which will never come.
# Luckily, q cli calls list tool immediately after being connected to a mcp server
# the list_tool call will require the client to be connected again, so the mcp error
# will be displayed in the q cli logs.
await client._connect()
return await call_next(context)
except Exception:
logger.exception('Initialize failed in middleware.')
Expand Down
106 changes: 106 additions & 0 deletions tests/unit/test_fastmcp_patch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import mcp.types as mt
import pytest
from mcp import McpError
from mcp.shared.session import RequestResponder
from unittest.mock import AsyncMock, Mock, patch


@pytest.mark.asyncio
async def test_patched_received_request_initialize_success():
"""Test that patched _received_request calls original for successful initialize."""
# Import after patching is applied
import fastmcp.server.low_level as low_level_module
from mcp_proxy_for_aws import fastmcp_patch

mock_self = Mock()
mock_self.fastmcp = Mock()

mock_request = Mock()
mock_request.root = Mock(spec=mt.InitializeRequest)

mock_responder = Mock(spec=RequestResponder)
mock_responder.request = mock_request

with patch.object(
fastmcp_patch, 'original_receive_request', new_callable=AsyncMock
) as mock_original:
await low_level_module.MiddlewareServerSession._received_request(mock_self, mock_responder)
mock_original.assert_called_once_with(mock_self, mock_responder)


@pytest.mark.asyncio
async def test_patched_received_request_initialize_mcp_error_not_completed():
"""Test that patched _received_request handles McpError when responder not completed."""
import fastmcp.server.low_level as low_level_module
from mcp_proxy_for_aws import fastmcp_patch

mock_self = Mock()
mock_self.fastmcp = Mock()

mock_request = Mock()
mock_request.root = Mock(spec=mt.InitializeRequest)

mock_responder = Mock(spec=RequestResponder)
mock_responder.request = mock_request
mock_responder._completed = False
mock_responder.__enter__ = Mock(return_value=mock_responder)
mock_responder.__exit__ = Mock(return_value=False)
mock_responder.respond = AsyncMock()

error = mt.ErrorData(code=1, message='test error')
mcp_error = McpError(error=error)

with patch.object(
fastmcp_patch, 'original_receive_request', new_callable=AsyncMock, side_effect=mcp_error
):
await low_level_module.MiddlewareServerSession._received_request(mock_self, mock_responder)
mock_responder.respond.assert_called_once_with(error)


@pytest.mark.asyncio
async def test_patched_received_request_initialize_mcp_error_completed():
"""Test that patched _received_request re-raises McpError when responder completed."""
import fastmcp.server.low_level as low_level_module
from mcp_proxy_for_aws import fastmcp_patch

mock_self = Mock()
mock_self.fastmcp = Mock()

mock_request = Mock()
mock_request.root = Mock(spec=mt.InitializeRequest)

mock_responder = Mock(spec=RequestResponder)
mock_responder.request = mock_request
mock_responder._completed = True

error = mt.ErrorData(code=1, message='test error')
mcp_error = McpError(error=error)

with patch.object(
fastmcp_patch, 'original_receive_request', new_callable=AsyncMock, side_effect=mcp_error
):
with pytest.raises(McpError):
await low_level_module.MiddlewareServerSession._received_request(
mock_self, mock_responder
)


@pytest.mark.asyncio
async def test_patched_received_request_non_initialize():
"""Test that patched _received_request calls original for non-initialize requests."""
import fastmcp.server.low_level as low_level_module
from mcp_proxy_for_aws import fastmcp_patch

mock_self = Mock()

mock_request = Mock()
mock_request.root = Mock(spec=mt.CallToolRequest)

mock_responder = Mock(spec=RequestResponder)
mock_responder.request = mock_request

with patch.object(
fastmcp_patch, 'original_receive_request', new_callable=AsyncMock
) as mock_original:
await low_level_module.MiddlewareServerSession._received_request(mock_self, mock_responder)
mock_original.assert_called_once_with(mock_self, mock_responder)
98 changes: 98 additions & 0 deletions tests/unit/test_initialize_middleware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import mcp.types as mt
import pytest
from mcp_proxy_for_aws.middleware.initialize_middleware import InitializeMiddleware
from unittest.mock import AsyncMock, Mock


def create_initialize_request(client_name: str) -> mt.InitializeRequest:
"""Create a real InitializeRequest object."""
return mt.InitializeRequest(
method='initialize',
params=mt.InitializeRequestParams(
protocolVersion='2024-11-05',
capabilities=mt.ClientCapabilities(),
clientInfo=mt.Implementation(name=client_name, version='1.0'),
),
)


@pytest.mark.asyncio
async def test_on_initialize_connects_client():
"""Test that on_initialize calls client._connect()."""
mock_client = Mock()
mock_client._connect = AsyncMock()

mock_factory = Mock()
mock_factory.set_init_params = Mock()
mock_factory.get_client = AsyncMock(return_value=mock_client)

middleware = InitializeMiddleware(mock_factory)

mock_context = Mock()
mock_context.message = create_initialize_request('test-client')

mock_call_next = AsyncMock()

await middleware.on_initialize(mock_context, mock_call_next)

mock_factory.set_init_params.assert_called_once_with(mock_context.message)
mock_factory.get_client.assert_called_once()
mock_client._connect.assert_called_once()
mock_call_next.assert_called_once_with(mock_context)


@pytest.mark.asyncio
async def test_on_initialize_fails_if_connect_fails():
"""Test that on_initialize raises exception if _connect() fails."""
mock_client = Mock()
mock_client._connect = AsyncMock(side_effect=Exception('Connection failed'))

mock_factory = Mock()
mock_factory.set_init_params = Mock()
mock_factory.get_client = AsyncMock(return_value=mock_client)

middleware = InitializeMiddleware(mock_factory)

mock_context = Mock()
mock_context.message = create_initialize_request('test-client')

mock_call_next = AsyncMock()

with pytest.raises(Exception, match='Connection failed'):
await middleware.on_initialize(mock_context, mock_call_next)

mock_call_next.assert_not_called()


@pytest.mark.asyncio
@pytest.mark.parametrize(
'client_name',
[
'Kiro CLI',
'kiro cli',
'KIRO CLI',
'Amazon Q Dev CLI',
'amazon q dev cli',
'Q DEV CLI',
],
)
async def test_on_initialize_skips_connect_for_special_clients(client_name):
"""Test that on_initialize skips _connect() for Kiro CLI and Q Dev CLI."""
mock_client = Mock()
mock_client._connect = AsyncMock()

mock_factory = Mock()
mock_factory.set_init_params = Mock()
mock_factory.get_client = AsyncMock(return_value=mock_client)

middleware = InitializeMiddleware(mock_factory)

mock_context = Mock()
mock_context.message = create_initialize_request(client_name)

mock_call_next = AsyncMock()

await middleware.on_initialize(mock_context, mock_call_next)

mock_client._connect.assert_not_called()
mock_call_next.assert_called_once_with(mock_context)
Loading