Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions mcp_proxy_for_aws/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@

from importlib.metadata import version as _metadata_version

import mcp_proxy_for_aws.fastmcp_patch as _fastmcp_patch


__all__ = ['__version__']
__version__ = _metadata_version('mcp-proxy-for-aws')
48 changes: 0 additions & 48 deletions mcp_proxy_for_aws/fastmcp_patch.py

This file was deleted.

4 changes: 2 additions & 2 deletions mcp_proxy_for_aws/middleware/initialize_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ def __init__(self, client_factory: AWSMCPProxyClientFactory) -> None:
async def on_initialize(
self,
context: MiddlewareContext[mt.InitializeRequest],
call_next: CallNext[mt.InitializeRequest, None],
) -> None:
call_next: CallNext[mt.InitializeRequest, mt.InitializeResult | None],
) -> mt.InitializeResult | None:
try:
logger.debug('Received initialize request %s.', context.message)
self._client_factory.set_init_params(context.message)
Expand Down
2 changes: 1 addition & 1 deletion mcp_proxy_for_aws/proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ class AWSMCPProxy(_FastMCPProxy):
def __init__(
self,
*,
client_factory: ClientFactoryT | None = None,
client_factory: ClientFactoryT,
**kwargs,
):
"""Initialize a client."""
Expand Down
2 changes: 2 additions & 0 deletions mcp_proxy_for_aws/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def client_factory(
headers: Optional[Dict[str, str]] = None,
timeout: Optional[httpx.Timeout] = None,
auth: Optional[httpx.Auth] = None,
**kw,
) -> httpx.AsyncClient:
return create_sigv4_client(
service=service,
Expand All @@ -66,6 +67,7 @@ def client_factory(
timeout=custom_timeout,
metadata=metadata,
auth=auth,
**kw,
)

return StreamableHttpTransport(
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ description = "MCP Proxy for AWS"
readme = "README.md"
requires-python = ">=3.10,<3.14"
dependencies = [
"fastmcp (>=2.13.1,<2.14.1)",
"fastmcp~=2.14.1",
"boto3>=1.41.0",
"botocore[crt]>=1.41.0",
]
Expand Down
101 changes: 101 additions & 0 deletions tests/unit/test_fastmcp_initialize_error_handling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import mcp.types as mt
import pytest
from fastmcp import FastMCP
from fastmcp.client import Client
from fastmcp.server.middleware import Middleware, MiddlewareContext
from mcp import McpError
from mcp.types import ErrorData


@pytest.mark.asyncio
async def test_fastmcp_handles_initialize_error_from_middleware():
"""Test that fastmcp properly handles McpError raised during initialization middleware.
This validates that the fix from https://github.com/jlowin/fastmcp/pull/2531 works,
ensuring that initialization errors are sent back to the client instead of crashing.
"""

class InitializeErrorMiddleware(Middleware):
"""Middleware that raises an error during initialization."""

async def on_initialize(
self,
context: MiddlewareContext[mt.InitializeRequest],
call_next,
):
raise McpError(ErrorData(code=-1, message='Initialization failed from middleware'))

server = FastMCP('test-server')
server.add_middleware(InitializeErrorMiddleware())

@server.tool()
def test_tool() -> str:
"""A test tool."""
return 'success'

client = Client(server)

# The client should receive the error during initialization
with pytest.raises(McpError) as exc_info:
async with client:
pass

# Verify the error contains our custom error data
assert exc_info.value.error.code == -1
assert exc_info.value.error.message == 'Initialization failed from middleware'


@pytest.mark.asyncio
async def test_fastmcp_handles_error_after_initialization_completes():
"""Test that fastmcp handles McpError raised AFTER initialization completes.
This validates that when an error is raised after call_next (when responder is already
completed), fastmcp logs a warning but doesn't crash. The client receives the successful
initialization response, not the error.
This is a current limitation of fastmcp - errors raised after call_next cannot be sent
to the client because the response has already been sent.
"""
server = FastMCP('test-server')

class PostInitializeErrorMiddleware(Middleware):
"""Middleware that raises an error AFTER initialization completes."""

async def on_initialize(
self,
context: MiddlewareContext[mt.InitializeRequest],
call_next,
):
await call_next(context)
# Raising error after call_next - responder is already completed
raise McpError(ErrorData(code=-1, message='Error after initialization'))

server.add_middleware(PostInitializeErrorMiddleware())

@server.tool()
def test_tool() -> str:
"""A test tool."""
return 'success'

client = Client(server)

# Client should still initialize successfully because the error happens after response is sent
async with client:
# Verify we can list tools - initialization succeeded despite the error
tools = await client.list_tools()
assert len(tools) > 0
assert tools[0].name == 'test_tool'
120 changes: 0 additions & 120 deletions tests/unit/test_fastmcp_patch.py

This file was deleted.

35 changes: 35 additions & 0 deletions tests/unit/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,41 @@ def test_create_transport_with_sigv4_no_profile(
# If we can't access the factory directly, just verify the transport was created
assert result is not None

@patch('mcp_proxy_for_aws.utils.create_aws_session')
@patch('mcp_proxy_for_aws.utils.create_sigv4_client')
def test_create_transport_with_sigv4_kwargs_passthrough(
self, mock_create_sigv4_client, mock_create_session
):
"""Test that kwargs are passed through to create_sigv4_client."""
from httpx import Timeout

mock_session = MagicMock()
mock_create_session.return_value = mock_session

url = 'https://test-service.us-west-2.api.aws/mcp'
service = 'test-service'
region = 'test-region'
metadata = {'AWS_REGION': 'test-region'}
custom_timeout = Timeout(60.0)

result = create_transport_with_sigv4(url, service, region, metadata, custom_timeout)

assert hasattr(result, 'httpx_client_factory')
factory = result.httpx_client_factory
assert factory is not None
factory(headers=None, timeout=None, auth=None, follow_redirects=True) # type: ignore[call-arg]

mock_create_sigv4_client.assert_called_once_with(
service=service,
session=mock_session,
region=region,
headers=None,
timeout=custom_timeout,
auth=None,
metadata=metadata,
follow_redirects=True,
)


class TestValidateRequiredArgs:
"""Test cases for validate_service_name function."""
Expand Down
Loading
Loading