88import httpx
99import pytest
1010
11- from a2a .server .agent_execution import AgentExecutor
11+
12+ from a2a .server .agent_execution import AgentExecutor , RequestContext
1213from a2a .server .agent_execution .request_context_builder import (
1314 RequestContextBuilder ,
1415)
5960 TaskStatusUpdateEvent ,
6061 TextPart ,
6162 UnsupportedOperationError ,
63+ InternalError ,
6264)
6365from a2a .utils .errors import ServerError
6466
@@ -188,7 +190,12 @@ async def test_on_cancel_task_not_found(self) -> None:
188190 mock_task_store .get .assert_called_once_with ('nonexistent_id' )
189191 mock_agent_executor .cancel .assert_not_called ()
190192
191- async def test_on_message_new_message_success (self ) -> None :
193+ @patch (
194+ 'a2a.server.agent_execution.simple_request_context_builder.SimpleRequestContextBuilder.build'
195+ )
196+ async def test_on_message_new_message_success (
197+ self , _mock_builder_build : AsyncMock
198+ ) -> None :
192199 mock_agent_executor = AsyncMock (spec = AgentExecutor )
193200 mock_task_store = AsyncMock (spec = TaskStore )
194201 request_handler = DefaultRequestHandler (
@@ -199,6 +206,14 @@ async def test_on_message_new_message_success(self) -> None:
199206 mock_task_store .get .return_value = mock_task
200207 mock_agent_executor .execute .return_value = None
201208
209+ _mock_builder_build .return_value = RequestContext (
210+ request = MagicMock (),
211+ task_id = 'task_123' ,
212+ context_id = 'session-xyz' ,
213+ task = None ,
214+ related_tasks = None ,
215+ )
216+
202217 async def streaming_coro ():
203218 yield mock_task
204219
@@ -284,15 +299,28 @@ async def streaming_coro():
284299 assert response .root .error == UnsupportedOperationError () # type: ignore
285300 mock_agent_executor .execute .assert_called_once ()
286301
287- async def test_on_message_stream_new_message_success (self ) -> None :
302+ @patch (
303+ 'a2a.server.agent_execution.simple_request_context_builder.SimpleRequestContextBuilder.build'
304+ )
305+ async def test_on_message_stream_new_message_success (
306+ self , _mock_builder_build : AsyncMock
307+ ) -> None :
288308 mock_agent_executor = AsyncMock (spec = AgentExecutor )
289309 mock_task_store = AsyncMock (spec = TaskStore )
290310 request_handler = DefaultRequestHandler (
291311 mock_agent_executor , mock_task_store
292312 )
293- self .mock_agent_card .capabilities = AgentCapabilities (streaming = True )
294313
314+ self .mock_agent_card .capabilities = AgentCapabilities (streaming = True )
295315 handler = JSONRPCHandler (self .mock_agent_card , request_handler )
316+ _mock_builder_build .return_value = RequestContext (
317+ request = MagicMock (),
318+ task_id = 'task_123' ,
319+ context_id = 'session-xyz' ,
320+ task = None ,
321+ related_tasks = None ,
322+ )
323+
296324 events : list [Any ] = [
297325 Task (** MINIMAL_TASK ),
298326 TaskArtifactUpdateEvent (
@@ -467,8 +495,11 @@ async def test_get_push_notification_success(self) -> None:
467495 )
468496 assert get_response .root .result == task_push_config # type: ignore
469497
498+ @patch (
499+ 'a2a.server.agent_execution.simple_request_context_builder.SimpleRequestContextBuilder.build'
500+ )
470501 async def test_on_message_stream_new_message_send_push_notification_success (
471- self ,
502+ self , _mock_builder_build : AsyncMock
472503 ) -> None :
473504 mock_agent_executor = AsyncMock (spec = AgentExecutor )
474505 mock_task_store = AsyncMock (spec = TaskStore )
@@ -480,6 +511,13 @@ async def test_on_message_stream_new_message_send_push_notification_success(
480511 self .mock_agent_card .capabilities = AgentCapabilities (
481512 streaming = True , pushNotifications = True
482513 )
514+ _mock_builder_build .return_value = RequestContext (
515+ request = MagicMock (),
516+ task_id = 'task_123' ,
517+ context_id = 'session-xyz' ,
518+ task = None ,
519+ related_tasks = None ,
520+ )
483521
484522 handler = JSONRPCHandler (self .mock_agent_card , request_handler )
485523 events : list [Any ] = [
@@ -738,7 +776,8 @@ async def test_on_get_push_notification_no_push_notifier(self) -> None:
738776
739777 # Assert
740778 self .assertIsInstance (response .root , JSONRPCErrorResponse )
741- self .assertEqual (response .root .error , UnsupportedOperationError ())
779+ self .assertEqual (response .root .error , UnsupportedOperationError ()) # type: ignore
780+
742781
743782 async def test_on_set_push_notification_no_push_notifier (self ) -> None :
744783 """Test set_push_notification with no push notifier configured."""
@@ -771,7 +810,8 @@ async def test_on_set_push_notification_no_push_notifier(self) -> None:
771810
772811 # Assert
773812 self .assertIsInstance (response .root , JSONRPCErrorResponse )
774- self .assertEqual (response .root .error , UnsupportedOperationError ())
813+ self .assertEqual (response .root .error , UnsupportedOperationError ()) # type: ignore
814+
775815
776816 async def test_on_message_send_internal_error (self ) -> None :
777817 """Test on_message_send with an internal error."""
@@ -800,7 +840,8 @@ async def raise_server_error(*args, **kwargs):
800840
801841 # Assert
802842 self .assertIsInstance (response .root , JSONRPCErrorResponse )
803- self .assertIsInstance (response .root .error , InternalError )
843+ self .assertIsInstance (response .root .error , InternalError ) # type: ignore
844+
804845
805846 async def test_on_message_stream_internal_error (self ) -> None :
806847 """Test on_message_send_stream with an internal error."""
@@ -906,3 +947,66 @@ async def consume_raises_error(*args, **kwargs):
906947 # Assert
907948 self .assertIsInstance (response .root , JSONRPCErrorResponse )
908949 self .assertEqual (response .root .error , UnsupportedOperationError ())
950+
951+ async def test_on_message_send_task_id_mismatch (self ) -> None :
952+ mock_agent_executor = AsyncMock (spec = AgentExecutor )
953+ mock_task_store = AsyncMock (spec = TaskStore )
954+ request_handler = DefaultRequestHandler (
955+ mock_agent_executor , mock_task_store
956+ )
957+ handler = JSONRPCHandler (self .mock_agent_card , request_handler )
958+ mock_task = Task (** MINIMAL_TASK )
959+ mock_task_store .get .return_value = mock_task
960+ mock_agent_executor .execute .return_value = None
961+
962+ async def streaming_coro ():
963+ yield mock_task
964+
965+ with patch (
966+ 'a2a.server.request_handlers.default_request_handler.EventConsumer.consume_all' ,
967+ return_value = streaming_coro (),
968+ ):
969+ request = SendMessageRequest (
970+ id = '1' ,
971+ params = MessageSendParams (message = Message (** MESSAGE_PAYLOAD )),
972+ )
973+ response = await handler .on_message_send (request )
974+ assert mock_agent_executor .execute .call_count == 1
975+ self .assertIsInstance (response .root , JSONRPCErrorResponse )
976+ self .assertIsInstance (response .root .error , InternalError ) # type: ignore
977+
978+ async def test_on_message_stream_task_id_mismatch (self ) -> None :
979+ mock_agent_executor = AsyncMock (spec = AgentExecutor )
980+ mock_task_store = AsyncMock (spec = TaskStore )
981+ request_handler = DefaultRequestHandler (
982+ mock_agent_executor , mock_task_store
983+ )
984+
985+ self .mock_agent_card .capabilities = AgentCapabilities (streaming = True )
986+ handler = JSONRPCHandler (self .mock_agent_card , request_handler )
987+ events : list [Any ] = [Task (** MINIMAL_TASK )]
988+
989+ async def streaming_coro ():
990+ for event in events :
991+ yield event
992+
993+ with patch (
994+ 'a2a.server.request_handlers.default_request_handler.EventConsumer.consume_all' ,
995+ return_value = streaming_coro (),
996+ ):
997+ mock_task_store .get .return_value = None
998+ mock_agent_executor .execute .return_value = None
999+ request = SendStreamingMessageRequest (
1000+ id = '1' ,
1001+ params = MessageSendParams (message = Message (** MESSAGE_PAYLOAD )),
1002+ )
1003+ response = handler .on_message_send_stream (request )
1004+ assert isinstance (response , AsyncGenerator )
1005+ collected_events : list [Any ] = []
1006+ async for event in response :
1007+ collected_events .append (event )
1008+ assert len (collected_events ) == 1
1009+ self .assertIsInstance (
1010+ collected_events [0 ].root , JSONRPCErrorResponse
1011+ )
1012+ self .assertIsInstance (collected_events [0 ].root .error , InternalError )
0 commit comments