|
1 | 1 | from unittest.mock import AsyncMock |
2 | 2 |
|
3 | | -import grpc |
4 | 3 | import pytest |
5 | 4 |
|
6 | | -from a2a import types |
7 | 5 | from a2a.client import A2AGrpcClient |
8 | 6 | from a2a.grpc import a2a_pb2, a2a_pb2_grpc |
9 | 7 | from a2a.types import ( |
|
12 | 10 | Message, |
13 | 11 | MessageSendParams, |
14 | 12 | Part, |
15 | | - PushNotificationConfig, |
16 | 13 | Role, |
17 | 14 | Task, |
18 | | - TaskArtifactUpdateEvent, |
19 | 15 | TaskIdParams, |
20 | | - TaskPushNotificationConfig, |
21 | 16 | TaskQueryParams, |
22 | 17 | TaskState, |
23 | 18 | TaskStatus, |
24 | | - TaskStatusUpdateEvent, |
25 | 19 | TextPart, |
26 | 20 | ) |
27 | 21 | from a2a.utils import proto_utils |
@@ -115,87 +109,6 @@ async def test_send_message_task_response( |
115 | 109 | assert response.id == sample_task.id |
116 | 110 |
|
117 | 111 |
|
118 | | -@pytest.mark.asyncio |
119 | | -async def test_send_message_message_response( |
120 | | - grpc_client: A2AGrpcClient, |
121 | | - mock_grpc_stub: AsyncMock, |
122 | | - sample_message_send_params: MessageSendParams, |
123 | | - sample_message: Message, |
124 | | -): |
125 | | - """Test send_message that returns a Message.""" |
126 | | - mock_grpc_stub.SendMessage.return_value = a2a_pb2.SendMessageResponse( |
127 | | - msg=proto_utils.ToProto.message(sample_message) |
128 | | - ) |
129 | | - |
130 | | - response = await grpc_client.send_message(sample_message_send_params) |
131 | | - |
132 | | - mock_grpc_stub.SendMessage.assert_awaited_once() |
133 | | - assert isinstance(response, Message) |
134 | | - assert response.messageId == sample_message.messageId |
135 | | - |
136 | | - |
137 | | -@pytest.mark.asyncio |
138 | | -async def test_send_message_streaming( |
139 | | - grpc_client: A2AGrpcClient, |
140 | | - mock_grpc_stub: AsyncMock, |
141 | | - sample_message_send_params: MessageSendParams, |
142 | | -): |
143 | | - """Test the streaming message functionality.""" |
144 | | - mock_stream = AsyncMock() |
145 | | - |
146 | | - status_update = TaskStatusUpdateEvent( |
147 | | - taskId='task-stream', |
148 | | - contextId='ctx-stream', |
149 | | - status=TaskStatus(state=TaskState.working), |
150 | | - final=False, |
151 | | - ) |
152 | | - artifact_update = TaskArtifactUpdateEvent( |
153 | | - taskId='task-stream', |
154 | | - contextId='ctx-stream', |
155 | | - artifact=types.Artifact( |
156 | | - artifactId='art-stream', |
157 | | - parts=[types.Part(root=types.TextPart(text='data'))], |
158 | | - ), |
159 | | - ) |
160 | | - final_task = Task( |
161 | | - id='task-stream', |
162 | | - contextId='ctx-stream', |
163 | | - status=TaskStatus(state=TaskState.completed), |
164 | | - ) |
165 | | - |
166 | | - stream_responses = [ |
167 | | - a2a_pb2.StreamResponse( |
168 | | - status_update=proto_utils.ToProto.task_status_update_event( |
169 | | - status_update |
170 | | - ) |
171 | | - ), |
172 | | - a2a_pb2.StreamResponse( |
173 | | - artifact_update=proto_utils.ToProto.task_artifact_update_event( |
174 | | - artifact_update |
175 | | - ) |
176 | | - ), |
177 | | - a2a_pb2.StreamResponse(task=proto_utils.ToProto.task(final_task)), |
178 | | - grpc.aio.EOF, |
179 | | - ] |
180 | | - |
181 | | - mock_stream.read.side_effect = stream_responses |
182 | | - mock_grpc_stub.SendStreamingMessage.return_value = mock_stream |
183 | | - |
184 | | - results = [ |
185 | | - result |
186 | | - async for result in grpc_client.send_message_streaming( |
187 | | - sample_message_send_params |
188 | | - ) |
189 | | - ] |
190 | | - |
191 | | - mock_grpc_stub.SendStreamingMessage.assert_called_once() |
192 | | - assert len(results) == 3 |
193 | | - assert isinstance(results[0], TaskStatusUpdateEvent) |
194 | | - assert isinstance(results[1], TaskArtifactUpdateEvent) |
195 | | - assert isinstance(results[2], Task) |
196 | | - assert results[2].status.state == TaskState.completed |
197 | | - |
198 | | - |
199 | 112 | @pytest.mark.asyncio |
200 | 113 | async def test_get_task( |
201 | 114 | grpc_client: A2AGrpcClient, mock_grpc_stub: AsyncMock, sample_task: Task |
@@ -230,106 +143,3 @@ async def test_cancel_task( |
230 | 143 | a2a_pb2.CancelTaskRequest(name=f'tasks/{sample_task.id}') |
231 | 144 | ) |
232 | 145 | assert response.status.state == TaskState.canceled |
233 | | - |
234 | | - |
235 | | -@pytest.mark.asyncio |
236 | | -async def test_set_task_callback( |
237 | | - grpc_client: A2AGrpcClient, mock_grpc_stub: AsyncMock |
238 | | -): |
239 | | - """Test setting a task callback.""" |
240 | | - task_id = 'task-callback-1' |
241 | | - config = TaskPushNotificationConfig( |
242 | | - taskId=task_id, |
243 | | - pushNotificationConfig=PushNotificationConfig( |
244 | | - url='http://my.callback/push', token='secret' |
245 | | - ), |
246 | | - ) |
247 | | - # The gRPC method returns the proto version of TaskPushNotificationConfig, not the inner config |
248 | | - proto_response = a2a_pb2.TaskPushNotificationConfig( |
249 | | - name=f'tasks/{task_id}/pushNotifications/{config.pushNotificationConfig.id or "some_id"}', |
250 | | - push_notification_config=proto_utils.ToProto.push_notification_config( |
251 | | - config.pushNotificationConfig |
252 | | - ), |
253 | | - ) |
254 | | - mock_grpc_stub.CreateTaskPushNotification.return_value = proto_response |
255 | | - |
256 | | - response = await grpc_client.set_task_callback(config) |
257 | | - |
258 | | - mock_grpc_stub.CreateTaskPushNotification.assert_awaited_once() |
259 | | - call_args, _ = mock_grpc_stub.CreateTaskPushNotification.call_args |
260 | | - sent_request = call_args[0] |
261 | | - assert isinstance(sent_request, a2a_pb2.CreateTaskPushNotificationRequest) |
262 | | - |
263 | | - assert response.taskId == task_id |
264 | | - assert response.pushNotificationConfig.url == 'http://my.callback/push' |
265 | | - |
266 | | - |
267 | | -@pytest.mark.asyncio |
268 | | -async def test_get_task_callback( |
269 | | - grpc_client: A2AGrpcClient, mock_grpc_stub: AsyncMock |
270 | | -): |
271 | | - """Test getting a task callback.""" |
272 | | - task_id = 'task-get-callback-1' |
273 | | - push_id = 'undefined' # As per current implementation |
274 | | - resource_name = f'tasks/{task_id}/pushNotification/{push_id}' |
275 | | - |
276 | | - config_model = TaskPushNotificationConfig( |
277 | | - taskId=task_id, |
278 | | - pushNotificationConfig=PushNotificationConfig( |
279 | | - id=push_id, url='http://my.callback/get', token='secret-get' |
280 | | - ), |
281 | | - ) |
282 | | - |
283 | | - proto_response = a2a_pb2.TaskPushNotificationConfig( |
284 | | - name=resource_name, |
285 | | - push_notification_config=proto_utils.ToProto.push_notification_config( |
286 | | - config_model.pushNotificationConfig |
287 | | - ), |
288 | | - ) |
289 | | - mock_grpc_stub.GetTaskPushNotification.return_value = proto_response |
290 | | - |
291 | | - params = TaskIdParams(id=task_id) |
292 | | - response = await grpc_client.get_task_callback(params) |
293 | | - |
294 | | - mock_grpc_stub.GetTaskPushNotification.assert_awaited_once_with( |
295 | | - a2a_pb2.GetTaskPushNotificationRequest(name=resource_name) |
296 | | - ) |
297 | | - assert response.taskId == task_id |
298 | | - assert response.pushNotificationConfig.url == 'http://my.callback/get' |
299 | | - |
300 | | - |
301 | | -@pytest.mark.asyncio |
302 | | -async def test_send_message_streaming_with_msg_and_task( |
303 | | - grpc_client: A2AGrpcClient, |
304 | | - mock_grpc_stub: AsyncMock, |
305 | | - sample_message_send_params: MessageSendParams, |
306 | | -): |
307 | | - """Test streaming response that contains both message and task types.""" |
308 | | - mock_stream = AsyncMock() |
309 | | - |
310 | | - msg_event = Message(role=Role.agent, messageId='msg-stream-1', parts=[]) |
311 | | - task_event = Task( |
312 | | - id='task-stream-1', |
313 | | - contextId='ctx-stream-1', |
314 | | - status=TaskStatus(state=TaskState.completed), |
315 | | - ) |
316 | | - |
317 | | - stream_responses = [ |
318 | | - a2a_pb2.StreamResponse(msg=proto_utils.ToProto.message(msg_event)), |
319 | | - a2a_pb2.StreamResponse(task=proto_utils.ToProto.task(task_event)), |
320 | | - grpc.aio.EOF, |
321 | | - ] |
322 | | - |
323 | | - mock_stream.read.side_effect = stream_responses |
324 | | - mock_grpc_stub.SendStreamingMessage.return_value = mock_stream |
325 | | - |
326 | | - results = [ |
327 | | - result |
328 | | - async for result in grpc_client.send_message_streaming( |
329 | | - sample_message_send_params |
330 | | - ) |
331 | | - ] |
332 | | - |
333 | | - assert len(results) == 2 |
334 | | - assert isinstance(results[0], Message) |
335 | | - assert isinstance(results[1], Task) |
0 commit comments