Skip to content

Commit 64648c9

Browse files
committed
Update tests
1 parent 1caef80 commit 64648c9

File tree

2 files changed

+42
-14
lines changed

2 files changed

+42
-14
lines changed

tests/client/test_grpc_client.py

Lines changed: 33 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1-
from unittest.mock import AsyncMock, MagicMock
1+
from unittest.mock import AsyncMock
22

33
import grpc
44
import pytest
55

6+
from a2a import types
67
from a2a.client import A2AGrpcClient
78
from a2a.grpc import a2a_pb2, a2a_pb2_grpc
89
from a2a.types import (
@@ -29,8 +30,15 @@
2930
# Fixtures
3031
@pytest.fixture
3132
def mock_grpc_stub() -> AsyncMock:
32-
"""Provides a mock gRPC stub."""
33-
return AsyncMock(spec=a2a_pb2_grpc.A2AServiceStub)
33+
"""Provides a mock gRPC stub with methods mocked."""
34+
stub = AsyncMock(spec=a2a_pb2_grpc.A2AServiceStub)
35+
stub.SendMessage = AsyncMock()
36+
stub.SendStreamingMessage = AsyncMock()
37+
stub.GetTask = AsyncMock()
38+
stub.CancelTask = AsyncMock()
39+
stub.CreateTaskPushNotification = AsyncMock()
40+
stub.GetTaskPushNotification = AsyncMock()
41+
return stub
3442

3543

3644
@pytest.fixture
@@ -144,7 +152,10 @@ async def test_send_message_streaming(
144152
artifact_update = TaskArtifactUpdateEvent(
145153
taskId='task-stream',
146154
contextId='ctx-stream',
147-
artifact=MagicMock(spec=types.Artifact),
155+
artifact=types.Artifact(
156+
artifactId='art-stream',
157+
parts=[types.Part(root=types.TextPart(text='data'))],
158+
),
148159
)
149160
final_task = Task(
150161
id='task-stream',
@@ -233,8 +244,14 @@ async def test_set_task_callback(
233244
url='http://my.callback/push', token='secret'
234245
),
235246
)
236-
proto_config = proto_utils.ToProto.task_push_notification_config(config)
237-
mock_grpc_stub.CreateTaskPushNotification.return_value = proto_config
247+
# The gRPC method returns the proto version of TaskPushNotificationConfig, not the inner config
248+
proto_response = a2a_pb2.TaskPushNotificationConfig(
249+
name=f'tasks/{task_id}/pushNotifications/{config.pushNotificationConfig.id or "some_id"}',
250+
push_notification_config=proto_utils.ToProto.push_notification_config(
251+
config.pushNotificationConfig
252+
),
253+
)
254+
mock_grpc_stub.CreateTaskPushNotification.return_value = proto_response
238255

239256
response = await grpc_client.set_task_callback(config)
240257

@@ -256,14 +273,20 @@ async def test_get_task_callback(
256273
push_id = 'undefined' # As per current implementation
257274
resource_name = f'tasks/{task_id}/pushNotification/{push_id}'
258275

259-
config = TaskPushNotificationConfig(
276+
config_model = TaskPushNotificationConfig(
260277
taskId=task_id,
261278
pushNotificationConfig=PushNotificationConfig(
262-
url='http://my.callback/get', token='secret-get'
279+
id=push_id, url='http://my.callback/get', token='secret-get'
280+
),
281+
)
282+
283+
proto_response = a2a_pb2.TaskPushNotificationConfig(
284+
name=resource_name,
285+
push_notification_config=proto_utils.ToProto.push_notification_config(
286+
config_model.pushNotificationConfig
263287
),
264288
)
265-
proto_config = proto_utils.ToProto.task_push_notification_config(config)
266-
mock_grpc_stub.GetTaskPushNotification.return_value = proto_config
289+
mock_grpc_stub.GetTaskPushNotification.return_value = proto_response
267290

268291
params = TaskIdParams(id=task_id)
269292
response = await grpc_client.get_task_callback(params)

tests/utils/test_proto_utils.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from unittest import mock
2+
13
import pytest
24

35
from a2a import types
@@ -111,12 +113,15 @@ class TestToProto:
111113
def test_part_unsupported_type(self):
112114
"""Test that ToProto.part raises ValueError for an unsupported Part type."""
113115

114-
class FakePart(types.PartBase):
115-
kind: str = 'fake'
116+
class FakePartType:
117+
kind = 'fake'
118+
119+
# Create a mock Part object that has a .root attribute pointing to the fake type
120+
mock_part = mock.MagicMock(spec=types.Part)
121+
mock_part.root = FakePartType()
116122

117-
unsupported_part = types.Part(root=FakePart())
118123
with pytest.raises(ValueError, match='Unsupported part type'):
119-
proto_utils.ToProto.part(unsupported_part)
124+
proto_utils.ToProto.part(mock_part)
120125

121126

122127
class TestFromProto:

0 commit comments

Comments
 (0)