|
7 | 7 | from mcp import types |
8 | 8 | from mcp.client.session_group import ( |
9 | 9 | ClientSessionGroup, |
| 10 | + ServerParameters, |
10 | 11 | SseServerParameters, |
11 | 12 | StreamableHttpParameters, |
12 | 13 | ) |
@@ -279,45 +280,42 @@ async def test_disconnect_non_existent_server(self): |
279 | 280 | await group.disconnect_from_server(session) |
280 | 281 |
|
281 | 282 | @pytest.mark.parametrize( |
282 | | - "server_params_instance, client_type_name, patch_target_for_client_func", |
| 283 | + "server_params_instance, patch_target_for_client_func", |
283 | 284 | [ |
284 | 285 | ( |
285 | 286 | StdioServerParameters(command="test_stdio_cmd"), |
286 | | - "stdio", |
287 | 287 | "mcp.client.session_group.mcp.stdio_client", |
288 | 288 | ), |
289 | 289 | ( |
290 | 290 | SseServerParameters(url="http://test.com/sse", timeout=10), |
291 | | - "sse", |
292 | 291 | "mcp.client.session_group.sse_client", |
293 | 292 | ), # url, headers, timeout, sse_read_timeout |
294 | 293 | ( |
295 | 294 | StreamableHttpParameters( |
296 | 295 | url="http://test.com/stream", terminate_on_close=False |
297 | 296 | ), |
298 | | - "streamablehttp", |
299 | 297 | "mcp.client.session_group.streamable_http_client", |
300 | 298 | ), # url, headers, timeout, sse_read_timeout, terminate_on_close |
301 | 299 | ], |
302 | 300 | ) |
303 | 301 | async def test_establish_session_parameterized( |
304 | 302 | self, |
305 | | - server_params_instance, |
306 | | - client_type_name, # Just for clarity or conditional logic if needed |
307 | | - patch_target_for_client_func, |
| 303 | + server_params_instance: ServerParameters, |
| 304 | + patch_target_for_client_func: str, |
308 | 305 | ): |
309 | 306 | with mock.patch( |
310 | 307 | "mcp.client.session_group.mcp.ClientSession" |
311 | 308 | ) as mock_ClientSession_class: |
312 | 309 | with mock.patch(patch_target_for_client_func) as mock_specific_client_func: |
| 310 | + client_type_name = server_params_instance.__class__.__name__ |
313 | 311 | mock_client_cm_instance = mock.AsyncMock( |
314 | 312 | name=f"{client_type_name}ClientCM" |
315 | 313 | ) |
316 | 314 | mock_read_stream = mock.AsyncMock(name=f"{client_type_name}Read") |
317 | 315 | mock_write_stream = mock.AsyncMock(name=f"{client_type_name}Write") |
318 | 316 |
|
319 | 317 | # streamable_http_client's __aenter__ returns three values |
320 | | - if client_type_name == "streamablehttp": |
| 318 | + if isinstance(server_params_instance, StreamableHttpParameters): |
321 | 319 | mock_extra_stream_val = mock.AsyncMock(name="StreamableExtra") |
322 | 320 | mock_client_cm_instance.__aenter__.return_value = ( |
323 | 321 | mock_read_stream, |
@@ -363,23 +361,23 @@ async def test_establish_session_parameterized( |
363 | 361 |
|
364 | 362 | # --- Assertions --- |
365 | 363 | # 1. Assert the correct specific client function was called |
366 | | - if client_type_name == "stdio": |
| 364 | + if isinstance(server_params_instance, StdioServerParameters): |
367 | 365 | mock_specific_client_func.assert_called_once_with( |
368 | 366 | server_params_instance |
369 | 367 | ) |
370 | | - elif client_type_name == "sse": |
| 368 | + elif isinstance(server_params_instance, SseServerParameters): |
371 | 369 | mock_specific_client_func.assert_called_once_with( |
372 | 370 | url=server_params_instance.url, |
373 | 371 | headers=server_params_instance.headers, |
374 | 372 | timeout=server_params_instance.timeout, |
375 | 373 | sse_read_timeout=server_params_instance.sse_read_timeout, |
376 | 374 | ) |
377 | | - elif client_type_name == "streamablehttp": |
| 375 | + else: |
378 | 376 | mock_specific_client_func.assert_called_once_with( |
379 | 377 | url=server_params_instance.url, |
380 | 378 | headers=server_params_instance.headers, |
381 | | - timeout=server_params_instance.timeout, |
382 | | - sse_read_timeout=server_params_instance.sse_read_timeout, |
| 379 | + timeout=server_params_instance.timeout.total_seconds(), |
| 380 | + sse_read_timeout=server_params_instance.sse_read_timeout.total_seconds(), |
383 | 381 | terminate_on_close=server_params_instance.terminate_on_close, |
384 | 382 | ) |
385 | 383 |
|
|
0 commit comments