Skip to content

Commit 4339f8d

Browse files
committed
Add more tests
1 parent d227dde commit 4339f8d

File tree

3 files changed

+713
-0
lines changed

3 files changed

+713
-0
lines changed

tests/client/test_grpc_client.py

Lines changed: 207 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,207 @@
1+
from unittest.mock import AsyncMock
2+
3+
import grpc
4+
import pytest
5+
6+
from a2a import types
7+
from a2a.client.grpc_client import A2AGrpcClient
8+
from a2a.grpc import a2a_pb2, a2a_pb2_grpc
9+
10+
11+
# --- Fixtures ---
12+
13+
14+
@pytest.fixture
15+
def mock_grpc_stub() -> AsyncMock:
16+
return AsyncMock(spec=a2a_pb2_grpc.A2AServiceStub)
17+
18+
19+
@pytest.fixture
20+
def sample_agent_card() -> types.AgentCard:
21+
return types.AgentCard(
22+
name='Test Agent',
23+
description='A test agent',
24+
url='http://localhost',
25+
version='1.0.0',
26+
capabilities=types.AgentCapabilities(
27+
streaming=True, pushNotifications=True
28+
),
29+
defaultInputModes=['text/plain'],
30+
defaultOutputModes=['text/plain'],
31+
skills=[],
32+
)
33+
34+
35+
@pytest.fixture
36+
def grpc_client(
37+
mock_grpc_stub: AsyncMock, sample_agent_card: types.AgentCard
38+
) -> A2AGrpcClient:
39+
return A2AGrpcClient(grpc_stub=mock_grpc_stub, agent_card=sample_agent_card)
40+
41+
42+
# --- Test Cases ---
43+
44+
45+
@pytest.mark.asyncio
46+
async def test_send_message_returns_task(
47+
grpc_client: A2AGrpcClient, mock_grpc_stub: AsyncMock
48+
):
49+
"""Test send_message when the server returns a Task."""
50+
request_params = types.MessageSendParams(
51+
message=types.Message(role=types.Role.user, messageId='1', parts=[])
52+
)
53+
response_proto = a2a_pb2.SendMessageResponse(task=a2a_pb2.Task(id='task-1'))
54+
mock_grpc_stub.SendMessage.return_value = response_proto
55+
56+
result = await grpc_client.send_message(request_params)
57+
58+
mock_grpc_stub.SendMessage.assert_awaited_once()
59+
assert isinstance(result, types.Task)
60+
assert result.id == 'task-1'
61+
62+
63+
@pytest.mark.asyncio
64+
async def test_send_message_returns_message(
65+
grpc_client: A2AGrpcClient, mock_grpc_stub: AsyncMock
66+
):
67+
"""Test send_message when the server returns a Message."""
68+
request_params = types.MessageSendParams(
69+
message=types.Message(role=types.Role.user, messageId='1', parts=[])
70+
)
71+
response_proto = a2a_pb2.SendMessageResponse(
72+
msg=a2a_pb2.Message(message_id='msg-resp-1')
73+
)
74+
mock_grpc_stub.SendMessage.return_value = response_proto
75+
76+
result = await grpc_client.send_message(request_params)
77+
78+
mock_grpc_stub.SendMessage.assert_awaited_once()
79+
assert isinstance(result, types.Message)
80+
assert result.messageId == 'msg-resp-1'
81+
82+
83+
@pytest.mark.asyncio
84+
async def test_send_message_streaming(
85+
grpc_client: A2AGrpcClient, mock_grpc_stub: AsyncMock
86+
):
87+
"""Test the streaming message functionality."""
88+
request_params = types.MessageSendParams(
89+
message=types.Message(role=types.Role.user, messageId='1', parts=[])
90+
)
91+
92+
# Mock the stream object and its read method
93+
mock_stream = AsyncMock()
94+
stream_responses = [
95+
a2a_pb2.StreamResponse(task=a2a_pb2.Task(id='task-stream')),
96+
a2a_pb2.StreamResponse(msg=a2a_pb2.Message(message_id='msg-stream')),
97+
a2a_pb2.StreamResponse(
98+
status_update=a2a_pb2.TaskStatusUpdateEvent(task_id='task-stream')
99+
),
100+
a2a_pb2.StreamResponse(
101+
artifact_update=a2a_pb2.TaskArtifactUpdateEvent(
102+
task_id='task-stream'
103+
)
104+
),
105+
grpc.aio.EOF,
106+
]
107+
mock_stream.read.side_effect = stream_responses
108+
mock_grpc_stub.SendStreamingMessage.return_value = mock_stream
109+
110+
results = []
111+
async for item in grpc_client.send_message_streaming(request_params):
112+
results.append(item)
113+
114+
mock_grpc_stub.SendStreamingMessage.assert_called_once()
115+
assert len(results) == 4
116+
assert isinstance(results[0], types.Task)
117+
assert isinstance(results[1], types.Message)
118+
assert isinstance(results[2], types.TaskStatusUpdateEvent)
119+
assert isinstance(results[3], types.TaskArtifactUpdateEvent)
120+
121+
122+
@pytest.mark.asyncio
123+
async def test_get_task(grpc_client: A2AGrpcClient, mock_grpc_stub: AsyncMock):
124+
"""Test retrieving a task."""
125+
request_params = types.TaskQueryParams(id='task-1')
126+
response_proto = a2a_pb2.Task(id='task-1', context_id='ctx-1')
127+
mock_grpc_stub.GetTask.return_value = response_proto
128+
129+
result = await grpc_client.get_task(request_params)
130+
131+
mock_grpc_stub.GetTask.assert_awaited_once_with(
132+
a2a_pb2.GetTaskRequest(name='tasks/task-1')
133+
)
134+
assert isinstance(result, types.Task)
135+
assert result.id == 'task-1'
136+
137+
138+
@pytest.mark.asyncio
139+
async def test_cancel_task(
140+
grpc_client: A2AGrpcClient, mock_grpc_stub: AsyncMock
141+
):
142+
"""Test cancelling a task."""
143+
request_params = types.TaskIdParams(id='task-1')
144+
response_proto = a2a_pb2.Task(
145+
id='task-1',
146+
status=a2a_pb2.TaskStatus(state=a2a_pb2.TaskState.TASK_STATE_CANCELLED),
147+
)
148+
mock_grpc_stub.CancelTask.return_value = response_proto
149+
150+
result = await grpc_client.cancel_task(request_params)
151+
152+
mock_grpc_stub.CancelTask.assert_awaited_once_with(
153+
a2a_pb2.CancelTaskRequest(name='tasks/task-1')
154+
)
155+
assert isinstance(result, types.Task)
156+
assert result.status.state == types.TaskState.canceled
157+
158+
159+
@pytest.mark.asyncio
160+
async def test_set_task_callback(
161+
grpc_client: A2AGrpcClient, mock_grpc_stub: AsyncMock
162+
):
163+
"""Test setting a task callback."""
164+
request_params = types.TaskPushNotificationConfig(
165+
taskId='task-1',
166+
pushNotificationConfig=types.PushNotificationConfig(
167+
url='http://callback.url'
168+
),
169+
)
170+
response_proto = a2a_pb2.TaskPushNotificationConfig(
171+
name='tasks/task-1/pushNotifications/config-1',
172+
push_notification_config=a2a_pb2.PushNotificationConfig(
173+
url='http://callback.url'
174+
),
175+
)
176+
mock_grpc_stub.CreateTaskPushNotification.return_value = response_proto
177+
178+
result = await grpc_client.set_task_callback(request_params)
179+
180+
mock_grpc_stub.CreateTaskPushNotification.assert_awaited_once()
181+
assert isinstance(result, types.TaskPushNotificationConfig)
182+
assert result.pushNotificationConfig.url == 'http://callback.url'
183+
184+
185+
@pytest.mark.asyncio
186+
async def test_get_task_callback(
187+
grpc_client: A2AGrpcClient, mock_grpc_stub: AsyncMock
188+
):
189+
"""Test getting a task callback."""
190+
request_params = types.TaskIdParams(id='task-1')
191+
response_proto = a2a_pb2.TaskPushNotificationConfig(
192+
name='tasks/task-1/pushNotifications/undefined',
193+
push_notification_config=a2a_pb2.PushNotificationConfig(
194+
url='http://callback.url'
195+
),
196+
)
197+
mock_grpc_stub.GetTaskPushNotification.return_value = response_proto
198+
199+
result = await grpc_client.get_task_callback(request_params)
200+
201+
mock_grpc_stub.GetTaskPushNotification.assert_awaited_once_with(
202+
a2a_pb2.GetTaskPushNotificationRequest(
203+
name='tasks/task-1/pushNotifications/undefined'
204+
)
205+
)
206+
assert isinstance(result, types.TaskPushNotificationConfig)
207+
assert result.pushNotificationConfig.url == 'http://callback.url'

0 commit comments

Comments
 (0)