|
1 | 1 | import asyncio |
| 2 | +import logging |
2 | 3 | import time |
3 | 4 |
|
4 | 5 | from unittest.mock import ( |
|
50 | 51 | TextPart, |
51 | 52 | UnsupportedOperationError, |
52 | 53 | ) |
| 54 | +from a2a.utils import ( |
| 55 | + new_task, |
| 56 | +) |
53 | 57 |
|
54 | 58 |
|
55 | 59 | class DummyAgentExecutor(AgentExecutor): |
@@ -579,6 +583,79 @@ async def test_on_message_send_task_id_mismatch(): |
579 | 583 | assert 'Task ID mismatch' in exc_info.value.error.message # type: ignore |
580 | 584 |
|
581 | 585 |
|
| 586 | +class HelloAgentExecutor(AgentExecutor): |
| 587 | + async def execute(self, context: RequestContext, event_queue: EventQueue): |
| 588 | + task = context.current_task |
| 589 | + if not task: |
| 590 | + assert context.message is not None, ( |
| 591 | + 'A message is required to create a new task' |
| 592 | + ) |
| 593 | + task = new_task(context.message) # type: ignore |
| 594 | + await event_queue.enqueue_event(task) |
| 595 | + updater = TaskUpdater(event_queue, task.id, task.context_id) |
| 596 | + |
| 597 | + try: |
| 598 | + parts = [Part(root=TextPart(text='I am working'))] |
| 599 | + await updater.update_status( |
| 600 | + TaskState.working, |
| 601 | + message=updater.new_agent_message(parts), |
| 602 | + ) |
| 603 | + except Exception as e: |
| 604 | + # Stop processing when the event loop is closed |
| 605 | + logging.warning('Error: %s', e) |
| 606 | + return |
| 607 | + await updater.add_artifact( |
| 608 | + [Part(root=TextPart(text='Hello world!'))], |
| 609 | + name='conversion_result', |
| 610 | + ) |
| 611 | + await updater.complete() |
| 612 | + |
| 613 | + async def cancel(self, context: RequestContext, event_queue: EventQueue): |
| 614 | + pass |
| 615 | + |
| 616 | + |
| 617 | +@pytest.mark.asyncio |
| 618 | +async def test_on_message_send_non_blocking(): |
| 619 | + task_store = InMemoryTaskStore() |
| 620 | + push_store = InMemoryPushNotificationConfigStore() |
| 621 | + |
| 622 | + request_handler = DefaultRequestHandler( |
| 623 | + agent_executor=HelloAgentExecutor(), |
| 624 | + task_store=task_store, |
| 625 | + push_config_store=push_store, |
| 626 | + ) |
| 627 | + params = MessageSendParams( |
| 628 | + message=Message( |
| 629 | + role=Role.user, |
| 630 | + message_id='msg_push', |
| 631 | + parts=[Part(root=TextPart(text='Hi'))], |
| 632 | + ), |
| 633 | + configuration=MessageSendConfiguration( |
| 634 | + blocking=False, accepted_output_modes=['text/plain'] |
| 635 | + ), |
| 636 | + ) |
| 637 | + |
| 638 | + result = await request_handler.on_message_send( |
| 639 | + params, create_server_call_context() |
| 640 | + ) |
| 641 | + |
| 642 | + assert result is not None |
| 643 | + assert isinstance(result, Task) |
| 644 | + assert result.status.state == TaskState.submitted |
| 645 | + |
| 646 | + # Polling for 500ms until task is completed. |
| 647 | + task: Task | None = None |
| 648 | + for _ in range(5): |
| 649 | + await asyncio.sleep(0.1) |
| 650 | + task = await task_store.get(result.id) |
| 651 | + assert task is not None |
| 652 | + if task.status.state == TaskState.completed: |
| 653 | + break |
| 654 | + |
| 655 | + assert task is not None |
| 656 | + assert task.status.state == TaskState.completed |
| 657 | + |
| 658 | + |
582 | 659 | @pytest.mark.asyncio |
583 | 660 | async def test_on_message_send_interrupted_flow(): |
584 | 661 | """Test on_message_send when flow is interrupted (e.g., auth_required).""" |
|
0 commit comments