Skip to content

Commit 73dbc3f

Browse files
committed
Add tests
1 parent bcd51e3 commit 73dbc3f

File tree

1 file changed

+105
-5
lines changed

1 file changed

+105
-5
lines changed

tests/server/request_handlers/test_jsonrpc_handler.py

Lines changed: 105 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import pytest
99
import httpx
1010

11-
from a2a.server.agent_execution import AgentExecutor
11+
from a2a.server.agent_execution import AgentExecutor, RequestContext
1212
from a2a.server.events import (
1313
QueueManager,
1414
)
@@ -55,6 +55,7 @@
5555
TaskStatusUpdateEvent,
5656
TextPart,
5757
UnsupportedOperationError,
58+
InternalError,
5859
)
5960
from a2a.utils.errors import ServerError
6061

@@ -183,7 +184,12 @@ async def test_on_cancel_task_not_found(self) -> None:
183184
mock_task_store.get.assert_called_once_with('nonexistent_id')
184185
mock_agent_executor.cancel.assert_not_called()
185186

186-
async def test_on_message_new_message_success(self) -> None:
187+
@patch(
188+
'a2a.server.agent_execution.simple_request_context_builder.SimpleRequestContextBuilder.build'
189+
)
190+
async def test_on_message_new_message_success(
191+
self, _mock_builder_build: AsyncMock
192+
) -> None:
187193
mock_agent_executor = AsyncMock(spec=AgentExecutor)
188194
mock_task_store = AsyncMock(spec=TaskStore)
189195
request_handler = DefaultRequestHandler(
@@ -194,6 +200,14 @@ async def test_on_message_new_message_success(self) -> None:
194200
mock_task_store.get.return_value = mock_task
195201
mock_agent_executor.execute.return_value = None
196202

203+
_mock_builder_build.return_value = RequestContext(
204+
request=MagicMock(),
205+
task_id='task_123',
206+
context_id='session-xyz',
207+
task=None,
208+
related_tasks=None,
209+
)
210+
197211
async def streaming_coro():
198212
yield mock_task
199213

@@ -279,15 +293,28 @@ async def streaming_coro():
279293
assert response.root.error == UnsupportedOperationError() # type: ignore
280294
mock_agent_executor.execute.assert_called_once()
281295

282-
async def test_on_message_stream_new_message_success(self) -> None:
296+
@patch(
297+
'a2a.server.agent_execution.simple_request_context_builder.SimpleRequestContextBuilder.build'
298+
)
299+
async def test_on_message_stream_new_message_success(
300+
self, _mock_builder_build: AsyncMock
301+
) -> None:
283302
mock_agent_executor = AsyncMock(spec=AgentExecutor)
284303
mock_task_store = AsyncMock(spec=TaskStore)
285304
request_handler = DefaultRequestHandler(
286305
mock_agent_executor, mock_task_store
287306
)
288-
self.mock_agent_card.capabilities = AgentCapabilities(streaming=True)
289307

308+
self.mock_agent_card.capabilities = AgentCapabilities(streaming=True)
290309
handler = JSONRPCHandler(self.mock_agent_card, request_handler)
310+
_mock_builder_build.return_value = RequestContext(
311+
request=MagicMock(),
312+
task_id='task_123',
313+
context_id='session-xyz',
314+
task=None,
315+
related_tasks=None,
316+
)
317+
291318
events: list[Any] = [
292319
Task(**MINIMAL_TASK),
293320
TaskArtifactUpdateEvent(
@@ -462,8 +489,11 @@ async def test_get_push_notification_success(self) -> None:
462489
)
463490
assert get_response.root.result == task_push_config # type: ignore
464491

492+
@patch(
493+
'a2a.server.agent_execution.simple_request_context_builder.SimpleRequestContextBuilder.build'
494+
)
465495
async def test_on_message_stream_new_message_send_push_notification_success(
466-
self,
496+
self, _mock_builder_build: AsyncMock
467497
) -> None:
468498
mock_agent_executor = AsyncMock(spec=AgentExecutor)
469499
mock_task_store = AsyncMock(spec=TaskStore)
@@ -475,6 +505,13 @@ async def test_on_message_stream_new_message_send_push_notification_success(
475505
self.mock_agent_card.capabilities = AgentCapabilities(
476506
streaming=True, pushNotifications=True
477507
)
508+
_mock_builder_build.return_value = RequestContext(
509+
request=MagicMock(),
510+
task_id='task_123',
511+
context_id='session-xyz',
512+
task=None,
513+
related_tasks=None,
514+
)
478515

479516
handler = JSONRPCHandler(self.mock_agent_card, request_handler)
480517
events: list[Any] = [
@@ -642,3 +679,66 @@ async def test_on_resubscribe_no_existing_task_error(self) -> None:
642679
assert len(collected_events) == 1
643680
self.assertIsInstance(collected_events[0].root, JSONRPCErrorResponse)
644681
assert collected_events[0].root.error == TaskNotFoundError()
682+
683+
async def test_on_message_send_task_id_mismatch(self) -> None:
684+
mock_agent_executor = AsyncMock(spec=AgentExecutor)
685+
mock_task_store = AsyncMock(spec=TaskStore)
686+
request_handler = DefaultRequestHandler(
687+
mock_agent_executor, mock_task_store
688+
)
689+
handler = JSONRPCHandler(self.mock_agent_card, request_handler)
690+
mock_task = Task(**MINIMAL_TASK)
691+
mock_task_store.get.return_value = mock_task
692+
mock_agent_executor.execute.return_value = None
693+
694+
async def streaming_coro():
695+
yield mock_task
696+
697+
with patch(
698+
'a2a.server.request_handlers.default_request_handler.EventConsumer.consume_all',
699+
return_value=streaming_coro(),
700+
):
701+
request = SendMessageRequest(
702+
id='1',
703+
params=MessageSendParams(message=Message(**MESSAGE_PAYLOAD)),
704+
)
705+
response = await handler.on_message_send(request)
706+
assert mock_agent_executor.execute.call_count == 1
707+
self.assertIsInstance(response.root, JSONRPCErrorResponse)
708+
self.assertIsInstance(response.root.error, InternalError) # type: ignore
709+
710+
async def test_on_message_stream_task_id_mismatch(self) -> None:
711+
mock_agent_executor = AsyncMock(spec=AgentExecutor)
712+
mock_task_store = AsyncMock(spec=TaskStore)
713+
request_handler = DefaultRequestHandler(
714+
mock_agent_executor, mock_task_store
715+
)
716+
717+
self.mock_agent_card.capabilities = AgentCapabilities(streaming=True)
718+
handler = JSONRPCHandler(self.mock_agent_card, request_handler)
719+
events: list[Any] = [Task(**MINIMAL_TASK)]
720+
721+
async def streaming_coro():
722+
for event in events:
723+
yield event
724+
725+
with patch(
726+
'a2a.server.request_handlers.default_request_handler.EventConsumer.consume_all',
727+
return_value=streaming_coro(),
728+
):
729+
mock_task_store.get.return_value = None
730+
mock_agent_executor.execute.return_value = None
731+
request = SendStreamingMessageRequest(
732+
id='1',
733+
params=MessageSendParams(message=Message(**MESSAGE_PAYLOAD)),
734+
)
735+
response = handler.on_message_send_stream(request)
736+
assert isinstance(response, AsyncGenerator)
737+
collected_events: list[Any] = []
738+
async for event in response:
739+
collected_events.append(event)
740+
assert len(collected_events) == 1
741+
self.assertIsInstance(
742+
collected_events[0].root, JSONRPCErrorResponse
743+
)
744+
self.assertIsInstance(collected_events[0].root.error, InternalError)

0 commit comments

Comments
 (0)