Skip to content

Commit 0b6eedb

Browse files
committed
Increase code coverage for proto_utils.py
1 parent cb3c237 commit 0b6eedb

File tree

1 file changed

+233
-2
lines changed

1 file changed

+233
-2
lines changed

tests/utils/test_proto_utils.py

Lines changed: 233 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from unittest import mock
22

33
import pytest
4+
from google.protobuf import struct_pb2
45

56
from a2a import types
67
from a2a.grpc import a2a_pb2
@@ -31,7 +32,30 @@ def sample_message() -> types.Message:
3132
),
3233
types.Part(root=types.DataPart(data={'key': 'value'})),
3334
],
34-
metadata={'source': 'test'},
35+
metadata={'source': 'test', 'nested': {'key': 'value'}},
36+
)
37+
38+
39+
@pytest.fixture
40+
def sample_message_with_bytes() -> types.Message:
41+
"""A sample message that includes a file with raw bytes."""
42+
return types.Message(
43+
message_id='msg-2',
44+
context_id='ctx-2',
45+
task_id='task-2',
46+
role=types.Role.user,
47+
parts=[
48+
types.Part(
49+
root=types.FilePart(
50+
file=types.FileWithBytes(
51+
bytes='file content',
52+
name='bytes.txt',
53+
mime_type='text/plain',
54+
),
55+
)
56+
)
57+
],
58+
metadata={},
3559
)
3660

3761

@@ -47,6 +71,9 @@ def sample_task(sample_message: types.Message) -> types.Task:
4771
artifacts=[
4872
types.Artifact(
4973
artifact_id='art-1',
74+
name='artifact1',
75+
description='An artifact',
76+
metadata={'source': 'test'},
5077
parts=[
5178
types.Part(root=types.TextPart(text='Artifact content'))
5279
],
@@ -63,7 +90,16 @@ def sample_agent_card() -> types.AgentCard:
6390
url='http://localhost',
6491
version='1.0.0',
6592
capabilities=types.AgentCapabilities(
66-
streaming=True, push_notifications=True
93+
streaming=True,
94+
push_notifications=True,
95+
extensions=[
96+
types.AgentExtension(
97+
uri='ext:uri',
98+
description='An extension',
99+
params={'key': 'value'},
100+
required=True,
101+
)
102+
],
67103
),
68104
default_input_modes=['text/plain'],
69105
default_output_modes=['text/plain'],
@@ -73,6 +109,9 @@ def sample_agent_card() -> types.AgentCard:
73109
name='Test Skill',
74110
description='A test skill',
75111
tags=['test'],
112+
examples=['example1'],
113+
input_modes=['text/plain'],
114+
output_modes=['text/markdown'],
76115
)
77116
],
78117
provider=types.AgentProvider(
@@ -106,10 +145,48 @@ def sample_agent_card() -> types.AgentCard:
106145
open_id_connect_url='http://oidc.url'
107146
)
108147
),
148+
'mtls': types.SecurityScheme(
149+
root=types.MutualTLSSecurityScheme(description='mTLS auth')
150+
),
109151
},
110152
)
111153

112154

155+
@pytest.fixture
156+
def sample_task_status_update() -> types.TaskStatusUpdateEvent:
157+
"""Sample TaskStatusUpdateEvent for testing."""
158+
return types.TaskStatusUpdateEvent(
159+
task_id='task-1',
160+
context_id='ctx-1',
161+
status=types.TaskStatus(
162+
state=types.TaskState.completed,
163+
message=types.Message(
164+
message_id='msg-final',
165+
role=types.Role.agent,
166+
parts=[types.Part(root=types.TextPart(text='Done!'))],
167+
),
168+
),
169+
metadata={'final_update': True},
170+
final=True,
171+
)
172+
173+
174+
@pytest.fixture
175+
def sample_task_artifact_update() -> types.TaskArtifactUpdateEvent:
176+
"""Sample TaskArtifactUpdateEvent for testing."""
177+
return types.TaskArtifactUpdateEvent(
178+
task_id='task-1',
179+
context_id='ctx-1',
180+
artifact=types.Artifact(
181+
artifact_id='art-1',
182+
parts=[types.Part(root=types.TextPart(text=' some data'))],
183+
),
184+
metadata={'chunk': 2},
185+
append=True,
186+
last_chunk=False,
187+
)
188+
189+
113190
# --- Test Cases ---
114191

115192

@@ -127,6 +204,11 @@ class FakePartType:
127204
with pytest.raises(ValueError, match='Unsupported part type'):
128205
proto_utils.ToProto.part(mock_part)
129206

207+
def test_update_event_unsupported_type(self):
208+
"""Test that ToProto.update_event raises ValueError for an unsupported type."""
209+
with pytest.raises(ValueError, match='Unsupported event type'):
210+
proto_utils.ToProto.update_event('not a valid event')
211+
130212

131213
class TestFromProto:
132214
def test_part_unsupported_type(self):
@@ -143,6 +225,14 @@ def test_task_query_params_invalid_name(self):
143225
proto_utils.FromProto.task_query_params(request)
144226
assert isinstance(exc_info.value.error, types.InvalidParamsError)
145227

228+
def test_task_push_notification_config_request_invalid_parent(self):
229+
request = a2a_pb2.CreateTaskPushNotificationConfigRequest(
230+
parent='invalid-parent-format'
231+
)
232+
with pytest.raises(ServerError) as exc_info:
233+
proto_utils.FromProto.task_push_notification_config_request(request)
234+
assert isinstance(exc_info.value.error, types.InvalidParamsError)
235+
146236

147237
class TestProtoUtils:
148238
def test_roundtrip_message(self, sample_message: types.Message):
@@ -158,6 +248,92 @@ def test_roundtrip_message(self, sample_message: types.Message):
158248
roundtrip_msg = proto_utils.FromProto.message(proto_msg)
159249
assert roundtrip_msg == sample_message
160250

251+
def test_roundtrip_message_with_bytes(
252+
self, sample_message_with_bytes: types.Message
253+
):
254+
"""Test conversion of a Message with file bytes to proto and back."""
255+
proto_msg = proto_utils.ToProto.message(sample_message_with_bytes)
256+
assert isinstance(proto_msg, a2a_pb2.Message)
257+
assert proto_msg.content[0].file.file_with_bytes == b'file content'
258+
259+
roundtrip_msg = proto_utils.FromProto.message(proto_msg)
260+
assert roundtrip_msg == sample_message_with_bytes
261+
# Also check metadata with empty dict
262+
assert roundtrip_msg.metadata == {}
263+
264+
def test_roundtrip_task(self, sample_task: types.Task):
265+
"""Test conversion of Task to proto and back."""
266+
proto_task = proto_utils.ToProto.task(sample_task)
267+
assert isinstance(proto_task, a2a_pb2.Task)
268+
269+
roundtrip_task = proto_utils.FromProto.task(proto_task)
270+
assert roundtrip_task == sample_task
271+
272+
def test_roundtrip_agent_card(self, sample_agent_card: types.AgentCard):
273+
"""Test conversion of AgentCard to proto and back."""
274+
proto_card = proto_utils.ToProto.agent_card(sample_agent_card)
275+
assert isinstance(proto_card, a2a_pb2.AgentCard)
276+
assert proto_card.security_schemes['mtls'].HasField(
277+
'mtls_security_scheme'
278+
)
279+
280+
roundtrip_card = proto_utils.FromProto.agent_card(proto_card)
281+
assert roundtrip_card == sample_agent_card
282+
283+
def test_roundtrip_stream_responses(
284+
self,
285+
sample_task_status_update: types.TaskStatusUpdateEvent,
286+
sample_task_artifact_update: types.TaskArtifactUpdateEvent,
287+
sample_task: types.Task,
288+
sample_message: types.Message,
289+
):
290+
"""Test roundtrip conversion for all streamable event types."""
291+
# TaskStatusUpdateEvent
292+
proto_status_update = proto_utils.ToProto.stream_response(
293+
sample_task_status_update
294+
)
295+
roundtrip_status_update = proto_utils.FromProto.stream_response(
296+
proto_status_update
297+
)
298+
assert roundtrip_status_update == sample_task_status_update
299+
300+
# TaskArtifactUpdateEvent
301+
proto_artifact_update = proto_utils.ToProto.stream_response(
302+
sample_task_artifact_update
303+
)
304+
roundtrip_artifact_update = proto_utils.FromProto.stream_response(
305+
proto_artifact_update
306+
)
307+
assert roundtrip_artifact_update == sample_task_artifact_update
308+
309+
# Task
310+
proto_task = proto_utils.ToProto.stream_response(sample_task)
311+
roundtrip_task = proto_utils.FromProto.stream_response(proto_task)
312+
assert roundtrip_task == sample_task
313+
314+
# Message
315+
proto_message = proto_utils.ToProto.stream_response(sample_message)
316+
roundtrip_message = proto_utils.FromProto.stream_response(proto_message)
317+
assert roundtrip_message == sample_message
318+
319+
def test_task_or_message_conversion(
320+
self, sample_task: types.Task, sample_message: types.Message
321+
):
322+
"""Test ToProto and FromProto for task_or_message methods."""
323+
# Message case
324+
proto_msg_resp = proto_utils.ToProto.task_or_message(sample_message)
325+
assert proto_msg_resp.HasField('msg')
326+
assert not proto_msg_resp.HasField('task')
327+
roundtrip_msg = proto_utils.FromProto.task_or_message(proto_msg_resp)
328+
assert roundtrip_msg == sample_message
329+
330+
# Task case
331+
proto_task_resp = proto_utils.ToProto.task_or_message(sample_task)
332+
assert not proto_task_resp.HasField('msg')
333+
assert proto_task_resp.HasField('task')
334+
roundtrip_task = proto_utils.FromProto.task_or_message(proto_task_resp)
335+
assert roundtrip_task == sample_task
336+
161337
def test_enum_conversions(self):
162338
"""Test conversions for all enum types."""
163339
assert (
@@ -168,6 +344,11 @@ def test_enum_conversions(self):
168344
proto_utils.FromProto.role(a2a_pb2.Role.ROLE_USER)
169345
== types.Role.user
170346
)
347+
# Test unspecified role defaults to agent
348+
assert (
349+
proto_utils.FromProto.role(a2a_pb2.Role.ROLE_UNSPECIFIED)
350+
== types.Role.agent
351+
)
171352

172353
for state in types.TaskState:
173354
if state not in (types.TaskState.unknown, types.TaskState.rejected):
@@ -232,18 +413,68 @@ def test_oauth_flows_conversion(self):
232413
)
233414
assert roundtrip_implicit.implicit is not None
234415

416+
roundtrip_auth_code = proto_utils.FromProto.oauth2_flows(
417+
proto_auth_code_flow
418+
)
419+
assert roundtrip_auth_code.authorization_code is not None
420+
421+
def test_task_id_params_from_proto(self):
422+
"""Test successful parsing of task ID from request names."""
423+
request = a2a_pb2.CancelTaskRequest(name='tasks/task-123')
424+
params = proto_utils.FromProto.task_id_params(request)
425+
assert params.id == 'task-123'
426+
427+
push_request = a2a_pb2.GetTaskPushNotificationConfigRequest(
428+
name='tasks/task-456/pushNotificationConfigs/config-abc'
429+
)
430+
params = proto_utils.FromProto.task_id_params(push_request)
431+
assert params.id == 'task-456'
432+
235433
def test_task_id_params_from_proto_invalid_name(self):
236434
request = a2a_pb2.CancelTaskRequest(name='invalid-name-format')
237435
with pytest.raises(ServerError) as exc_info:
238436
proto_utils.FromProto.task_id_params(request)
239437
assert isinstance(exc_info.value.error, types.InvalidParamsError)
240438

439+
def test_task_push_config_from_proto(self):
440+
"""Test successful parsing of task push notification config."""
441+
name = 'tasks/task-789/pushNotificationConfigs/config-xyz'
442+
proto_config = a2a_pb2.TaskPushNotificationConfig(
443+
name=name,
444+
push_notification_config=a2a_pb2.PushNotificationConfig(id='cfg-1'),
445+
)
446+
config = proto_utils.FromProto.task_push_notification_config(
447+
proto_config
448+
)
449+
assert config.task_id == 'task-789'
450+
assert config.push_notification_config.id == 'cfg-1'
451+
241452
def test_task_push_config_from_proto_invalid_parent(self):
242453
request = a2a_pb2.TaskPushNotificationConfig(name='invalid-name-format')
243454
with pytest.raises(ServerError) as exc_info:
244455
proto_utils.FromProto.task_push_notification_config(request)
245456
assert isinstance(exc_info.value.error, types.InvalidParamsError)
246457

458+
def test_from_proto_task_query_params(self):
459+
"""Test successful parsing of task query parameters."""
460+
request = a2a_pb2.GetTaskRequest(
461+
name='tasks/task-abc', history_length=50
462+
)
463+
params = proto_utils.FromProto.task_query_params(request)
464+
assert params.id == 'task-abc'
465+
assert params.history_length == 50
466+
467+
def test_dict_to_struct(self):
468+
"""Test the dict_to_struct utility function."""
469+
py_dict = {'a': 1, 'b': 'hello', 'c': {'d': True}}
470+
struct = proto_utils.dict_to_struct(py_dict)
471+
472+
assert isinstance(struct, struct_pb2.Struct)
473+
assert struct['a'] == 1
474+
assert struct['b'] == 'hello'
475+
assert isinstance(struct['c'], struct_pb2.Struct)
476+
assert struct['c']['d'] is True
477+
247478
def test_none_handling(self):
248479
"""Test that None inputs are handled gracefully."""
249480
assert proto_utils.ToProto.message(None) is None

0 commit comments

Comments
 (0)