Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 15 additions & 9 deletions mcp_proxy_for_aws/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,13 @@
# limitations under the License.

import boto3
import httpx
import logging
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from botocore.credentials import Credentials
from contextlib import _AsyncGeneratorContextManager
from datetime import timedelta
from mcp.client.streamable_http import GetSessionIdCallback, streamablehttp_client
from mcp.client.streamable_http import GetSessionIdCallback, streamable_http_client
from mcp.shared._httpx_utils import McpHttpClientFactory, create_mcp_http_client
from mcp.shared.message import SessionMessage
from mcp_proxy_for_aws.sigv4_helper import SigV4HTTPXAuth
Expand Down Expand Up @@ -113,13 +114,18 @@ def aws_iam_streamablehttp_client(
# Create a SigV4 authentication handler with AWS credentials
auth = SigV4HTTPXAuth(creds, aws_service, region)

# Convert timeout to httpx.Timeout if needed
if isinstance(timeout, (int, float)):
httpx_timeout = httpx.Timeout(timeout=timeout)
elif isinstance(timeout, timedelta):
httpx_timeout = httpx.Timeout(timeout=timeout.total_seconds())
else:
httpx_timeout = timeout

# Create the HTTP client with authentication and configuration
http_client = httpx_client_factory(headers=headers, timeout=httpx_timeout, auth=auth)

# Return the streamable HTTP client context manager with AWS IAM authentication
return streamablehttp_client(
url=endpoint,
headers=headers,
timeout=timeout,
sse_read_timeout=sse_read_timeout,
terminate_on_close=terminate_on_close,
httpx_client_factory=httpx_client_factory,
auth=auth,
return streamable_http_client(
url=endpoint, http_client=http_client, terminate_on_close=terminate_on_close
)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ description = "MCP Proxy for AWS"
readme = "README.md"
requires-python = ">=3.10,<3.15"
dependencies = [
"fastmcp (>=2.13.1,<2.14.1)",
"fastmcp>=2.14.4,<3.0.0",
"boto3>=1.41.0",
"botocore[crt]>=1.41.0",
]
Expand Down
59 changes: 38 additions & 21 deletions tests/unit/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ async def test_boto3_session_parameters(
mock_read, mock_write, mock_get_session = mock_streams

with patch('boto3.Session', return_value=mock_session) as mock_boto:
with patch('mcp_proxy_for_aws.client.streamablehttp_client') as mock_stream_client:
with patch('mcp_proxy_for_aws.client.streamable_http_client') as mock_stream_client:
mock_stream_client.return_value.__aenter__ = AsyncMock(
return_value=(mock_read, mock_write, mock_get_session)
)
Expand Down Expand Up @@ -94,9 +94,11 @@ async def test_sigv4_auth_is_created_and_used(mock_session, mock_streams, servic

with patch('boto3.Session', return_value=mock_session):
with patch('mcp_proxy_for_aws.client.SigV4HTTPXAuth') as mock_auth_cls:
with patch('mcp_proxy_for_aws.client.streamablehttp_client') as mock_stream_client:
with patch('mcp_proxy_for_aws.client.streamable_http_client') as mock_stream_client:
mock_auth = Mock()
mock_auth_cls.return_value = mock_auth
mock_http_client = Mock()
mock_factory = Mock(return_value=mock_http_client)
mock_stream_client.return_value.__aenter__ = AsyncMock(
return_value=(mock_read, mock_write, mock_get_session)
)
Expand All @@ -106,17 +108,20 @@ async def test_sigv4_auth_is_created_and_used(mock_session, mock_streams, servic
endpoint='https://test.example.com/mcp',
aws_service=service_name,
aws_region=region,
httpx_client_factory=mock_factory,
):
pass

mock_auth_cls.assert_called_once_with(
# Auth should be constructed with the resolved credentials, service, and region,
# and passed into the streamable client.
# Auth should be constructed with the resolved credentials, service, and region
mock_session.get_credentials.return_value,
service_name,
region,
)
assert mock_stream_client.call_args[1]['auth'] is mock_auth
# Auth should be passed to the httpx client factory
assert mock_factory.call_args[1]['auth'] is mock_auth
# The created http client should be passed to streamable_http_client
assert mock_stream_client.call_args[1]['http_client'] is mock_http_client


@pytest.mark.asyncio
Expand All @@ -132,12 +137,14 @@ async def test_streamable_client_parameters(
mock_session, mock_streams, headers, timeout_value, sse_value, terminate_value
):
"""Test the correctness of streamablehttp_client parameters."""
# Verify that connection settings are forwarded as-is to the streamable HTTP client.
# timedelta values are allowed and compared directly here.
# Verify that connection settings are forwarded correctly to the httpx client factory
# and streamable HTTP client.
mock_read, mock_write, mock_get_session = mock_streams

with patch('boto3.Session', return_value=mock_session):
with patch('mcp_proxy_for_aws.client.streamablehttp_client') as mock_stream_client:
with patch('mcp_proxy_for_aws.client.streamable_http_client') as mock_stream_client:
mock_http_client = Mock()
mock_factory = Mock(return_value=mock_http_client)
mock_stream_client.return_value.__aenter__ = AsyncMock(
return_value=(mock_read, mock_write, mock_get_session)
)
Expand All @@ -150,27 +157,34 @@ async def test_streamable_client_parameters(
timeout=timeout_value,
sse_read_timeout=sse_value,
terminate_on_close=terminate_value,
httpx_client_factory=mock_factory,
):
pass

call_kwargs = mock_stream_client.call_args[1]
# Confirm each parameter is forwarded unchanged.
assert call_kwargs['url'] == 'https://test.example.com/mcp'
assert call_kwargs['headers'] == headers
assert call_kwargs['timeout'] == timeout_value
assert call_kwargs['sse_read_timeout'] == sse_value
assert call_kwargs['terminate_on_close'] == terminate_value
# Verify headers and auth are passed to the factory
factory_call_kwargs = mock_factory.call_args[1]
assert factory_call_kwargs['headers'] == headers
# Timeout is passed to the factory (converted to httpx.Timeout)
assert factory_call_kwargs['timeout'] is not None

# Verify the created http client and other params are passed to streamable_http_client
stream_call_kwargs = mock_stream_client.call_args[1]
assert stream_call_kwargs['url'] == 'https://test.example.com/mcp'
assert stream_call_kwargs['http_client'] is mock_http_client
assert stream_call_kwargs['terminate_on_close'] == terminate_value


@pytest.mark.asyncio
async def test_custom_httpx_client_factory_is_passed(mock_session, mock_streams):
"""Test the passing of a custom HTTPX client factory."""
# The factory should be handed through to the underlying streamable client untouched.
# The factory should be used to create the http client.
mock_read, mock_write, mock_get_session = mock_streams
custom_factory = Mock()
mock_http_client = Mock()
custom_factory.return_value = mock_http_client

with patch('boto3.Session', return_value=mock_session):
with patch('mcp_proxy_for_aws.client.streamablehttp_client') as mock_stream_client:
with patch('mcp_proxy_for_aws.client.streamable_http_client') as mock_stream_client:
mock_stream_client.return_value.__aenter__ = AsyncMock(
return_value=(mock_read, mock_write, mock_get_session)
)
Expand All @@ -183,7 +197,10 @@ async def test_custom_httpx_client_factory_is_passed(mock_session, mock_streams)
):
pass

assert mock_stream_client.call_args[1]['httpx_client_factory'] is custom_factory
# Verify the custom factory was called
custom_factory.assert_called_once()
# Verify the http client from the factory was passed to streamable_http_client
assert mock_stream_client.call_args[1]['http_client'] is mock_http_client


@pytest.mark.asyncio
Expand All @@ -198,7 +215,7 @@ async def mock_aexit(*_):
cleanup_called = True

with patch('boto3.Session', return_value=mock_session):
with patch('mcp_proxy_for_aws.client.streamablehttp_client') as mock_stream_client:
with patch('mcp_proxy_for_aws.client.streamable_http_client') as mock_stream_client:
mock_stream_client.return_value.__aenter__ = AsyncMock(
return_value=(mock_read, mock_write, mock_get_session)
)
Expand All @@ -220,7 +237,7 @@ async def test_credentials_parameter_with_region(mock_streams):
creds = Credentials('test_key', 'test_secret', 'test_token')

with patch('mcp_proxy_for_aws.client.SigV4HTTPXAuth') as mock_auth_cls:
with patch('mcp_proxy_for_aws.client.streamablehttp_client') as mock_stream_client:
with patch('mcp_proxy_for_aws.client.streamable_http_client') as mock_stream_client:
mock_auth = Mock()
mock_auth_cls.return_value = mock_auth
mock_stream_client.return_value.__aenter__ = AsyncMock(
Expand Down Expand Up @@ -264,7 +281,7 @@ async def test_credentials_parameter_bypasses_boto3_session(mock_streams):

with patch('boto3.Session') as mock_boto:
with patch('mcp_proxy_for_aws.client.SigV4HTTPXAuth'):
with patch('mcp_proxy_for_aws.client.streamablehttp_client') as mock_stream_client:
with patch('mcp_proxy_for_aws.client.streamable_http_client') as mock_stream_client:
mock_stream_client.return_value.__aenter__ = AsyncMock(
return_value=(mock_read, mock_write, mock_get_session)
)
Expand Down
Loading