1+ from contextlib import AsyncExitStack
2+ from unittest .mock import AsyncMock , Mock
3+
14import pytest
5+
26from mcp import types
37from mcp .server .lowlevel .result_cache import ResultCache
4- from unittest .mock import AsyncMock , Mock , patch
5- from contextlib import AsyncExitStack
8+
69
710@pytest .mark .anyio
811async def test_async_call ():
912 """Tests basic async call"""
13+
1014 async def test_call (call : types .CallToolRequest ) -> types .ServerResult :
11- return types .ServerResult (types .CallToolResult (
12- content = [types .TextContent (
13- type = "text" ,
14- text = "test"
15- )]
16- ))
17- async_call = types .CallToolAsyncRequest (
18- method = "tools/async/call" ,
19- params = types .CallToolAsyncRequestParams (
20- name = "test"
15+ return types .ServerResult (
16+ types .CallToolResult (content = [types .TextContent (type = "text" , text = "test" )])
2117 )
18+
19+ async_call = types .CallToolAsyncRequest (
20+ method = "tools/async/call" , params = types .CallToolAsyncRequestParams (name = "test" )
2221 )
2322
2423 mock_session = AsyncMock ()
@@ -27,37 +26,38 @@ async def test_call(call: types.CallToolRequest) -> types.ServerResult:
2726 result_cache = ResultCache (max_size = 1 , max_keep_alive = 1 )
2827 async with AsyncExitStack () as stack :
2928 await stack .enter_async_context (result_cache )
30- async_call_ref = await result_cache .start_call (test_call , async_call , mock_context )
29+ async_call_ref = await result_cache .start_call (
30+ test_call , async_call , mock_context
31+ )
3132 assert async_call_ref .token is not None
3233
33- result = await result_cache .get_result (types .GetToolAsyncResultRequest (
34- method = "tools/async/get" ,
35- params = types .GetToolAsyncResultRequestParams (
36- token = async_call_ref .token
34+ result = await result_cache .get_result (
35+ types .GetToolAsyncResultRequest (
36+ method = "tools/async/get" ,
37+ params = types .GetToolAsyncResultRequestParams (
38+ token = async_call_ref .token
39+ ),
3740 )
38- ))
41+ )
3942
4043 assert not result .isError
4144 assert not result .isPending
4245 assert len (result .content ) == 1
4346 assert type (result .content [0 ]) is types .TextContent
4447 assert result .content [0 ].text == "test"
4548
49+
4650@pytest .mark .anyio
4751async def test_async_join_call_progress ():
4852 """Tests basic async call"""
53+
4954 async def test_call (call : types .CallToolRequest ) -> types .ServerResult :
50- return types .ServerResult (types .CallToolResult (
51- content = [types .TextContent (
52- type = "text" ,
53- text = "test"
54- )]
55- ))
56- async_call = types .CallToolAsyncRequest (
57- method = "tools/async/call" ,
58- params = types .CallToolAsyncRequestParams (
59- name = "test"
55+ return types .ServerResult (
56+ types .CallToolResult (content = [types .TextContent (type = "text" , text = "test" )])
6057 )
58+
59+ async_call = types .CallToolAsyncRequest (
60+ method = "tools/async/call" , params = types .CallToolAsyncRequestParams (name = "test" )
6161 )
6262
6363 mock_session_1 = AsyncMock ()
@@ -73,38 +73,42 @@ async def test_call(call: types.CallToolRequest) -> types.ServerResult:
7373 result_cache = ResultCache (max_size = 1 , max_keep_alive = 1 )
7474 async with AsyncExitStack () as stack :
7575 await stack .enter_async_context (result_cache )
76- async_call_ref = await result_cache .start_call (test_call , async_call , mock_context_1 )
76+ async_call_ref = await result_cache .start_call (
77+ test_call , async_call , mock_context_1
78+ )
7779 assert async_call_ref .token is not None
7880
7981 await result_cache .join_call (
8082 req = types .JoinCallToolAsyncRequest (
8183 method = "tools/async/join" ,
8284 params = types .JoinCallToolRequestParams (
8385 token = async_call_ref .token ,
84- _meta = types .RequestParams .Meta (
85- progressToken = "test"
86- )
87- )
86+ _meta = types .RequestParams .Meta (progressToken = "test" ),
87+ ),
8888 ),
89- ctx = mock_context_2
89+ ctx = mock_context_2 ,
9090 )
9191 assert async_call_ref .token is not None
9292 await result_cache .notification_hook (
93- session = mock_session_1 ,
94- notification = types .ServerNotification (types .ProgressNotification (
95- method = "notifications/progress" ,
96- params = types .ProgressNotificationParams (
97- progressToken = "test" ,
98- progress = 1
93+ session = mock_session_1 ,
94+ notification = types .ServerNotification (
95+ types .ProgressNotification (
96+ method = "notifications/progress" ,
97+ params = types .ProgressNotificationParams (
98+ progressToken = "test" , progress = 1
99+ ),
99100 )
100- )))
101+ ),
102+ )
101103
102- result = await result_cache .get_result (types .GetToolAsyncResultRequest (
103- method = "tools/async/get" ,
104- params = types .GetToolAsyncResultRequestParams (
105- token = async_call_ref .token
104+ result = await result_cache .get_result (
105+ types .GetToolAsyncResultRequest (
106+ method = "tools/async/get" ,
107+ params = types .GetToolAsyncResultRequestParams (
108+ token = async_call_ref .token
109+ ),
106110 )
107- ))
111+ )
108112
109113 assert not result .isError
110114 assert not result .isPending
@@ -113,9 +117,9 @@ async def test_call(call: types.CallToolRequest) -> types.ServerResult:
113117 assert result .content [0 ].text == "test"
114118 mock_context_1 .send_progress_notification .assert_not_called ()
115119 mock_session_2 .send_progress_notification .assert_called_with (
116- progress_token = "test" ,
117- progress = 1.0 ,
118- total = None ,
119- message = None ,
120- resource_uri = None
120+ progress_token = "test" ,
121+ progress = 1.0 ,
122+ total = None ,
123+ message = None ,
124+ resource_uri = None ,
121125 )
0 commit comments