Skip to content

Commit bf89546

Browse files
committed
simplify test
1 parent f561854 commit bf89546

File tree

1 file changed

+11
-13
lines changed

1 file changed

+11
-13
lines changed

tests/client/test_session_group.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from mcp import types
88
from mcp.client.session_group import (
99
ClientSessionGroup,
10+
ServerParameters,
1011
SseServerParameters,
1112
StreamableHttpParameters,
1213
)
@@ -279,45 +280,42 @@ async def test_disconnect_non_existent_server(self):
279280
await group.disconnect_from_server(session)
280281

281282
@pytest.mark.parametrize(
282-
"server_params_instance, client_type_name, patch_target_for_client_func",
283+
"server_params_instance, patch_target_for_client_func",
283284
[
284285
(
285286
StdioServerParameters(command="test_stdio_cmd"),
286-
"stdio",
287287
"mcp.client.session_group.mcp.stdio_client",
288288
),
289289
(
290290
SseServerParameters(url="http://test.com/sse", timeout=10),
291-
"sse",
292291
"mcp.client.session_group.sse_client",
293292
), # url, headers, timeout, sse_read_timeout
294293
(
295294
StreamableHttpParameters(
296295
url="http://test.com/stream", terminate_on_close=False
297296
),
298-
"streamablehttp",
299297
"mcp.client.session_group.streamable_http_client",
300298
), # url, headers, timeout, sse_read_timeout, terminate_on_close
301299
],
302300
)
303301
async def test_establish_session_parameterized(
304302
self,
305-
server_params_instance,
306-
client_type_name, # Just for clarity or conditional logic if needed
307-
patch_target_for_client_func,
303+
server_params_instance: ServerParameters,
304+
patch_target_for_client_func: str,
308305
):
309306
with mock.patch(
310307
"mcp.client.session_group.mcp.ClientSession"
311308
) as mock_ClientSession_class:
312309
with mock.patch(patch_target_for_client_func) as mock_specific_client_func:
310+
client_type_name = server_params_instance.__class__.__name__
313311
mock_client_cm_instance = mock.AsyncMock(
314312
name=f"{client_type_name}ClientCM"
315313
)
316314
mock_read_stream = mock.AsyncMock(name=f"{client_type_name}Read")
317315
mock_write_stream = mock.AsyncMock(name=f"{client_type_name}Write")
318316

319317
# streamable_http_client's __aenter__ returns three values
320-
if client_type_name == "streamablehttp":
318+
if isinstance(server_params_instance, StreamableHttpParameters):
321319
mock_extra_stream_val = mock.AsyncMock(name="StreamableExtra")
322320
mock_client_cm_instance.__aenter__.return_value = (
323321
mock_read_stream,
@@ -363,23 +361,23 @@ async def test_establish_session_parameterized(
363361

364362
# --- Assertions ---
365363
# 1. Assert the correct specific client function was called
366-
if client_type_name == "stdio":
364+
if isinstance(server_params_instance, StdioServerParameters):
367365
mock_specific_client_func.assert_called_once_with(
368366
server_params_instance
369367
)
370-
elif client_type_name == "sse":
368+
elif isinstance(server_params_instance, SseServerParameters):
371369
mock_specific_client_func.assert_called_once_with(
372370
url=server_params_instance.url,
373371
headers=server_params_instance.headers,
374372
timeout=server_params_instance.timeout,
375373
sse_read_timeout=server_params_instance.sse_read_timeout,
376374
)
377-
elif client_type_name == "streamablehttp":
375+
else:
378376
mock_specific_client_func.assert_called_once_with(
379377
url=server_params_instance.url,
380378
headers=server_params_instance.headers,
381-
timeout=server_params_instance.timeout,
382-
sse_read_timeout=server_params_instance.sse_read_timeout,
379+
timeout=server_params_instance.timeout.total_seconds(),
380+
sse_read_timeout=server_params_instance.sse_read_timeout.total_seconds(),
383381
terminate_on_close=server_params_instance.terminate_on_close,
384382
)
385383

0 commit comments

Comments
 (0)