1- from unittest .mock import AsyncMock , MagicMock
1+ from unittest .mock import AsyncMock
22
33import grpc
44import pytest
55
6+ from a2a import types
67from a2a .client import A2AGrpcClient
78from a2a .grpc import a2a_pb2 , a2a_pb2_grpc
89from a2a .types import (
2930# Fixtures
3031@pytest .fixture
3132def mock_grpc_stub () -> AsyncMock :
32- """Provides a mock gRPC stub."""
33- return AsyncMock (spec = a2a_pb2_grpc .A2AServiceStub )
33+ """Provides a mock gRPC stub with methods mocked."""
34+ stub = AsyncMock (spec = a2a_pb2_grpc .A2AServiceStub )
35+ stub .SendMessage = AsyncMock ()
36+ stub .SendStreamingMessage = AsyncMock ()
37+ stub .GetTask = AsyncMock ()
38+ stub .CancelTask = AsyncMock ()
39+ stub .CreateTaskPushNotification = AsyncMock ()
40+ stub .GetTaskPushNotification = AsyncMock ()
41+ return stub
3442
3543
3644@pytest .fixture
@@ -144,7 +152,10 @@ async def test_send_message_streaming(
144152 artifact_update = TaskArtifactUpdateEvent (
145153 taskId = 'task-stream' ,
146154 contextId = 'ctx-stream' ,
147- artifact = MagicMock (spec = types .Artifact ),
155+ artifact = types .Artifact (
156+ artifactId = 'art-stream' ,
157+ parts = [types .Part (root = types .TextPart (text = 'data' ))],
158+ ),
148159 )
149160 final_task = Task (
150161 id = 'task-stream' ,
@@ -233,8 +244,14 @@ async def test_set_task_callback(
233244 url = 'http://my.callback/push' , token = 'secret'
234245 ),
235246 )
236- proto_config = proto_utils .ToProto .task_push_notification_config (config )
237- mock_grpc_stub .CreateTaskPushNotification .return_value = proto_config
247+ # The gRPC method returns the proto version of TaskPushNotificationConfig, not the inner config
248+ proto_response = a2a_pb2 .TaskPushNotificationConfig (
249+ name = f'tasks/{ task_id } /pushNotifications/{ config .pushNotificationConfig .id or "some_id" } ' ,
250+ push_notification_config = proto_utils .ToProto .push_notification_config (
251+ config .pushNotificationConfig
252+ ),
253+ )
254+ mock_grpc_stub .CreateTaskPushNotification .return_value = proto_response
238255
239256 response = await grpc_client .set_task_callback (config )
240257
@@ -256,14 +273,20 @@ async def test_get_task_callback(
256273 push_id = 'undefined' # As per current implementation
257274 resource_name = f'tasks/{ task_id } /pushNotification/{ push_id } '
258275
259- config = TaskPushNotificationConfig (
276+ config_model = TaskPushNotificationConfig (
260277 taskId = task_id ,
261278 pushNotificationConfig = PushNotificationConfig (
262- url = 'http://my.callback/get' , token = 'secret-get'
279+ id = push_id , url = 'http://my.callback/get' , token = 'secret-get'
280+ ),
281+ )
282+
283+ proto_response = a2a_pb2 .TaskPushNotificationConfig (
284+ name = resource_name ,
285+ push_notification_config = proto_utils .ToProto .push_notification_config (
286+ config_model .pushNotificationConfig
263287 ),
264288 )
265- proto_config = proto_utils .ToProto .task_push_notification_config (config )
266- mock_grpc_stub .GetTaskPushNotification .return_value = proto_config
289+ mock_grpc_stub .GetTaskPushNotification .return_value = proto_response
267290
268291 params = TaskIdParams (id = task_id )
269292 response = await grpc_client .get_task_callback (params )
0 commit comments