Skip to content

Commit f22e479

Browse files
test: add test for non-blocking sendMessage (#355)
Co-authored-by: Holt Skinner <[email protected]> Co-authored-by: Holt Skinner <[email protected]>
1 parent b469036 commit f22e479

File tree

1 file changed

+77
-0
lines changed

1 file changed

+77
-0
lines changed

tests/server/request_handlers/test_default_request_handler.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import asyncio
2+
import logging
23
import time
34

45
from unittest.mock import (
@@ -50,6 +51,9 @@
5051
TextPart,
5152
UnsupportedOperationError,
5253
)
54+
from a2a.utils import (
55+
new_task,
56+
)
5357

5458

5559
class DummyAgentExecutor(AgentExecutor):
@@ -579,6 +583,79 @@ async def test_on_message_send_task_id_mismatch():
579583
assert 'Task ID mismatch' in exc_info.value.error.message # type: ignore
580584

581585

586+
class HelloAgentExecutor(AgentExecutor):
587+
async def execute(self, context: RequestContext, event_queue: EventQueue):
588+
task = context.current_task
589+
if not task:
590+
assert context.message is not None, (
591+
'A message is required to create a new task'
592+
)
593+
task = new_task(context.message) # type: ignore
594+
await event_queue.enqueue_event(task)
595+
updater = TaskUpdater(event_queue, task.id, task.context_id)
596+
597+
try:
598+
parts = [Part(root=TextPart(text='I am working'))]
599+
await updater.update_status(
600+
TaskState.working,
601+
message=updater.new_agent_message(parts),
602+
)
603+
except Exception as e:
604+
# Stop processing when the event loop is closed
605+
logging.warning('Error: %s', e)
606+
return
607+
await updater.add_artifact(
608+
[Part(root=TextPart(text='Hello world!'))],
609+
name='conversion_result',
610+
)
611+
await updater.complete()
612+
613+
async def cancel(self, context: RequestContext, event_queue: EventQueue):
614+
pass
615+
616+
617+
@pytest.mark.asyncio
618+
async def test_on_message_send_non_blocking():
619+
task_store = InMemoryTaskStore()
620+
push_store = InMemoryPushNotificationConfigStore()
621+
622+
request_handler = DefaultRequestHandler(
623+
agent_executor=HelloAgentExecutor(),
624+
task_store=task_store,
625+
push_config_store=push_store,
626+
)
627+
params = MessageSendParams(
628+
message=Message(
629+
role=Role.user,
630+
message_id='msg_push',
631+
parts=[Part(root=TextPart(text='Hi'))],
632+
),
633+
configuration=MessageSendConfiguration(
634+
blocking=False, accepted_output_modes=['text/plain']
635+
),
636+
)
637+
638+
result = await request_handler.on_message_send(
639+
params, create_server_call_context()
640+
)
641+
642+
assert result is not None
643+
assert isinstance(result, Task)
644+
assert result.status.state == TaskState.submitted
645+
646+
# Polling for 500ms until task is completed.
647+
task: Task | None = None
648+
for _ in range(5):
649+
await asyncio.sleep(0.1)
650+
task = await task_store.get(result.id)
651+
assert task is not None
652+
if task.status.state == TaskState.completed:
653+
break
654+
655+
assert task is not None
656+
assert task.status.state == TaskState.completed
657+
658+
582659
@pytest.mark.asyncio
583660
async def test_on_message_send_interrupted_flow():
584661
"""Test on_message_send when flow is interrupted (e.g., auth_required)."""

0 commit comments

Comments
 (0)