88import pytest
99import httpx
1010
11- from a2a .server .agent_execution import AgentExecutor
11+ from a2a .server .agent_execution import AgentExecutor , RequestContext
1212from a2a .server .events import (
1313 QueueManager ,
1414)
5555 TaskStatusUpdateEvent ,
5656 TextPart ,
5757 UnsupportedOperationError ,
58+ InternalError ,
5859)
5960from a2a .utils .errors import ServerError
6061
@@ -183,7 +184,12 @@ async def test_on_cancel_task_not_found(self) -> None:
183184 mock_task_store .get .assert_called_once_with ('nonexistent_id' )
184185 mock_agent_executor .cancel .assert_not_called ()
185186
186- async def test_on_message_new_message_success (self ) -> None :
187+ @patch (
188+ 'a2a.server.agent_execution.simple_request_context_builder.SimpleRequestContextBuilder.build'
189+ )
190+ async def test_on_message_new_message_success (
191+ self , _mock_builder_build : AsyncMock
192+ ) -> None :
187193 mock_agent_executor = AsyncMock (spec = AgentExecutor )
188194 mock_task_store = AsyncMock (spec = TaskStore )
189195 request_handler = DefaultRequestHandler (
@@ -194,6 +200,14 @@ async def test_on_message_new_message_success(self) -> None:
194200 mock_task_store .get .return_value = mock_task
195201 mock_agent_executor .execute .return_value = None
196202
203+ _mock_builder_build .return_value = RequestContext (
204+ request = MagicMock (),
205+ task_id = 'task_123' ,
206+ context_id = 'session-xyz' ,
207+ task = None ,
208+ related_tasks = None ,
209+ )
210+
197211 async def streaming_coro ():
198212 yield mock_task
199213
@@ -279,15 +293,28 @@ async def streaming_coro():
279293 assert response .root .error == UnsupportedOperationError () # type: ignore
280294 mock_agent_executor .execute .assert_called_once ()
281295
282- async def test_on_message_stream_new_message_success (self ) -> None :
296+ @patch (
297+ 'a2a.server.agent_execution.simple_request_context_builder.SimpleRequestContextBuilder.build'
298+ )
299+ async def test_on_message_stream_new_message_success (
300+ self , _mock_builder_build : AsyncMock
301+ ) -> None :
283302 mock_agent_executor = AsyncMock (spec = AgentExecutor )
284303 mock_task_store = AsyncMock (spec = TaskStore )
285304 request_handler = DefaultRequestHandler (
286305 mock_agent_executor , mock_task_store
287306 )
288- self .mock_agent_card .capabilities = AgentCapabilities (streaming = True )
289307
308+ self .mock_agent_card .capabilities = AgentCapabilities (streaming = True )
290309 handler = JSONRPCHandler (self .mock_agent_card , request_handler )
310+ _mock_builder_build .return_value = RequestContext (
311+ request = MagicMock (),
312+ task_id = 'task_123' ,
313+ context_id = 'session-xyz' ,
314+ task = None ,
315+ related_tasks = None ,
316+ )
317+
291318 events : list [Any ] = [
292319 Task (** MINIMAL_TASK ),
293320 TaskArtifactUpdateEvent (
@@ -462,8 +489,11 @@ async def test_get_push_notification_success(self) -> None:
462489 )
463490 assert get_response .root .result == task_push_config # type: ignore
464491
492+ @patch (
493+ 'a2a.server.agent_execution.simple_request_context_builder.SimpleRequestContextBuilder.build'
494+ )
465495 async def test_on_message_stream_new_message_send_push_notification_success (
466- self ,
496+ self , _mock_builder_build : AsyncMock
467497 ) -> None :
468498 mock_agent_executor = AsyncMock (spec = AgentExecutor )
469499 mock_task_store = AsyncMock (spec = TaskStore )
@@ -475,6 +505,13 @@ async def test_on_message_stream_new_message_send_push_notification_success(
475505 self .mock_agent_card .capabilities = AgentCapabilities (
476506 streaming = True , pushNotifications = True
477507 )
508+ _mock_builder_build .return_value = RequestContext (
509+ request = MagicMock (),
510+ task_id = 'task_123' ,
511+ context_id = 'session-xyz' ,
512+ task = None ,
513+ related_tasks = None ,
514+ )
478515
479516 handler = JSONRPCHandler (self .mock_agent_card , request_handler )
480517 events : list [Any ] = [
@@ -642,3 +679,66 @@ async def test_on_resubscribe_no_existing_task_error(self) -> None:
642679 assert len (collected_events ) == 1
643680 self .assertIsInstance (collected_events [0 ].root , JSONRPCErrorResponse )
644681 assert collected_events [0 ].root .error == TaskNotFoundError ()
682+
683+ async def test_on_message_send_task_id_mismatch (self ) -> None :
684+ mock_agent_executor = AsyncMock (spec = AgentExecutor )
685+ mock_task_store = AsyncMock (spec = TaskStore )
686+ request_handler = DefaultRequestHandler (
687+ mock_agent_executor , mock_task_store
688+ )
689+ handler = JSONRPCHandler (self .mock_agent_card , request_handler )
690+ mock_task = Task (** MINIMAL_TASK )
691+ mock_task_store .get .return_value = mock_task
692+ mock_agent_executor .execute .return_value = None
693+
694+ async def streaming_coro ():
695+ yield mock_task
696+
697+ with patch (
698+ 'a2a.server.request_handlers.default_request_handler.EventConsumer.consume_all' ,
699+ return_value = streaming_coro (),
700+ ):
701+ request = SendMessageRequest (
702+ id = '1' ,
703+ params = MessageSendParams (message = Message (** MESSAGE_PAYLOAD )),
704+ )
705+ response = await handler .on_message_send (request )
706+ assert mock_agent_executor .execute .call_count == 1
707+ self .assertIsInstance (response .root , JSONRPCErrorResponse )
708+ self .assertIsInstance (response .root .error , InternalError ) # type: ignore
709+
710+ async def test_on_message_stream_task_id_mismatch (self ) -> None :
711+ mock_agent_executor = AsyncMock (spec = AgentExecutor )
712+ mock_task_store = AsyncMock (spec = TaskStore )
713+ request_handler = DefaultRequestHandler (
714+ mock_agent_executor , mock_task_store
715+ )
716+
717+ self .mock_agent_card .capabilities = AgentCapabilities (streaming = True )
718+ handler = JSONRPCHandler (self .mock_agent_card , request_handler )
719+ events : list [Any ] = [Task (** MINIMAL_TASK )]
720+
721+ async def streaming_coro ():
722+ for event in events :
723+ yield event
724+
725+ with patch (
726+ 'a2a.server.request_handlers.default_request_handler.EventConsumer.consume_all' ,
727+ return_value = streaming_coro (),
728+ ):
729+ mock_task_store .get .return_value = None
730+ mock_agent_executor .execute .return_value = None
731+ request = SendStreamingMessageRequest (
732+ id = '1' ,
733+ params = MessageSendParams (message = Message (** MESSAGE_PAYLOAD )),
734+ )
735+ response = handler .on_message_send_stream (request )
736+ assert isinstance (response , AsyncGenerator )
737+ collected_events : list [Any ] = []
738+ async for event in response :
739+ collected_events .append (event )
740+ assert len (collected_events ) == 1
741+ self .assertIsInstance (
742+ collected_events [0 ].root , JSONRPCErrorResponse
743+ )
744+ self .assertIsInstance (collected_events [0 ].root .error , InternalError )
0 commit comments