1- from unittest .mock import AsyncMock
1+ from unittest .mock import AsyncMock , MagicMock
22
3+ import grpc
34import pytest
45
56from a2a .client import A2AGrpcClient
67from a2a .grpc import a2a_pb2 , a2a_pb2_grpc
78from a2a .types import (
89 AgentCapabilities ,
910 AgentCard ,
11+ Artifact ,
1012 Message ,
1113 MessageSendParams ,
1214 Part ,
15+ PushNotificationAuthenticationInfo ,
16+ PushNotificationConfig ,
1317 Role ,
1418 Task ,
19+ TaskArtifactUpdateEvent ,
1520 TaskIdParams ,
21+ TaskPushNotificationConfig ,
1622 TaskQueryParams ,
1723 TaskState ,
1824 TaskStatus ,
25+ TaskStatusUpdateEvent ,
1926 TextPart ,
2027)
2128from a2a .utils import proto_utils
29+ from a2a .utils .errors import ServerError
2230
2331
2432# Fixtures
@@ -30,8 +38,8 @@ def mock_grpc_stub() -> AsyncMock:
3038 stub .SendStreamingMessage = AsyncMock ()
3139 stub .GetTask = AsyncMock ()
3240 stub .CancelTask = AsyncMock ()
33- stub .CreateTaskPushNotification = AsyncMock ()
34- stub .GetTaskPushNotification = AsyncMock ()
41+ stub .CreateTaskPushNotificationConfig = AsyncMock ()
42+ stub .GetTaskPushNotificationConfig = AsyncMock ()
3543 return stub
3644
3745
@@ -90,6 +98,78 @@ def sample_message() -> Message:
9098 )
9199
92100
101+ @pytest .fixture
102+ def sample_artifact () -> Artifact :
103+ """Provides a sample Artifact object."""
104+ return Artifact (
105+ artifactId = 'artifact-1' ,
106+ name = 'example.txt' ,
107+ description = 'An example artifact' ,
108+ parts = [Part (root = TextPart (text = 'Hi there' ))],
109+ metadata = {},
110+ extensions = [],
111+ )
112+
113+
114+ @pytest .fixture
115+ def sample_task_status_update_event () -> TaskStatusUpdateEvent :
116+ """Provides a sample TaskStatusUpdateEvent."""
117+ return TaskStatusUpdateEvent (
118+ taskId = 'task-1' ,
119+ contextId = 'ctx-1' ,
120+ status = TaskStatus (state = TaskState .working ),
121+ final = False ,
122+ metadata = {},
123+ )
124+
125+
126+ @pytest .fixture
127+ def sample_task_artifact_update_event (
128+ sample_artifact ,
129+ ) -> TaskArtifactUpdateEvent :
130+ """Provides a sample TaskArtifactUpdateEvent."""
131+ return TaskArtifactUpdateEvent (
132+ taskId = 'task-1' ,
133+ contextId = 'ctx-1' ,
134+ artifact = sample_artifact ,
135+ append = True ,
136+ last_chunk = True ,
137+ metadata = {},
138+ )
139+
140+
141+ @pytest .fixture
142+ def sample_authentication_info () -> PushNotificationAuthenticationInfo :
143+ """Provides a sample AuthenticationInfo object."""
144+ return PushNotificationAuthenticationInfo (
145+ schemes = ['apikey' , 'oauth2' ], credentials = 'secret-token'
146+ )
147+
148+
149+ @pytest .fixture
150+ def sample_push_notification_config (
151+ sample_authentication_info : PushNotificationAuthenticationInfo ,
152+ ) -> PushNotificationConfig :
153+ """Provides a sample PushNotificationConfig object."""
154+ return PushNotificationConfig (
155+ id = 'config-1' ,
156+ url = 'https://example.com/notify' ,
157+ token = 'example-token' ,
158+ authentication = sample_authentication_info ,
159+ )
160+
161+
162+ @pytest .fixture
163+ def sample_task_push_notification_config (
164+ sample_push_notification_config : PushNotificationConfig ,
165+ ) -> TaskPushNotificationConfig :
166+ """Provides a sample TaskPushNotificationConfig object."""
167+ return TaskPushNotificationConfig (
168+ taskId = 'task-1' ,
169+ pushNotificationConfig = sample_push_notification_config ,
170+ )
171+
172+
93173@pytest .mark .asyncio
94174async def test_send_message_task_response (
95175 grpc_client : A2AGrpcClient ,
@@ -109,6 +189,76 @@ async def test_send_message_task_response(
109189 assert response .id == sample_task .id
110190
111191
192+ @pytest .mark .asyncio
193+ async def test_send_message_message_response (
194+ grpc_client : A2AGrpcClient ,
195+ mock_grpc_stub : AsyncMock ,
196+ sample_message_send_params : MessageSendParams ,
197+ sample_message : Message ,
198+ ):
199+ """Test send_message that returns a Message."""
200+ mock_grpc_stub .SendMessage .return_value = a2a_pb2 .SendMessageResponse (
201+ msg = proto_utils .ToProto .message (sample_message )
202+ )
203+
204+ response = await grpc_client .send_message (sample_message_send_params )
205+
206+ mock_grpc_stub .SendMessage .assert_awaited_once ()
207+ assert isinstance (response , Message )
208+ assert response .messageId == sample_message .messageId
209+
210+
211+ @pytest .mark .asyncio
212+ async def test_send_message_streaming (
213+ grpc_client : A2AGrpcClient ,
214+ mock_grpc_stub : AsyncMock ,
215+ sample_message_send_params : MessageSendParams ,
216+ sample_message : Message ,
217+ sample_task : Task ,
218+ sample_task_status_update_event : TaskStatusUpdateEvent ,
219+ sample_task_artifact_update_event : TaskArtifactUpdateEvent ,
220+ ):
221+ """Test send_message_streaming that yields responses."""
222+ stream = MagicMock ()
223+ stream .read = AsyncMock (
224+ side_effect = [
225+ a2a_pb2 .StreamResponse (
226+ msg = proto_utils .ToProto .message (sample_message )
227+ ),
228+ a2a_pb2 .StreamResponse (task = proto_utils .ToProto .task (sample_task )),
229+ a2a_pb2 .StreamResponse (
230+ status_update = proto_utils .ToProto .task_status_update_event (
231+ sample_task_status_update_event
232+ )
233+ ),
234+ a2a_pb2 .StreamResponse (
235+ artifact_update = proto_utils .ToProto .task_artifact_update_event (
236+ sample_task_artifact_update_event
237+ )
238+ ),
239+ grpc .aio .EOF ,
240+ ]
241+ )
242+ mock_grpc_stub .SendStreamingMessage .return_value = stream
243+
244+ responses = [
245+ response
246+ async for response in grpc_client .send_message_streaming (
247+ sample_message_send_params
248+ )
249+ ]
250+
251+ mock_grpc_stub .SendStreamingMessage .assert_awaited_once ()
252+ assert isinstance (responses [0 ], Message )
253+ assert responses [0 ].messageId == sample_message .messageId
254+ assert isinstance (responses [1 ], Task )
255+ assert responses [1 ].id == sample_task .id
256+ assert isinstance (responses [2 ], TaskStatusUpdateEvent )
257+ assert responses [2 ].taskId == sample_task_status_update_event .taskId
258+ assert isinstance (responses [3 ], TaskArtifactUpdateEvent )
259+ assert responses [3 ].taskId == sample_task_artifact_update_event .taskId
260+
261+
112262@pytest .mark .asyncio
113263async def test_get_task (
114264 grpc_client : A2AGrpcClient , mock_grpc_stub : AsyncMock , sample_task : Task
@@ -143,3 +293,117 @@ async def test_cancel_task(
143293 a2a_pb2 .CancelTaskRequest (name = f'tasks/{ sample_task .id } ' )
144294 )
145295 assert response .status .state == TaskState .canceled
296+
297+
298+ @pytest .mark .asyncio
299+ async def test_set_task_callback_with_valid_task (
300+ grpc_client : A2AGrpcClient ,
301+ mock_grpc_stub : AsyncMock ,
302+ sample_task_push_notification_config : TaskPushNotificationConfig ,
303+ ):
304+ """Test setting a task push notification config with a valid task id."""
305+ task_id = 'task-1'
306+ config_id = 'config-1'
307+ mock_grpc_stub .CreateTaskPushNotificationConfig .return_value = (
308+ a2a_pb2 .CreateTaskPushNotificationConfigRequest (
309+ parent = f'tasks/{ task_id } ' ,
310+ config_id = config_id ,
311+ config = proto_utils .ToProto .task_push_notification_config (
312+ sample_task_push_notification_config
313+ ),
314+ )
315+ )
316+
317+ response = await grpc_client .set_task_callback (
318+ sample_task_push_notification_config
319+ )
320+
321+ mock_grpc_stub .CreateTaskPushNotificationConfig .assert_awaited_once_with (
322+ a2a_pb2 .CreateTaskPushNotificationConfigRequest (
323+ config = proto_utils .ToProto .task_push_notification_config (
324+ sample_task_push_notification_config
325+ ),
326+ )
327+ )
328+ assert response .taskId == task_id
329+
330+
331+ @pytest .mark .asyncio
332+ async def test_set_task_callback_with_invalid_task (
333+ grpc_client : A2AGrpcClient ,
334+ mock_grpc_stub : AsyncMock ,
335+ sample_task_push_notification_config : TaskPushNotificationConfig ,
336+ ):
337+ """Test setting a task push notification config with a invalid task id."""
338+ task_id = 'task-1'
339+ config_id = 'config-1'
340+ mock_grpc_stub .CreateTaskPushNotificationConfig .return_value = (
341+ a2a_pb2 .CreateTaskPushNotificationConfigRequest (
342+ parent = f'invalid-path-to-tasks/{ task_id } ' ,
343+ config_id = config_id ,
344+ config = proto_utils .ToProto .task_push_notification_config (
345+ sample_task_push_notification_config
346+ ),
347+ )
348+ )
349+
350+ with pytest .raises (ServerError ) as exc_info :
351+ await grpc_client .set_task_callback (
352+ sample_task_push_notification_config
353+ )
354+ assert 'No task for' in exc_info .value .error .message
355+
356+
357+ @pytest .mark .asyncio
358+ async def test_get_task_callback_with_valid_task (
359+ grpc_client : A2AGrpcClient ,
360+ mock_grpc_stub : AsyncMock ,
361+ sample_task_push_notification_config : TaskPushNotificationConfig ,
362+ ):
363+ """Test retrieving a task push notification config with a valid task id."""
364+ task_id = 'task-1'
365+ config_id = 'config-1'
366+ mock_grpc_stub .GetTaskPushNotificationConfig .return_value = (
367+ a2a_pb2 .CreateTaskPushNotificationConfigRequest (
368+ parent = f'tasks/{ task_id } ' ,
369+ config_id = config_id ,
370+ config = proto_utils .ToProto .task_push_notification_config (
371+ sample_task_push_notification_config
372+ ),
373+ )
374+ )
375+ params = TaskIdParams (id = sample_task_push_notification_config .taskId )
376+
377+ response = await grpc_client .get_task_callback (params )
378+
379+ mock_grpc_stub .GetTaskPushNotificationConfig .assert_awaited_once_with (
380+ a2a_pb2 .GetTaskPushNotificationConfigRequest (
381+ name = f'tasks/{ params .id } /pushNotification/undefined' ,
382+ )
383+ )
384+ assert response .taskId == task_id
385+
386+
387+ @pytest .mark .asyncio
388+ async def test_get_task_callback_with_invalid_task (
389+ grpc_client : A2AGrpcClient ,
390+ mock_grpc_stub : AsyncMock ,
391+ sample_task_push_notification_config : TaskPushNotificationConfig ,
392+ ):
393+ """Test retrieving a task push notification config with a invalid task id."""
394+ task_id = 'task-1'
395+ config_id = 'config-1'
396+ mock_grpc_stub .GetTaskPushNotificationConfig .return_value = (
397+ a2a_pb2 .CreateTaskPushNotificationConfigRequest (
398+ parent = f'invalid-path-to-tasks/{ task_id } ' ,
399+ config_id = config_id ,
400+ config = proto_utils .ToProto .task_push_notification_config (
401+ sample_task_push_notification_config
402+ ),
403+ )
404+ )
405+ params = TaskIdParams (id = sample_task_push_notification_config .taskId )
406+
407+ with pytest .raises (ServerError ) as exc_info :
408+ await grpc_client .get_task_callback (params )
409+ assert 'No task for' in exc_info .value .error .message
0 commit comments