Skip to content

Commit 2be38bd

Browse files
committed
chore: add test for non-blocking sendMessage
1 parent d37ce7b commit 2be38bd

File tree

1 file changed

+72
-2
lines changed

1 file changed

+72
-2
lines changed

tests/server/request_handlers/test_default_request_handler.py

Lines changed: 72 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,9 @@
5050
TextPart,
5151
UnsupportedOperationError,
5252
)
53-
53+
from a2a.utils import (
54+
new_task,
55+
)
5456

5557
class DummyAgentExecutor(AgentExecutor):
5658
async def execute(self, context: RequestContext, event_queue: EventQueue):
@@ -75,7 +77,6 @@ async def _run(self):
7577
async def cancel(self, context: RequestContext, event_queue: EventQueue):
7678
pass
7779

78-
7980
# Helper to create a simple task for tests
8081
def create_sample_task(
8182
task_id='task1', status_state=TaskState.submitted, context_id='ctx1'
@@ -579,6 +580,75 @@ async def test_on_message_send_task_id_mismatch():
579580
assert 'Task ID mismatch' in exc_info.value.error.message # type: ignore
580581

581582

583+
class HelloAgentExecutor(AgentExecutor):
584+
async def execute(self, context: RequestContext, event_queue: EventQueue):
585+
task = context.current_task
586+
if not task:
587+
task = new_task(context.message) # type: ignore
588+
await event_queue.enqueue_event(task)
589+
updater = TaskUpdater(event_queue, task.id, task.context_id)
590+
591+
try:
592+
parts = [Part(root=TextPart(text=f'I am working'))]
593+
await updater.update_status(
594+
TaskState.working,
595+
message=updater.new_agent_message(parts),
596+
)
597+
except RuntimeError as e:
598+
# Stop processing when the event loop is closed
599+
print("Runtim error", e)
600+
return
601+
await updater.add_artifact(
602+
[Part(root=TextPart(text="Hello world!"))],
603+
name='conversion_result',
604+
)
605+
await updater.complete()
606+
607+
async def cancel(self, context: RequestContext, event_queue: EventQueue):
608+
pass
609+
610+
@pytest.mark.asyncio
611+
async def test_on_message_send_non_blocking():
612+
task_store = InMemoryTaskStore()
613+
push_store = InMemoryPushNotificationConfigStore()
614+
615+
request_handler = DefaultRequestHandler(
616+
agent_executor=HelloAgentExecutor(),
617+
task_store=task_store,
618+
push_config_store=push_store,
619+
)
620+
params = MessageSendParams(
621+
message=Message(
622+
role=Role.user,
623+
message_id='msg_push',
624+
parts=[Part(root=TextPart(text=f'Hi'))]
625+
),
626+
configuration=MessageSendConfiguration(
627+
blocking=False,
628+
accepted_output_modes=['text/plain']
629+
)
630+
)
631+
632+
result = await request_handler.on_message_send(
633+
params, create_server_call_context()
634+
)
635+
636+
assert result is not None
637+
assert type(result) == Task
638+
result.status.state = TaskState.submitted
639+
640+
# Polling for 500ms until task is completed.
641+
task: Task | None = None
642+
for _ in range(5):
643+
await asyncio.sleep(0.1)
644+
task = await task_store.get(result.id)
645+
assert task is not None
646+
if task.status.state == TaskState.completed:
647+
break
648+
649+
assert task is not None
650+
assert task.status.state == TaskState.completed
651+
582652
@pytest.mark.asyncio
583653
async def test_on_message_send_interrupted_flow():
584654
"""Test on_message_send when flow is interrupted (e.g., auth_required)."""

0 commit comments

Comments
 (0)