@@ -169,23 +169,25 @@ async def _run_event_stream(
169169 await self .agent_executor .execute (request , queue )
170170 await queue .close ()
171171
172- async def on_message_send (
172+ async def _setup_message_execution (
173173 self ,
174174 params : MessageSendParams ,
175175 context : ServerCallContext | None = None ,
176- ) -> Message | Task :
177- """Default handler for 'message/send' interface ( non-streaming) .
176+ ) -> tuple [ TaskManager , str , EventQueue , ResultAggregator , asyncio . Task ] :
177+ """Common setup logic for both streaming and non-streaming message handling .
178178
179- Starts the agent execution for the message and waits for the final
180- result (Task or Message).
179+ Returns:
180+ A tuple of (task_manager, task_id, queue, result_aggregator, producer_task)
181181 """
182+ # Create task manager and validate existing task
182183 task_manager = TaskManager (
183184 task_id = params .message .taskId ,
184185 context_id = params .message .contextId ,
185186 task_store = self .task_store ,
186187 initial_message = params .message ,
187188 )
188189 task : Task | None = await task_manager .get_task ()
190+
189191 if task :
190192 if task .status .state in TERMINAL_TASK_STATES :
191193 raise ServerError (
@@ -207,6 +209,8 @@ async def on_message_send(
207209 await self ._push_notifier .set_info (
208210 task .id , params .configuration .pushNotificationConfig
209211 )
212+
213+ # Build request context
210214 request_context = await self ._request_context_builder .build (
211215 params = params ,
212216 task_id = task .id if task else None ,
@@ -223,13 +227,49 @@ async def on_message_send(
223227 result_aggregator = ResultAggregator (task_manager )
224228 # TODO: to manage the non-blocking flows.
225229 producer_task = asyncio .create_task (
226- self ._run_event_stream (
227- request_context ,
228- queue ,
229- )
230+ self ._run_event_stream (request_context , queue )
230231 )
231232 await self ._register_producer (task_id , producer_task )
232233
234+ return task_manager , task_id , queue , result_aggregator , producer_task
235+
236+ def _validate_task_id_match (self , task_id : str , event_task_id : str ) -> None :
237+ """Validates that agent-generated task ID matches the expected task ID."""
238+ if task_id != event_task_id :
239+ logger .error (
240+ f'Agent generated task_id={ event_task_id } does not match the RequestContext task_id={ task_id } .'
241+ )
242+ raise ServerError (
243+ InternalError (message = 'Task ID mismatch in agent response' )
244+ )
245+
246+ async def _send_push_notification_if_needed (
247+ self , task_id : str , result_aggregator : ResultAggregator
248+ ) -> None :
249+ """Sends push notification if configured and task is available."""
250+ if self ._push_notifier and task_id :
251+ latest_task = await result_aggregator .current_result
252+ if isinstance (latest_task , Task ):
253+ await self ._push_notifier .send_notification (latest_task )
254+
255+ async def on_message_send (
256+ self ,
257+ params : MessageSendParams ,
258+ context : ServerCallContext | None = None ,
259+ ) -> Message | Task :
260+ """Default handler for 'message/send' interface (non-streaming).
261+
262+ Starts the agent execution for the message and waits for the final
263+ result (Task or Message).
264+ """
265+ (
266+ task_manager ,
267+ task_id ,
268+ queue ,
269+ result_aggregator ,
270+ producer_task ,
271+ ) = await self ._setup_message_execution (params , context )
272+
233273 consumer = EventConsumer (queue )
234274 producer_task .add_done_callback (consumer .agent_task_callback )
235275
@@ -242,18 +282,12 @@ async def on_message_send(
242282 if not result :
243283 raise ServerError (error = InternalError ())
244284
245- if isinstance (result , Task ) and task_id != result .id :
246- logger .error (
247- f'Agent generated task_id={ result .id } does not match the RequestContext task_id={ task_id } .'
248- )
249- raise ServerError (
250- InternalError (message = 'Task ID mismatch in agent response' )
251- )
285+ if isinstance (result , Task ):
286+ self ._validate_task_id_match (task_id , result .id )
252287
253- if self ._push_notifier and task_id :
254- latest_task = await result_aggregator .current_result
255- if isinstance (latest_task , Task ):
256- await self ._push_notifier .send_notification (latest_task )
288+ await self ._send_push_notification_if_needed (
289+ task_id , result_aggregator
290+ )
257291
258292 finally :
259293 if interrupted :
@@ -276,85 +310,34 @@ async def on_message_send_stream(
276310 Starts the agent execution and yields events as they are produced
277311 by the agent.
278312 """
279- task_manager = TaskManager (
280- task_id = params .message .taskId ,
281- context_id = params .message .contextId ,
282- task_store = self .task_store ,
283- initial_message = params .message ,
284- )
285- task : Task | None = await task_manager .get_task ()
286-
287- if task :
288- if task .status .state in TERMINAL_TASK_STATES :
289- raise ServerError (
290- error = InvalidParamsError (
291- message = f'Task { task .id } is in terminal state: { task .status .state } '
292- )
293- )
294-
295- task = task_manager .update_with_message (params .message , task )
296- if self .should_add_push_info (params ):
297- assert isinstance (self ._push_notifier , PushNotifier )
298- assert isinstance (
299- params .configuration , MessageSendConfiguration
300- )
301- assert isinstance (
302- params .configuration .pushNotificationConfig ,
303- PushNotificationConfig ,
304- )
305- await self ._push_notifier .set_info (
306- task .id , params .configuration .pushNotificationConfig
307- )
308- else :
309- queue = EventQueue ()
310- result_aggregator = ResultAggregator (task_manager )
311- request_context = await self ._request_context_builder .build (
312- params = params ,
313- task_id = task .id if task else None ,
314- context_id = params .message .contextId ,
315- task = task ,
316- context = context ,
317- )
318-
319- task_id = cast ('str' , request_context .task_id )
320- queue = await self ._queue_manager .create_or_tap (task_id )
321- producer_task = asyncio .create_task (
322- self ._run_event_stream (
323- request_context ,
324- queue ,
325- )
326- )
327- await self ._register_producer (task_id , producer_task )
313+ (
314+ task_manager ,
315+ task_id ,
316+ queue ,
317+ result_aggregator ,
318+ producer_task ,
319+ ) = await self ._setup_message_execution (params , context )
328320
329321 try :
330322 consumer = EventConsumer (queue )
331323 producer_task .add_done_callback (consumer .agent_task_callback )
332324 async for event in result_aggregator .consume_and_emit (consumer ):
333325 if isinstance (event , Task ):
334- if task_id != event .id :
335- logger .error (
336- f'Agent generated task_id={ event .id } does not match the RequestContext task_id={ task_id } .'
337- )
338- raise ServerError (
339- InternalError (
340- message = 'Task ID mismatch in agent response'
341- )
342- )
343-
344- if (
345- self ._push_notifier
346- and params .configuration
347- and params .configuration .pushNotificationConfig
348- ):
349- await self ._push_notifier .set_info (
350- task_id ,
351- params .configuration .pushNotificationConfig ,
352- )
353-
354- if self ._push_notifier and task_id :
355- latest_task = await result_aggregator .current_result
356- if isinstance (latest_task , Task ):
357- await self ._push_notifier .send_notification (latest_task )
326+ self ._validate_task_id_match (task_id , event .id )
327+
328+ if (
329+ self ._push_notifier
330+ and params .configuration
331+ and params .configuration .pushNotificationConfig
332+ ):
333+ await self ._push_notifier .set_info (
334+ task_id ,
335+ params .configuration .pushNotificationConfig ,
336+ )
337+
338+ await self ._send_push_notification_if_needed (
339+ task_id , result_aggregator
340+ )
358341 yield event
359342 finally :
360343 await self ._cleanup_producer (producer_task , task_id )
0 commit comments