@@ -250,3 +250,112 @@ async def mock_server():
250250
251251 # Assert that the default client info was sent
252252 assert received_client_info == DEFAULT_CLIENT_INFO
253+
254+
255+ @pytest .mark .anyio
256+ async def test_client_session_progress ():
257+ client_to_server_send , client_to_server_receive = anyio .create_memory_object_stream [
258+ SessionMessage
259+ ](1 )
260+ server_to_client_send , server_to_client_receive = anyio .create_memory_object_stream [
261+ SessionMessage
262+ ](1 )
263+
264+ async def mock_server ():
265+ session_message = await client_to_server_receive .receive ()
266+ jsonrpc_request = session_message .message
267+ assert isinstance (jsonrpc_request .root , JSONRPCRequest )
268+ request = ClientRequest .model_validate (
269+ jsonrpc_request .model_dump (by_alias = True , mode = "json" , exclude_none = True )
270+ )
271+ assert isinstance (request .root , types .CallToolRequest )
272+ assert request .root .params .meta
273+ assert request .root .params .meta .progressToken is not None
274+
275+ progress_token = request .root .params .meta .progressToken
276+
277+ notifications = [
278+ types .ServerNotification (
279+ root = types .ProgressNotification (
280+ params = types .ProgressNotificationParams (
281+ progressToken = progress_token , progress = 1
282+ ),
283+ method = "notifications/progress" ,
284+ )
285+ ),
286+ types .ServerNotification (
287+ root = types .ProgressNotification (
288+ params = types .ProgressNotificationParams (
289+ progressToken = progress_token , progress = 2
290+ ),
291+ method = "notifications/progress" ,
292+ )
293+ ),
294+ ]
295+ result = ServerResult (types .CallToolResult (content = []))
296+
297+ async with server_to_client_send :
298+ for notification in notifications :
299+ await server_to_client_send .send (
300+ SessionMessage (
301+ JSONRPCMessage (
302+ types .JSONRPCNotification (
303+ jsonrpc = "2.0" ,
304+ ** notification .model_dump (
305+ by_alias = True , mode = "json" , exclude_none = True
306+ ),
307+ )
308+ )
309+ )
310+ )
311+ await server_to_client_send .send (
312+ SessionMessage (
313+ JSONRPCMessage (
314+ JSONRPCResponse (
315+ jsonrpc = "2.0" ,
316+ id = jsonrpc_request .root .id ,
317+ result = result .model_dump (
318+ by_alias = True , mode = "json" , exclude_none = True
319+ ),
320+ )
321+ )
322+ )
323+ )
324+
325+ # Create a message handler to catch exceptions
326+ async def message_handler (
327+ message : RequestResponder [types .ServerRequest , types .ClientResult ]
328+ | types .ServerNotification
329+ | Exception ,
330+ ) -> None :
331+ if isinstance (message , Exception ):
332+ raise message
333+
334+ progress_count = 0
335+
336+ async with (
337+ ClientSession (
338+ server_to_client_receive ,
339+ client_to_server_send ,
340+ message_handler = message_handler ,
341+ ) as session ,
342+ anyio .create_task_group () as tg ,
343+ client_to_server_send ,
344+ client_to_server_receive ,
345+ server_to_client_send ,
346+ server_to_client_receive ,
347+ ):
348+ tg .start_soon (mock_server )
349+
350+ async def progress_callback (params : types .ProgressNotificationParams ):
351+ nonlocal progress_count
352+ progress_count = progress_count + 1
353+
354+ result = await session .call_tool (
355+ "tool_with_progress" , progress_callback = progress_callback
356+ )
357+
358+ # Assert the result
359+ assert isinstance (result , types .CallToolResult )
360+ assert len (result .content ) == 0
361+ assert progress_count == 2
0 commit comments