11from unittest import mock
22
33import pytest
4+ from google .protobuf import struct_pb2
45
56from a2a import types
67from a2a .grpc import a2a_pb2
@@ -31,7 +32,30 @@ def sample_message() -> types.Message:
3132 ),
3233 types .Part (root = types .DataPart (data = {'key' : 'value' })),
3334 ],
34- metadata = {'source' : 'test' },
35+ metadata = {'source' : 'test' , 'nested' : {'key' : 'value' }},
36+ )
37+
38+
39+ @pytest .fixture
40+ def sample_message_with_bytes () -> types .Message :
41+ """A sample message that includes a file with raw bytes."""
42+ return types .Message (
43+ message_id = 'msg-2' ,
44+ context_id = 'ctx-2' ,
45+ task_id = 'task-2' ,
46+ role = types .Role .user ,
47+ parts = [
48+ types .Part (
49+ root = types .FilePart (
50+ file = types .FileWithBytes (
51+ bytes = 'file content' ,
52+ name = 'bytes.txt' ,
53+ mime_type = 'text/plain' ,
54+ ),
55+ )
56+ )
57+ ],
58+ metadata = {},
3559 )
3660
3761
@@ -47,6 +71,9 @@ def sample_task(sample_message: types.Message) -> types.Task:
4771 artifacts = [
4872 types .Artifact (
4973 artifact_id = 'art-1' ,
74+ name = 'artifact1' ,
75+ description = 'An artifact' ,
76+ metadata = {'source' : 'test' },
5077 parts = [
5178 types .Part (root = types .TextPart (text = 'Artifact content' ))
5279 ],
@@ -63,7 +90,16 @@ def sample_agent_card() -> types.AgentCard:
6390 url = 'http://localhost' ,
6491 version = '1.0.0' ,
6592 capabilities = types .AgentCapabilities (
66- streaming = True , push_notifications = True
93+ streaming = True ,
94+ push_notifications = True ,
95+ extensions = [
96+ types .AgentExtension (
97+ uri = 'ext:uri' ,
98+ description = 'An extension' ,
99+ params = {'key' : 'value' },
100+ required = True ,
101+ )
102+ ],
67103 ),
68104 default_input_modes = ['text/plain' ],
69105 default_output_modes = ['text/plain' ],
@@ -73,6 +109,9 @@ def sample_agent_card() -> types.AgentCard:
73109 name = 'Test Skill' ,
74110 description = 'A test skill' ,
75111 tags = ['test' ],
112+ examples = ['example1' ],
113+ input_modes = ['text/plain' ],
114+ output_modes = ['text/markdown' ],
76115 )
77116 ],
78117 provider = types .AgentProvider (
@@ -106,10 +145,48 @@ def sample_agent_card() -> types.AgentCard:
106145 open_id_connect_url = 'http://oidc.url'
107146 )
108147 ),
148+ 'mtls' : types .SecurityScheme (
149+ root = types .MutualTLSSecurityScheme (description = 'mTLS auth' )
150+ ),
109151 },
110152 )
111153
112154
155+ @pytest .fixture
156+ def sample_task_status_update () -> types .TaskStatusUpdateEvent :
157+ """Sample TaskStatusUpdateEvent for testing."""
158+ return types .TaskStatusUpdateEvent (
159+ task_id = 'task-1' ,
160+ context_id = 'ctx-1' ,
161+ status = types .TaskStatus (
162+ state = types .TaskState .completed ,
163+ message = types .Message (
164+ message_id = 'msg-final' ,
165+ role = types .Role .agent ,
166+ parts = [types .Part (root = types .TextPart (text = 'Done!' ))],
167+ ),
168+ ),
169+ metadata = {'final_update' : True },
170+ final = True ,
171+ )
172+
173+
174+ @pytest .fixture
175+ def sample_task_artifact_update () -> types .TaskArtifactUpdateEvent :
176+ """Sample TaskArtifactUpdateEvent for testing."""
177+ return types .TaskArtifactUpdateEvent (
178+ task_id = 'task-1' ,
179+ context_id = 'ctx-1' ,
180+ artifact = types .Artifact (
181+ artifact_id = 'art-1' ,
182+ parts = [types .Part (root = types .TextPart (text = ' some data' ))],
183+ ),
184+ metadata = {'chunk' : 2 },
185+ append = True ,
186+ last_chunk = False ,
187+ )
188+
189+
113190# --- Test Cases ---
114191
115192
@@ -127,6 +204,11 @@ class FakePartType:
127204 with pytest .raises (ValueError , match = 'Unsupported part type' ):
128205 proto_utils .ToProto .part (mock_part )
129206
207+ def test_update_event_unsupported_type (self ):
208+ """Test that ToProto.update_event raises ValueError for an unsupported type."""
209+ with pytest .raises (ValueError , match = 'Unsupported event type' ):
210+ proto_utils .ToProto .update_event ('not a valid event' )
211+
130212
131213class TestFromProto :
132214 def test_part_unsupported_type (self ):
@@ -143,6 +225,14 @@ def test_task_query_params_invalid_name(self):
143225 proto_utils .FromProto .task_query_params (request )
144226 assert isinstance (exc_info .value .error , types .InvalidParamsError )
145227
228+ def test_task_push_notification_config_request_invalid_parent (self ):
229+ request = a2a_pb2 .CreateTaskPushNotificationConfigRequest (
230+ parent = 'invalid-parent-format'
231+ )
232+ with pytest .raises (ServerError ) as exc_info :
233+ proto_utils .FromProto .task_push_notification_config_request (request )
234+ assert isinstance (exc_info .value .error , types .InvalidParamsError )
235+
146236
147237class TestProtoUtils :
148238 def test_roundtrip_message (self , sample_message : types .Message ):
@@ -158,6 +248,92 @@ def test_roundtrip_message(self, sample_message: types.Message):
158248 roundtrip_msg = proto_utils .FromProto .message (proto_msg )
159249 assert roundtrip_msg == sample_message
160250
251+ def test_roundtrip_message_with_bytes (
252+ self , sample_message_with_bytes : types .Message
253+ ):
254+ """Test conversion of a Message with file bytes to proto and back."""
255+ proto_msg = proto_utils .ToProto .message (sample_message_with_bytes )
256+ assert isinstance (proto_msg , a2a_pb2 .Message )
257+ assert proto_msg .content [0 ].file .file_with_bytes == b'file content'
258+
259+ roundtrip_msg = proto_utils .FromProto .message (proto_msg )
260+ assert roundtrip_msg == sample_message_with_bytes
261+ # Also check metadata with empty dict
262+ assert roundtrip_msg .metadata == {}
263+
264+ def test_roundtrip_task (self , sample_task : types .Task ):
265+ """Test conversion of Task to proto and back."""
266+ proto_task = proto_utils .ToProto .task (sample_task )
267+ assert isinstance (proto_task , a2a_pb2 .Task )
268+
269+ roundtrip_task = proto_utils .FromProto .task (proto_task )
270+ assert roundtrip_task == sample_task
271+
272+ def test_roundtrip_agent_card (self , sample_agent_card : types .AgentCard ):
273+ """Test conversion of AgentCard to proto and back."""
274+ proto_card = proto_utils .ToProto .agent_card (sample_agent_card )
275+ assert isinstance (proto_card , a2a_pb2 .AgentCard )
276+ assert proto_card .security_schemes ['mtls' ].HasField (
277+ 'mtls_security_scheme'
278+ )
279+
280+ roundtrip_card = proto_utils .FromProto .agent_card (proto_card )
281+ assert roundtrip_card == sample_agent_card
282+
283+ def test_roundtrip_stream_responses (
284+ self ,
285+ sample_task_status_update : types .TaskStatusUpdateEvent ,
286+ sample_task_artifact_update : types .TaskArtifactUpdateEvent ,
287+ sample_task : types .Task ,
288+ sample_message : types .Message ,
289+ ):
290+ """Test roundtrip conversion for all streamable event types."""
291+ # TaskStatusUpdateEvent
292+ proto_status_update = proto_utils .ToProto .stream_response (
293+ sample_task_status_update
294+ )
295+ roundtrip_status_update = proto_utils .FromProto .stream_response (
296+ proto_status_update
297+ )
298+ assert roundtrip_status_update == sample_task_status_update
299+
300+ # TaskArtifactUpdateEvent
301+ proto_artifact_update = proto_utils .ToProto .stream_response (
302+ sample_task_artifact_update
303+ )
304+ roundtrip_artifact_update = proto_utils .FromProto .stream_response (
305+ proto_artifact_update
306+ )
307+ assert roundtrip_artifact_update == sample_task_artifact_update
308+
309+ # Task
310+ proto_task = proto_utils .ToProto .stream_response (sample_task )
311+ roundtrip_task = proto_utils .FromProto .stream_response (proto_task )
312+ assert roundtrip_task == sample_task
313+
314+ # Message
315+ proto_message = proto_utils .ToProto .stream_response (sample_message )
316+ roundtrip_message = proto_utils .FromProto .stream_response (proto_message )
317+ assert roundtrip_message == sample_message
318+
319+ def test_task_or_message_conversion (
320+ self , sample_task : types .Task , sample_message : types .Message
321+ ):
322+ """Test ToProto and FromProto for task_or_message methods."""
323+ # Message case
324+ proto_msg_resp = proto_utils .ToProto .task_or_message (sample_message )
325+ assert proto_msg_resp .HasField ('msg' )
326+ assert not proto_msg_resp .HasField ('task' )
327+ roundtrip_msg = proto_utils .FromProto .task_or_message (proto_msg_resp )
328+ assert roundtrip_msg == sample_message
329+
330+ # Task case
331+ proto_task_resp = proto_utils .ToProto .task_or_message (sample_task )
332+ assert not proto_task_resp .HasField ('msg' )
333+ assert proto_task_resp .HasField ('task' )
334+ roundtrip_task = proto_utils .FromProto .task_or_message (proto_task_resp )
335+ assert roundtrip_task == sample_task
336+
161337 def test_enum_conversions (self ):
162338 """Test conversions for all enum types."""
163339 assert (
@@ -168,6 +344,11 @@ def test_enum_conversions(self):
168344 proto_utils .FromProto .role (a2a_pb2 .Role .ROLE_USER )
169345 == types .Role .user
170346 )
347+ # Test unspecified role defaults to agent
348+ assert (
349+ proto_utils .FromProto .role (a2a_pb2 .Role .ROLE_UNSPECIFIED )
350+ == types .Role .agent
351+ )
171352
172353 for state in types .TaskState :
173354 if state not in (types .TaskState .unknown , types .TaskState .rejected ):
@@ -232,18 +413,68 @@ def test_oauth_flows_conversion(self):
232413 )
233414 assert roundtrip_implicit .implicit is not None
234415
416+ roundtrip_auth_code = proto_utils .FromProto .oauth2_flows (
417+ proto_auth_code_flow
418+ )
419+ assert roundtrip_auth_code .authorization_code is not None
420+
421+ def test_task_id_params_from_proto (self ):
422+ """Test successful parsing of task ID from request names."""
423+ request = a2a_pb2 .CancelTaskRequest (name = 'tasks/task-123' )
424+ params = proto_utils .FromProto .task_id_params (request )
425+ assert params .id == 'task-123'
426+
427+ push_request = a2a_pb2 .GetTaskPushNotificationConfigRequest (
428+ name = 'tasks/task-456/pushNotificationConfigs/config-abc'
429+ )
430+ params = proto_utils .FromProto .task_id_params (push_request )
431+ assert params .id == 'task-456'
432+
235433 def test_task_id_params_from_proto_invalid_name (self ):
236434 request = a2a_pb2 .CancelTaskRequest (name = 'invalid-name-format' )
237435 with pytest .raises (ServerError ) as exc_info :
238436 proto_utils .FromProto .task_id_params (request )
239437 assert isinstance (exc_info .value .error , types .InvalidParamsError )
240438
439+ def test_task_push_config_from_proto (self ):
440+ """Test successful parsing of task push notification config."""
441+ name = 'tasks/task-789/pushNotificationConfigs/config-xyz'
442+ proto_config = a2a_pb2 .TaskPushNotificationConfig (
443+ name = name ,
444+ push_notification_config = a2a_pb2 .PushNotificationConfig (id = 'cfg-1' ),
445+ )
446+ config = proto_utils .FromProto .task_push_notification_config (
447+ proto_config
448+ )
449+ assert config .task_id == 'task-789'
450+ assert config .push_notification_config .id == 'cfg-1'
451+
241452 def test_task_push_config_from_proto_invalid_parent (self ):
242453 request = a2a_pb2 .TaskPushNotificationConfig (name = 'invalid-name-format' )
243454 with pytest .raises (ServerError ) as exc_info :
244455 proto_utils .FromProto .task_push_notification_config (request )
245456 assert isinstance (exc_info .value .error , types .InvalidParamsError )
246457
458+ def test_from_proto_task_query_params (self ):
459+ """Test successful parsing of task query parameters."""
460+ request = a2a_pb2 .GetTaskRequest (
461+ name = 'tasks/task-abc' , history_length = 50
462+ )
463+ params = proto_utils .FromProto .task_query_params (request )
464+ assert params .id == 'task-abc'
465+ assert params .history_length == 50
466+
467+ def test_dict_to_struct (self ):
468+ """Test the dict_to_struct utility function."""
469+ py_dict = {'a' : 1 , 'b' : 'hello' , 'c' : {'d' : True }}
470+ struct = proto_utils .dict_to_struct (py_dict )
471+
472+ assert isinstance (struct , struct_pb2 .Struct )
473+ assert struct ['a' ] == 1
474+ assert struct ['b' ] == 'hello'
475+ assert isinstance (struct ['c' ], struct_pb2 .Struct )
476+ assert struct ['c' ]['d' ] is True
477+
247478 def test_none_handling (self ):
248479 """Test that None inputs are handled gracefully."""
249480 assert proto_utils .ToProto .message (None ) is None
0 commit comments