Skip to content

Commit d4ccac1

Browse files
committed
fix: do not call initialize for q dev cli / kiro cli
1 parent 3bbe8ec commit d4ccac1

File tree

2 files changed

+59
-3
lines changed

2 files changed

+59
-3
lines changed

mcp_proxy_for_aws/middleware/initialize_middleware.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,17 @@ async def on_initialize(
2828
client = await self._client_factory.get_client()
2929
# connect the http client, fail and don't succeed the stdio connect
3030
# if remote client cannot be connected
31-
await client._connect()
31+
client_name = context.message.params.clientInfo.name.lower()
32+
if 'kiro cli' not in client_name and 'q dev cli' not in client_name:
33+
# q cli / kiro cli uses the rust SDK which does not handle json rpc error
34+
# properly during initialization.
35+
# https://github.com/modelcontextprotocol/rust-sdk/pull/569
36+
# if calling _connect below raise mcp error, the q cli will skip the message
37+
# and continue wait for a json rpc response message which will never come.
38+
# Luckily, q cli calls list tool immediately after being connected to a mcp server
39+
# the list_tool call will require the client to be connected again, so the mcp error
40+
# will be displayed in the q cli logs.
41+
await client._connect()
3242
return await call_next(context)
3343
except Exception:
3444
logger.exception('Initialize failed in middleware.')

tests/unit/test_initialize_middleware.py

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,18 @@
44
from unittest.mock import AsyncMock, Mock
55

66

7+
def create_initialize_request(client_name: str) -> mt.InitializeRequest:
8+
"""Create a real InitializeRequest object."""
9+
return mt.InitializeRequest(
10+
method='initialize',
11+
params=mt.InitializeRequestParams(
12+
protocolVersion='2024-11-05',
13+
capabilities=mt.ClientCapabilities(),
14+
clientInfo=mt.Implementation(name=client_name, version='1.0'),
15+
),
16+
)
17+
18+
719
@pytest.mark.asyncio
820
async def test_on_initialize_connects_client():
921
"""Test that on_initialize calls client._connect()."""
@@ -17,7 +29,7 @@ async def test_on_initialize_connects_client():
1729
middleware = InitializeMiddleware(mock_factory)
1830

1931
mock_context = Mock()
20-
mock_context.message = Mock(spec=mt.InitializeRequest)
32+
mock_context.message = create_initialize_request('test-client')
2133

2234
mock_call_next = AsyncMock()
2335

@@ -42,11 +54,45 @@ async def test_on_initialize_fails_if_connect_fails():
4254
middleware = InitializeMiddleware(mock_factory)
4355

4456
mock_context = Mock()
45-
mock_context.message = Mock(spec=mt.InitializeRequest)
57+
mock_context.message = create_initialize_request('test-client')
4658

4759
mock_call_next = AsyncMock()
4860

4961
with pytest.raises(Exception, match='Connection failed'):
5062
await middleware.on_initialize(mock_context, mock_call_next)
5163

5264
mock_call_next.assert_not_called()
65+
66+
67+
@pytest.mark.asyncio
68+
@pytest.mark.parametrize(
69+
'client_name',
70+
[
71+
'Kiro CLI',
72+
'kiro cli',
73+
'KIRO CLI',
74+
'Amazon Q Dev CLI',
75+
'amazon q dev cli',
76+
'Q DEV CLI',
77+
],
78+
)
79+
async def test_on_initialize_skips_connect_for_special_clients(client_name):
80+
"""Test that on_initialize skips _connect() for Kiro CLI and Q Dev CLI."""
81+
mock_client = Mock()
82+
mock_client._connect = AsyncMock()
83+
84+
mock_factory = Mock()
85+
mock_factory.set_init_params = Mock()
86+
mock_factory.get_client = AsyncMock(return_value=mock_client)
87+
88+
middleware = InitializeMiddleware(mock_factory)
89+
90+
mock_context = Mock()
91+
mock_context.message = create_initialize_request(client_name)
92+
93+
mock_call_next = AsyncMock()
94+
95+
await middleware.on_initialize(mock_context, mock_call_next)
96+
97+
mock_client._connect.assert_not_called()
98+
mock_call_next.assert_called_once_with(mock_context)

0 commit comments

Comments
 (0)