1- from collections .abc import AsyncGenerator
21from typing import Any
32
43from agent import CurrencyAgent
54from helpers import (
6- create_task_obj ,
75 process_streaming_agent_response ,
86 update_task_with_agent_response ,
97)
8+ from typing_extensions import override
109
11- from a2a .server import AgentExecutor , TaskStore
10+ from a2a .server .agent_execution import BaseAgentExecutor
11+ from a2a .server .events .event_queue import EventQueue
1212from a2a .types import (
13- CancelTaskRequest ,
14- CancelTaskResponse ,
15- JSONRPCErrorResponse ,
1613 MessageSendParams ,
1714 SendMessageRequest ,
18- SendMessageResponse ,
19- SendMessageStreamingRequest ,
20- SendMessageStreamingResponse ,
21- SendMessageStreamingSuccessResponse ,
22- SendMessageSuccessResponse ,
15+ SendStreamingMessageRequest ,
2316 Task ,
24- TaskNotCancelableError ,
25- TaskResubscriptionRequest ,
2617 TextPart ,
27- UnsupportedOperationError ,
2818)
19+ from a2a .utils import create_task_obj
2920
3021
31- class CurrencyAgentExecutor (AgentExecutor ):
22+ class CurrencyAgentExecutor (BaseAgentExecutor ):
3223 """Currency AgentExecutor Example."""
3324
34- def __init__ (self , task_store : TaskStore ):
25+ def __init__ (self ):
3526 self .agent = CurrencyAgent ()
36- self .task_store = task_store
3727
28+ @override
3829 async def on_message_send (
39- self , request : SendMessageRequest , task : Task | None
40- ) -> SendMessageResponse :
30+ self ,
31+ request : SendMessageRequest ,
32+ event_queue : EventQueue ,
33+ task : Task | None ,
34+ ) -> None :
4135 """Handler for 'message/send' requests."""
4236 params : MessageSendParams = request .params
4337 query = self ._get_user_query (params )
4438
4539 if not task :
4640 task = create_task_obj (params )
47- await self .task_store .save (task )
4841
4942 # invoke the underlying agent
5043 agent_response : dict [str , Any ] = self .agent .invoke (
5144 query , task .contextId
5245 )
53-
5446 update_task_with_agent_response (task , agent_response )
55- return SendMessageResponse (
56- root = SendMessageSuccessResponse (id = request .id , result = task )
57- )
58-
59- async def on_message_stream ( # type: ignore
60- self , request : SendMessageStreamingRequest , task : Task | None
61- ) -> AsyncGenerator [SendMessageStreamingResponse , None ]:
47+ event_queue .enqueue_event (task )
48+
49+ @override
50+ async def on_message_stream (
51+ self ,
52+ request : SendStreamingMessageRequest ,
53+ event_queue : EventQueue ,
54+ task : Task | None ,
55+ ) -> None :
6256 """Handler for 'message/sendStream' requests."""
6357 params : MessageSendParams = request .params
6458 query = self ._get_user_query (params )
6559
6660 if not task :
6761 task = create_task_obj (params )
68- await self .task_store .save (task )
62+ # emit the initial task so it is persisted to TaskStore
63+ event_queue .enqueue_event (task )
6964
7065 # kickoff the streaming agent and process responses
7166 async for item in self .agent .stream (query , task .contextId ):
@@ -74,37 +69,9 @@ async def on_message_stream( # type: ignore
7469 )
7570
7671 if task_artifact_update_event :
77- yield SendMessageStreamingResponse (
78- root = SendMessageStreamingSuccessResponse (
79- id = request .id , result = task_artifact_update_event
80- )
81- )
82-
83- yield SendMessageStreamingResponse (
84- root = SendMessageStreamingSuccessResponse (
85- id = request .id , result = task_status_event
86- )
87- )
72+ event_queue .enqueue_event (task_artifact_update_event )
8873
89- async def on_cancel (
90- self , request : CancelTaskRequest , task : Task
91- ) -> CancelTaskResponse :
92- """Handler for 'tasks/cancel' requests."""
93- return CancelTaskResponse (
94- root = JSONRPCErrorResponse (
95- id = request .id , error = TaskNotCancelableError ()
96- )
97- )
98-
99- async def on_resubscribe ( # type: ignore
100- self , request : TaskResubscriptionRequest , task : Task
101- ) -> AsyncGenerator [SendMessageStreamingResponse , None ]:
102- """Handler for 'tasks/resubscribe' requests."""
103- yield SendMessageStreamingResponse (
104- root = JSONRPCErrorResponse (
105- id = request .id , error = UnsupportedOperationError ()
106- )
107- )
74+ event_queue .enqueue_event (task_status_event )
10875
10976 def _get_user_query (self , task_send_params : MessageSendParams ) -> str :
11077 """Helper to get user query from task send params."""
0 commit comments