5050 TextPart ,
5151 UnsupportedOperationError ,
5252)
53-
53+ from a2a .utils import (
54+ new_task ,
55+ )
5456
5557class DummyAgentExecutor (AgentExecutor ):
5658 async def execute (self , context : RequestContext , event_queue : EventQueue ):
@@ -75,7 +77,6 @@ async def _run(self):
7577 async def cancel (self , context : RequestContext , event_queue : EventQueue ):
7678 pass
7779
78-
7980# Helper to create a simple task for tests
8081def create_sample_task (
8182 task_id = 'task1' , status_state = TaskState .submitted , context_id = 'ctx1'
@@ -579,6 +580,75 @@ async def test_on_message_send_task_id_mismatch():
579580 assert 'Task ID mismatch' in exc_info .value .error .message # type: ignore
580581
581582
583+ class HelloAgentExecutor (AgentExecutor ):
584+ async def execute (self , context : RequestContext , event_queue : EventQueue ):
585+ task = context .current_task
586+ if not task :
587+ task = new_task (context .message ) # type: ignore
588+ await event_queue .enqueue_event (task )
589+ updater = TaskUpdater (event_queue , task .id , task .context_id )
590+
591+ try :
592+ parts = [Part (root = TextPart (text = f'I am working' ))]
593+ await updater .update_status (
594+ TaskState .working ,
595+ message = updater .new_agent_message (parts ),
596+ )
597+ except RuntimeError as e :
598+ # Stop processing when the event loop is closed
599+ print ("Runtim error" , e )
600+ return
601+ await updater .add_artifact (
602+ [Part (root = TextPart (text = "Hello world!" ))],
603+ name = 'conversion_result' ,
604+ )
605+ await updater .complete ()
606+
607+ async def cancel (self , context : RequestContext , event_queue : EventQueue ):
608+ pass
609+
610+ @pytest .mark .asyncio
611+ async def test_on_message_send_non_blocking ():
612+ task_store = InMemoryTaskStore ()
613+ push_store = InMemoryPushNotificationConfigStore ()
614+
615+ request_handler = DefaultRequestHandler (
616+ agent_executor = HelloAgentExecutor (),
617+ task_store = task_store ,
618+ push_config_store = push_store ,
619+ )
620+ params = MessageSendParams (
621+ message = Message (
622+ role = Role .user ,
623+ message_id = 'msg_push' ,
624+ parts = [Part (root = TextPart (text = f'Hi' ))]
625+ ),
626+ configuration = MessageSendConfiguration (
627+ blocking = False ,
628+ accepted_output_modes = ['text/plain' ]
629+ )
630+ )
631+
632+ result = await request_handler .on_message_send (
633+ params , create_server_call_context ()
634+ )
635+
636+ assert result is not None
637+ assert type (result ) == Task
638+ result .status .state = TaskState .submitted
639+
640+ # Polling for 500ms until task is completed.
641+ task : Task | None = None
642+ for _ in range (5 ):
643+ await asyncio .sleep (0.1 )
644+ task = await task_store .get (result .id )
645+ assert task is not None
646+ if task .status .state == TaskState .completed :
647+ break
648+
649+ assert task is not None
650+ assert task .status .state == TaskState .completed
651+
582652@pytest .mark .asyncio
583653async def test_on_message_send_interrupted_flow ():
584654 """Test on_message_send when flow is interrupted (e.g., auth_required)."""
0 commit comments