@@ -60,7 +60,7 @@ async def test_boto3_session_parameters(
6060 mock_read , mock_write , mock_get_session = mock_streams
6161
6262 with patch ('boto3.Session' , return_value = mock_session ) as mock_boto :
63- with patch ('mcp_proxy_for_aws.client.streamablehttp_client ' ) as mock_stream_client :
63+ with patch ('mcp_proxy_for_aws.client.streamable_http_client ' ) as mock_stream_client :
6464 mock_stream_client .return_value .__aenter__ = AsyncMock (
6565 return_value = (mock_read , mock_write , mock_get_session )
6666 )
@@ -94,9 +94,14 @@ async def test_sigv4_auth_is_created_and_used(mock_session, mock_streams, servic
9494
9595 with patch ('boto3.Session' , return_value = mock_session ):
9696 with patch ('mcp_proxy_for_aws.client.SigV4HTTPXAuth' ) as mock_auth_cls :
97- with patch ('mcp_proxy_for_aws.client.streamablehttp_client ' ) as mock_stream_client :
97+ with patch ('mcp_proxy_for_aws.client.streamable_http_client ' ) as mock_stream_client :
9898 mock_auth = Mock ()
9999 mock_auth_cls .return_value = mock_auth
100+
101+ # Mock the factory to capture its calls
102+ mock_http_client = Mock ()
103+ mock_factory = Mock (return_value = mock_http_client )
104+
100105 mock_stream_client .return_value .__aenter__ = AsyncMock (
101106 return_value = (mock_read , mock_write , mock_get_session )
102107 )
@@ -106,17 +111,22 @@ async def test_sigv4_auth_is_created_and_used(mock_session, mock_streams, servic
106111 endpoint = 'https://test.example.com/mcp' ,
107112 aws_service = service_name ,
108113 aws_region = region ,
114+ httpx_client_factory = mock_factory ,
109115 ):
110116 pass
111117
112118 mock_auth_cls .assert_called_once_with (
113119 # Auth should be constructed with the resolved credentials, service, and region,
114- # and passed into the streamable client.
120+ # and passed to the httpx client factory .
115121 mock_session .get_credentials .return_value ,
116122 service_name ,
117123 region ,
118124 )
119- assert mock_stream_client .call_args [1 ]['auth' ] is mock_auth
125+ # Check that factory was called with auth
126+ assert mock_factory .called
127+ assert mock_factory .call_args [1 ]['auth' ] is mock_auth
128+ # Check that http_client was passed to streamable_http_client
129+ assert mock_stream_client .call_args [1 ]['http_client' ] is mock_http_client
120130
121131
122132@pytest .mark .asyncio
@@ -137,7 +147,10 @@ async def test_streamable_client_parameters(
137147 mock_read , mock_write , mock_get_session = mock_streams
138148
139149 with patch ('boto3.Session' , return_value = mock_session ):
140- with patch ('mcp_proxy_for_aws.client.streamablehttp_client' ) as mock_stream_client :
150+ with patch ('mcp_proxy_for_aws.client.streamable_http_client' ) as mock_stream_client :
151+ mock_http_client = Mock ()
152+ mock_factory = Mock (return_value = mock_http_client )
153+
141154 mock_stream_client .return_value .__aenter__ = AsyncMock (
142155 return_value = (mock_read , mock_write , mock_get_session )
143156 )
@@ -150,16 +163,30 @@ async def test_streamable_client_parameters(
150163 timeout = timeout_value ,
151164 sse_read_timeout = sse_value ,
152165 terminate_on_close = terminate_value ,
166+ httpx_client_factory = mock_factory ,
153167 ):
154168 pass
155169
156- call_kwargs = mock_stream_client .call_args [1 ]
157- # Confirm each parameter is forwarded unchanged.
158- assert call_kwargs ['url' ] == 'https://test.example.com/mcp'
159- assert call_kwargs ['headers' ] == headers
160- assert call_kwargs ['timeout' ] == timeout_value
161- assert call_kwargs ['sse_read_timeout' ] == sse_value
162- assert call_kwargs ['terminate_on_close' ] == terminate_value
170+ # Check that factory was called with headers and timeout
171+ assert mock_factory .called
172+ factory_kwargs = mock_factory .call_args [1 ]
173+ assert factory_kwargs ['headers' ] == headers
174+ # Check timeout conversion
175+ if isinstance (timeout_value , timedelta ):
176+ expected_timeout = timeout_value .total_seconds ()
177+ else :
178+ expected_timeout = timeout_value
179+ # httpx.Timeout sets all timeout types (connect, read, write, pool) to the same value
180+ assert factory_kwargs ['timeout' ].connect == expected_timeout
181+ assert factory_kwargs ['timeout' ].read == expected_timeout
182+ assert factory_kwargs ['timeout' ].write == expected_timeout
183+ assert factory_kwargs ['timeout' ].pool == expected_timeout
184+
185+ # Check streamable_http_client was called correctly
186+ stream_kwargs = mock_stream_client .call_args [1 ]
187+ assert stream_kwargs ['url' ] == 'https://test.example.com/mcp'
188+ assert stream_kwargs ['http_client' ] is mock_http_client
189+ assert stream_kwargs ['terminate_on_close' ] == terminate_value
163190
164191
165192@pytest .mark .asyncio
@@ -170,7 +197,9 @@ async def test_custom_httpx_client_factory_is_passed(mock_session, mock_streams)
170197 custom_factory = Mock ()
171198
172199 with patch ('boto3.Session' , return_value = mock_session ):
173- with patch ('mcp_proxy_for_aws.client.streamablehttp_client' ) as mock_stream_client :
200+ with patch ('mcp_proxy_for_aws.client.streamable_http_client' ) as mock_stream_client :
201+ mock_http_client = Mock ()
202+ custom_factory .return_value = mock_http_client
174203 mock_stream_client .return_value .__aenter__ = AsyncMock (
175204 return_value = (mock_read , mock_write , mock_get_session )
176205 )
@@ -183,7 +212,10 @@ async def test_custom_httpx_client_factory_is_passed(mock_session, mock_streams)
183212 ):
184213 pass
185214
186- assert mock_stream_client .call_args [1 ]['httpx_client_factory' ] is custom_factory
215+ # Check that the custom factory was called
216+ assert custom_factory .called
217+ # Check that the http_client from custom factory was passed to streamable_http_client
218+ assert mock_stream_client .call_args [1 ]['http_client' ] is mock_http_client
187219
188220
189221@pytest .mark .asyncio
@@ -198,7 +230,7 @@ async def mock_aexit(*_):
198230 cleanup_called = True
199231
200232 with patch ('boto3.Session' , return_value = mock_session ):
201- with patch ('mcp_proxy_for_aws.client.streamablehttp_client ' ) as mock_stream_client :
233+ with patch ('mcp_proxy_for_aws.client.streamable_http_client ' ) as mock_stream_client :
202234 mock_stream_client .return_value .__aenter__ = AsyncMock (
203235 return_value = (mock_read , mock_write , mock_get_session )
204236 )
@@ -220,7 +252,7 @@ async def test_credentials_parameter_with_region(mock_streams):
220252 creds = Credentials ('test_key' , 'test_secret' , 'test_token' )
221253
222254 with patch ('mcp_proxy_for_aws.client.SigV4HTTPXAuth' ) as mock_auth_cls :
223- with patch ('mcp_proxy_for_aws.client.streamablehttp_client ' ) as mock_stream_client :
255+ with patch ('mcp_proxy_for_aws.client.streamable_http_client ' ) as mock_stream_client :
224256 mock_auth = Mock ()
225257 mock_auth_cls .return_value = mock_auth
226258 mock_stream_client .return_value .__aenter__ = AsyncMock (
@@ -264,7 +296,7 @@ async def test_credentials_parameter_bypasses_boto3_session(mock_streams):
264296
265297 with patch ('boto3.Session' ) as mock_boto :
266298 with patch ('mcp_proxy_for_aws.client.SigV4HTTPXAuth' ):
267- with patch ('mcp_proxy_for_aws.client.streamablehttp_client ' ) as mock_stream_client :
299+ with patch ('mcp_proxy_for_aws.client.streamable_http_client ' ) as mock_stream_client :
268300 mock_stream_client .return_value .__aenter__ = AsyncMock (
269301 return_value = (mock_read , mock_write , mock_get_session )
270302 )
0 commit comments