Skip to content

Commit d1afcb9

Browse files
committed
fix: connect remote mcp client immediately in the initialize middleware
1 parent 919f073 commit d1afcb9

File tree

2 files changed

+56
-0
lines changed

2 files changed

+56
-0
lines changed

mcp_proxy_for_aws/middleware/initialize_middleware.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@ async def on_initialize(
2525
try:
2626
logger.debug('Received initialize request %s.', context.message)
2727
self._client_factory.set_init_params(context.message)
28+
client = await self._client_factory.get_client()
29+
# connect the http client, fail and don't succeed the stdio connect
30+
# if remote client cannot be connected
31+
await client._connect()
2832
return await call_next(context)
2933
except Exception:
3034
logger.exception('Initialize failed in middleware.')
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import mcp.types as mt
2+
import pytest
3+
from mcp_proxy_for_aws.middleware.initialize_middleware import InitializeMiddleware
4+
from unittest.mock import AsyncMock, Mock
5+
6+
7+
@pytest.mark.asyncio
8+
async def test_on_initialize_connects_client():
9+
"""Test that on_initialize calls client._connect()."""
10+
mock_client = Mock()
11+
mock_client._connect = AsyncMock()
12+
13+
mock_factory = Mock()
14+
mock_factory.set_init_params = Mock()
15+
mock_factory.get_client = AsyncMock(return_value=mock_client)
16+
17+
middleware = InitializeMiddleware(mock_factory)
18+
19+
mock_context = Mock()
20+
mock_context.message = Mock(spec=mt.InitializeRequest)
21+
22+
mock_call_next = AsyncMock()
23+
24+
await middleware.on_initialize(mock_context, mock_call_next)
25+
26+
mock_factory.set_init_params.assert_called_once_with(mock_context.message)
27+
mock_factory.get_client.assert_called_once()
28+
mock_client._connect.assert_called_once()
29+
mock_call_next.assert_called_once_with(mock_context)
30+
31+
32+
@pytest.mark.asyncio
33+
async def test_on_initialize_fails_if_connect_fails():
34+
"""Test that on_initialize raises exception if _connect() fails."""
35+
mock_client = Mock()
36+
mock_client._connect = AsyncMock(side_effect=Exception('Connection failed'))
37+
38+
mock_factory = Mock()
39+
mock_factory.set_init_params = Mock()
40+
mock_factory.get_client = AsyncMock(return_value=mock_client)
41+
42+
middleware = InitializeMiddleware(mock_factory)
43+
44+
mock_context = Mock()
45+
mock_context.message = Mock(spec=mt.InitializeRequest)
46+
47+
mock_call_next = AsyncMock()
48+
49+
with pytest.raises(Exception, match='Connection failed'):
50+
await middleware.on_initialize(mock_context, mock_call_next)
51+
52+
mock_call_next.assert_not_called()

0 commit comments

Comments
 (0)