Skip to content

Commit 9251b76

Browse files
committed
chore(tests): Add new tests to increase coverage
1 parent b952a14 commit 9251b76

File tree

7 files changed

+546
-9
lines changed

7 files changed

+546
-9
lines changed

tests/client/test_grpc_client.py

Lines changed: 312 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,312 @@
1+
from unittest.mock import AsyncMock, MagicMock
2+
3+
import grpc
4+
import pytest
5+
6+
from a2a.client import A2AGrpcClient
7+
from a2a.grpc import a2a_pb2, a2a_pb2_grpc
8+
from a2a.types import (
9+
AgentCapabilities,
10+
AgentCard,
11+
Message,
12+
MessageSendParams,
13+
Part,
14+
PushNotificationConfig,
15+
Role,
16+
Task,
17+
TaskArtifactUpdateEvent,
18+
TaskIdParams,
19+
TaskPushNotificationConfig,
20+
TaskQueryParams,
21+
TaskState,
22+
TaskStatus,
23+
TaskStatusUpdateEvent,
24+
TextPart,
25+
)
26+
from a2a.utils import proto_utils
27+
28+
29+
# Fixtures
30+
@pytest.fixture
31+
def mock_grpc_stub() -> AsyncMock:
32+
"""Provides a mock gRPC stub."""
33+
return AsyncMock(spec=a2a_pb2_grpc.A2AServiceStub)
34+
35+
36+
@pytest.fixture
37+
def sample_agent_card() -> AgentCard:
38+
"""Provides a minimal agent card for initialization."""
39+
return AgentCard(
40+
name='gRPC Test Agent',
41+
description='Agent for testing gRPC client',
42+
url='grpc://localhost:50051',
43+
version='1.0',
44+
capabilities=AgentCapabilities(streaming=True, pushNotifications=True),
45+
defaultInputModes=['text/plain'],
46+
defaultOutputModes=['text/plain'],
47+
skills=[],
48+
)
49+
50+
51+
@pytest.fixture
52+
def grpc_client(
53+
mock_grpc_stub: AsyncMock, sample_agent_card: AgentCard
54+
) -> A2AGrpcClient:
55+
"""Provides an A2AGrpcClient instance."""
56+
return A2AGrpcClient(grpc_stub=mock_grpc_stub, agent_card=sample_agent_card)
57+
58+
59+
@pytest.fixture
60+
def sample_message_send_params() -> MessageSendParams:
61+
"""Provides a sample MessageSendParams object."""
62+
return MessageSendParams(
63+
message=Message(
64+
role=Role.user,
65+
messageId='msg-1',
66+
parts=[Part(root=TextPart(text='Hello'))],
67+
)
68+
)
69+
70+
71+
@pytest.fixture
72+
def sample_task() -> Task:
73+
"""Provides a sample Task object."""
74+
return Task(
75+
id='task-1',
76+
contextId='ctx-1',
77+
status=TaskStatus(state=TaskState.completed),
78+
)
79+
80+
81+
@pytest.fixture
82+
def sample_message() -> Message:
83+
"""Provides a sample Message object."""
84+
return Message(
85+
role=Role.agent,
86+
messageId='msg-response',
87+
parts=[Part(root=TextPart(text='Hi there'))],
88+
)
89+
90+
91+
@pytest.mark.asyncio
92+
async def test_send_message_task_response(
93+
grpc_client: A2AGrpcClient,
94+
mock_grpc_stub: AsyncMock,
95+
sample_message_send_params: MessageSendParams,
96+
sample_task: Task,
97+
):
98+
"""Test send_message that returns a Task."""
99+
mock_grpc_stub.SendMessage.return_value = a2a_pb2.SendMessageResponse(
100+
task=proto_utils.ToProto.task(sample_task)
101+
)
102+
103+
response = await grpc_client.send_message(sample_message_send_params)
104+
105+
mock_grpc_stub.SendMessage.assert_awaited_once()
106+
assert isinstance(response, Task)
107+
assert response.id == sample_task.id
108+
109+
110+
@pytest.mark.asyncio
111+
async def test_send_message_message_response(
112+
grpc_client: A2AGrpcClient,
113+
mock_grpc_stub: AsyncMock,
114+
sample_message_send_params: MessageSendParams,
115+
sample_message: Message,
116+
):
117+
"""Test send_message that returns a Message."""
118+
mock_grpc_stub.SendMessage.return_value = a2a_pb2.SendMessageResponse(
119+
msg=proto_utils.ToProto.message(sample_message)
120+
)
121+
122+
response = await grpc_client.send_message(sample_message_send_params)
123+
124+
mock_grpc_stub.SendMessage.assert_awaited_once()
125+
assert isinstance(response, Message)
126+
assert response.messageId == sample_message.messageId
127+
128+
129+
@pytest.mark.asyncio
130+
async def test_send_message_streaming(
131+
grpc_client: A2AGrpcClient,
132+
mock_grpc_stub: AsyncMock,
133+
sample_message_send_params: MessageSendParams,
134+
):
135+
"""Test the streaming message functionality."""
136+
mock_stream = AsyncMock()
137+
138+
status_update = TaskStatusUpdateEvent(
139+
taskId='task-stream',
140+
contextId='ctx-stream',
141+
status=TaskStatus(state=TaskState.working),
142+
final=False,
143+
)
144+
artifact_update = TaskArtifactUpdateEvent(
145+
taskId='task-stream',
146+
contextId='ctx-stream',
147+
artifact=MagicMock(spec=types.Artifact),
148+
)
149+
final_task = Task(
150+
id='task-stream',
151+
contextId='ctx-stream',
152+
status=TaskStatus(state=TaskState.completed),
153+
)
154+
155+
stream_responses = [
156+
a2a_pb2.StreamResponse(
157+
status_update=proto_utils.ToProto.task_status_update_event(
158+
status_update
159+
)
160+
),
161+
a2a_pb2.StreamResponse(
162+
artifact_update=proto_utils.ToProto.task_artifact_update_event(
163+
artifact_update
164+
)
165+
),
166+
a2a_pb2.StreamResponse(task=proto_utils.ToProto.task(final_task)),
167+
grpc.aio.EOF,
168+
]
169+
170+
mock_stream.read.side_effect = stream_responses
171+
mock_grpc_stub.SendStreamingMessage.return_value = mock_stream
172+
173+
results = [
174+
result
175+
async for result in grpc_client.send_message_streaming(
176+
sample_message_send_params
177+
)
178+
]
179+
180+
mock_grpc_stub.SendStreamingMessage.assert_called_once()
181+
assert len(results) == 3
182+
assert isinstance(results[0], TaskStatusUpdateEvent)
183+
assert isinstance(results[1], TaskArtifactUpdateEvent)
184+
assert isinstance(results[2], Task)
185+
assert results[2].status.state == TaskState.completed
186+
187+
188+
@pytest.mark.asyncio
189+
async def test_get_task(
190+
grpc_client: A2AGrpcClient, mock_grpc_stub: AsyncMock, sample_task: Task
191+
):
192+
"""Test retrieving a task."""
193+
mock_grpc_stub.GetTask.return_value = proto_utils.ToProto.task(sample_task)
194+
params = TaskQueryParams(id=sample_task.id)
195+
196+
response = await grpc_client.get_task(params)
197+
198+
mock_grpc_stub.GetTask.assert_awaited_once_with(
199+
a2a_pb2.GetTaskRequest(name=f'tasks/{sample_task.id}')
200+
)
201+
assert response.id == sample_task.id
202+
203+
204+
@pytest.mark.asyncio
205+
async def test_cancel_task(
206+
grpc_client: A2AGrpcClient, mock_grpc_stub: AsyncMock, sample_task: Task
207+
):
208+
"""Test cancelling a task."""
209+
cancelled_task = sample_task.model_copy()
210+
cancelled_task.status.state = TaskState.canceled
211+
mock_grpc_stub.CancelTask.return_value = proto_utils.ToProto.task(
212+
cancelled_task
213+
)
214+
params = TaskIdParams(id=sample_task.id)
215+
216+
response = await grpc_client.cancel_task(params)
217+
218+
mock_grpc_stub.CancelTask.assert_awaited_once_with(
219+
a2a_pb2.CancelTaskRequest(name=f'tasks/{sample_task.id}')
220+
)
221+
assert response.status.state == TaskState.canceled
222+
223+
224+
@pytest.mark.asyncio
225+
async def test_set_task_callback(
226+
grpc_client: A2AGrpcClient, mock_grpc_stub: AsyncMock
227+
):
228+
"""Test setting a task callback."""
229+
task_id = 'task-callback-1'
230+
config = TaskPushNotificationConfig(
231+
taskId=task_id,
232+
pushNotificationConfig=PushNotificationConfig(
233+
url='http://my.callback/push', token='secret'
234+
),
235+
)
236+
proto_config = proto_utils.ToProto.task_push_notification_config(config)
237+
mock_grpc_stub.CreateTaskPushNotification.return_value = proto_config
238+
239+
response = await grpc_client.set_task_callback(config)
240+
241+
mock_grpc_stub.CreateTaskPushNotification.assert_awaited_once()
242+
call_args, _ = mock_grpc_stub.CreateTaskPushNotification.call_args
243+
sent_request = call_args[0]
244+
assert isinstance(sent_request, a2a_pb2.CreateTaskPushNotificationRequest)
245+
246+
assert response.taskId == task_id
247+
assert response.pushNotificationConfig.url == 'http://my.callback/push'
248+
249+
250+
@pytest.mark.asyncio
251+
async def test_get_task_callback(
252+
grpc_client: A2AGrpcClient, mock_grpc_stub: AsyncMock
253+
):
254+
"""Test getting a task callback."""
255+
task_id = 'task-get-callback-1'
256+
push_id = 'undefined' # As per current implementation
257+
resource_name = f'tasks/{task_id}/pushNotification/{push_id}'
258+
259+
config = TaskPushNotificationConfig(
260+
taskId=task_id,
261+
pushNotificationConfig=PushNotificationConfig(
262+
url='http://my.callback/get', token='secret-get'
263+
),
264+
)
265+
proto_config = proto_utils.ToProto.task_push_notification_config(config)
266+
mock_grpc_stub.GetTaskPushNotification.return_value = proto_config
267+
268+
params = TaskIdParams(id=task_id)
269+
response = await grpc_client.get_task_callback(params)
270+
271+
mock_grpc_stub.GetTaskPushNotification.assert_awaited_once_with(
272+
a2a_pb2.GetTaskPushNotificationRequest(name=resource_name)
273+
)
274+
assert response.taskId == task_id
275+
assert response.pushNotificationConfig.url == 'http://my.callback/get'
276+
277+
278+
@pytest.mark.asyncio
279+
async def test_send_message_streaming_with_msg_and_task(
280+
grpc_client: A2AGrpcClient,
281+
mock_grpc_stub: AsyncMock,
282+
sample_message_send_params: MessageSendParams,
283+
):
284+
"""Test streaming response that contains both message and task types."""
285+
mock_stream = AsyncMock()
286+
287+
msg_event = Message(role=Role.agent, messageId='msg-stream-1', parts=[])
288+
task_event = Task(
289+
id='task-stream-1',
290+
contextId='ctx-stream-1',
291+
status=TaskStatus(state=TaskState.completed),
292+
)
293+
294+
stream_responses = [
295+
a2a_pb2.StreamResponse(msg=proto_utils.ToProto.message(msg_event)),
296+
a2a_pb2.StreamResponse(task=proto_utils.ToProto.task(task_event)),
297+
grpc.aio.EOF,
298+
]
299+
300+
mock_stream.read.side_effect = stream_responses
301+
mock_grpc_stub.SendStreamingMessage.return_value = mock_stream
302+
303+
results = [
304+
result
305+
async for result in grpc_client.send_message_streaming(
306+
sample_message_send_params
307+
)
308+
]
309+
310+
assert len(results) == 2
311+
assert isinstance(results[0], Message)
312+
assert isinstance(results[1], Task)

tests/server/agent_execution/test_context.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
MessageSendParams,
1111
Task,
1212
)
13+
from a2a.utils.errors import ServerError
1314

1415

1516
class TestRequestContext:
@@ -165,6 +166,33 @@ def test_check_or_generate_context_id_with_existing_context_id(
165166
assert context.context_id == existing_id
166167
assert mock_params.message.contextId == existing_id
167168

169+
def test_init_raises_error_on_task_id_mismatch(
170+
self, mock_params, mock_task
171+
):
172+
"""Test that an error is raised if provided task_id mismatches task.id."""
173+
with pytest.raises(ServerError) as exc_info:
174+
RequestContext(
175+
request=mock_params, task_id='wrong-task-id', task=mock_task
176+
)
177+
assert 'bad task id' in str(exc_info.value.error.message)
178+
179+
def test_init_raises_error_on_context_id_mismatch(
180+
self, mock_params, mock_task
181+
):
182+
"""Test that an error is raised if provided context_id mismatches task.contextId."""
183+
# Set a valid task_id to avoid that error
184+
mock_params.message.taskId = mock_task.id
185+
186+
with pytest.raises(ServerError) as exc_info:
187+
RequestContext(
188+
request=mock_params,
189+
task_id=mock_task.id,
190+
context_id='wrong-context-id',
191+
task=mock_task,
192+
)
193+
194+
assert 'bad context id' in str(exc_info.value.error.message)
195+
168196
def test_with_related_tasks_provided(self, mock_task):
169197
"""Test initialization with related tasks provided."""
170198
related_tasks = [mock_task, Mock(spec=Task)]

tests/server/apps/jsonrpc/test_jsonrpc_app.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -70,17 +70,20 @@ def test_jsonrpc_app_build_method_abstract_raises_typeerror(
7070
# Ensure 'supportsAuthenticatedExtendedCard' attribute exists
7171
mock_agent_card.supportsAuthenticatedExtendedCard = False
7272

73-
class AbstractTester(JSONRPCApplication):
74-
# No 'build' method implemented
75-
pass
76-
77-
# Instantiating an ABC subclass that doesn't implement all abstract methods raises TypeError
73+
# This will fail at definition time if an abstract method is not implemented
7874
with pytest.raises(
7975
TypeError,
80-
match="Can't instantiate abstract class AbstractTester with abstract method build",
76+
match="Can't instantiate abstract class IncompleteJSONRPCApp with abstract method build",
8177
):
82-
# Using positional arguments for the abstract class constructor
83-
AbstractTester(mock_handler, mock_agent_card)
78+
79+
class IncompleteJSONRPCApp(JSONRPCApplication):
80+
# Intentionally not implementing 'build'
81+
def some_other_method(self):
82+
pass
83+
84+
IncompleteJSONRPCApp(
85+
agent_card=mock_agent_card, http_handler=mock_handler
86+
)
8487

8588

8689
if __name__ == '__main__':

0 commit comments

Comments
 (0)