@@ -60,7 +60,7 @@ async def test_boto3_session_parameters(
6060 mock_read , mock_write , mock_get_session = mock_streams
6161
6262 with patch ('boto3.Session' , return_value = mock_session ) as mock_boto :
63- with patch ('mcp_proxy_for_aws.client.streamablehttp_client ' ) as mock_stream_client :
63+ with patch ('mcp_proxy_for_aws.client.streamable_http_client ' ) as mock_stream_client :
6464 mock_stream_client .return_value .__aenter__ = AsyncMock (
6565 return_value = (mock_read , mock_write , mock_get_session )
6666 )
@@ -94,9 +94,11 @@ async def test_sigv4_auth_is_created_and_used(mock_session, mock_streams, servic
9494
9595 with patch ('boto3.Session' , return_value = mock_session ):
9696 with patch ('mcp_proxy_for_aws.client.SigV4HTTPXAuth' ) as mock_auth_cls :
97- with patch ('mcp_proxy_for_aws.client.streamablehttp_client ' ) as mock_stream_client :
97+ with patch ('mcp_proxy_for_aws.client.streamable_http_client ' ) as mock_stream_client :
9898 mock_auth = Mock ()
9999 mock_auth_cls .return_value = mock_auth
100+ mock_http_client = Mock ()
101+ mock_factory = Mock (return_value = mock_http_client )
100102 mock_stream_client .return_value .__aenter__ = AsyncMock (
101103 return_value = (mock_read , mock_write , mock_get_session )
102104 )
@@ -106,17 +108,20 @@ async def test_sigv4_auth_is_created_and_used(mock_session, mock_streams, servic
106108 endpoint = 'https://test.example.com/mcp' ,
107109 aws_service = service_name ,
108110 aws_region = region ,
111+ httpx_client_factory = mock_factory ,
109112 ):
110113 pass
111114
112115 mock_auth_cls .assert_called_once_with (
113- # Auth should be constructed with the resolved credentials, service, and region,
114- # and passed into the streamable client.
116+ # Auth should be constructed with the resolved credentials, service, and region
115117 mock_session .get_credentials .return_value ,
116118 service_name ,
117119 region ,
118120 )
119- assert mock_stream_client .call_args [1 ]['auth' ] is mock_auth
121+ # Auth should be passed to the httpx client factory
122+ assert mock_factory .call_args [1 ]['auth' ] is mock_auth
123+ # The created http client should be passed to streamable_http_client
124+ assert mock_stream_client .call_args [1 ]['http_client' ] is mock_http_client
120125
121126
122127@pytest .mark .asyncio
@@ -132,12 +137,14 @@ async def test_streamable_client_parameters(
132137 mock_session , mock_streams , headers , timeout_value , sse_value , terminate_value
133138):
134139 """Test the correctness of streamablehttp_client parameters."""
135- # Verify that connection settings are forwarded as-is to the streamable HTTP client.
136- # timedelta values are allowed and compared directly here .
140+ # Verify that connection settings are forwarded correctly to the httpx client factory
141+ # and streamable HTTP client .
137142 mock_read , mock_write , mock_get_session = mock_streams
138143
139144 with patch ('boto3.Session' , return_value = mock_session ):
140- with patch ('mcp_proxy_for_aws.client.streamablehttp_client' ) as mock_stream_client :
145+ with patch ('mcp_proxy_for_aws.client.streamable_http_client' ) as mock_stream_client :
146+ mock_http_client = Mock ()
147+ mock_factory = Mock (return_value = mock_http_client )
141148 mock_stream_client .return_value .__aenter__ = AsyncMock (
142149 return_value = (mock_read , mock_write , mock_get_session )
143150 )
@@ -150,27 +157,34 @@ async def test_streamable_client_parameters(
150157 timeout = timeout_value ,
151158 sse_read_timeout = sse_value ,
152159 terminate_on_close = terminate_value ,
160+ httpx_client_factory = mock_factory ,
153161 ):
154162 pass
155163
156- call_kwargs = mock_stream_client .call_args [1 ]
157- # Confirm each parameter is forwarded unchanged.
158- assert call_kwargs ['url' ] == 'https://test.example.com/mcp'
159- assert call_kwargs ['headers' ] == headers
160- assert call_kwargs ['timeout' ] == timeout_value
161- assert call_kwargs ['sse_read_timeout' ] == sse_value
162- assert call_kwargs ['terminate_on_close' ] == terminate_value
164+ # Verify headers and auth are passed to the factory
165+ factory_call_kwargs = mock_factory .call_args [1 ]
166+ assert factory_call_kwargs ['headers' ] == headers
167+ # Timeout is passed to the factory (converted to httpx.Timeout)
168+ assert factory_call_kwargs ['timeout' ] is not None
169+
170+ # Verify the created http client and other params are passed to streamable_http_client
171+ stream_call_kwargs = mock_stream_client .call_args [1 ]
172+ assert stream_call_kwargs ['url' ] == 'https://test.example.com/mcp'
173+ assert stream_call_kwargs ['http_client' ] is mock_http_client
174+ assert stream_call_kwargs ['terminate_on_close' ] == terminate_value
163175
164176
165177@pytest .mark .asyncio
166178async def test_custom_httpx_client_factory_is_passed (mock_session , mock_streams ):
167179 """Test the passing of a custom HTTPX client factory."""
168- # The factory should be handed through to the underlying streamable client untouched .
180+ # The factory should be used to create the http client.
169181 mock_read , mock_write , mock_get_session = mock_streams
170182 custom_factory = Mock ()
183+ mock_http_client = Mock ()
184+ custom_factory .return_value = mock_http_client
171185
172186 with patch ('boto3.Session' , return_value = mock_session ):
173- with patch ('mcp_proxy_for_aws.client.streamablehttp_client ' ) as mock_stream_client :
187+ with patch ('mcp_proxy_for_aws.client.streamable_http_client ' ) as mock_stream_client :
174188 mock_stream_client .return_value .__aenter__ = AsyncMock (
175189 return_value = (mock_read , mock_write , mock_get_session )
176190 )
@@ -183,7 +197,10 @@ async def test_custom_httpx_client_factory_is_passed(mock_session, mock_streams)
183197 ):
184198 pass
185199
186- assert mock_stream_client .call_args [1 ]['httpx_client_factory' ] is custom_factory
200+ # Verify the custom factory was called
201+ custom_factory .assert_called_once ()
202+ # Verify the http client from the factory was passed to streamable_http_client
203+ assert mock_stream_client .call_args [1 ]['http_client' ] is mock_http_client
187204
188205
189206@pytest .mark .asyncio
@@ -198,7 +215,7 @@ async def mock_aexit(*_):
198215 cleanup_called = True
199216
200217 with patch ('boto3.Session' , return_value = mock_session ):
201- with patch ('mcp_proxy_for_aws.client.streamablehttp_client ' ) as mock_stream_client :
218+ with patch ('mcp_proxy_for_aws.client.streamable_http_client ' ) as mock_stream_client :
202219 mock_stream_client .return_value .__aenter__ = AsyncMock (
203220 return_value = (mock_read , mock_write , mock_get_session )
204221 )
@@ -220,7 +237,7 @@ async def test_credentials_parameter_with_region(mock_streams):
220237 creds = Credentials ('test_key' , 'test_secret' , 'test_token' )
221238
222239 with patch ('mcp_proxy_for_aws.client.SigV4HTTPXAuth' ) as mock_auth_cls :
223- with patch ('mcp_proxy_for_aws.client.streamablehttp_client ' ) as mock_stream_client :
240+ with patch ('mcp_proxy_for_aws.client.streamable_http_client ' ) as mock_stream_client :
224241 mock_auth = Mock ()
225242 mock_auth_cls .return_value = mock_auth
226243 mock_stream_client .return_value .__aenter__ = AsyncMock (
@@ -264,7 +281,7 @@ async def test_credentials_parameter_bypasses_boto3_session(mock_streams):
264281
265282 with patch ('boto3.Session' ) as mock_boto :
266283 with patch ('mcp_proxy_for_aws.client.SigV4HTTPXAuth' ):
267- with patch ('mcp_proxy_for_aws.client.streamablehttp_client ' ) as mock_stream_client :
284+ with patch ('mcp_proxy_for_aws.client.streamable_http_client ' ) as mock_stream_client :
268285 mock_stream_client .return_value .__aenter__ = AsyncMock (
269286 return_value = (mock_read , mock_write , mock_get_session )
270287 )
0 commit comments