Skip to content

Commit 118f374

Browse files
authored
Merge branch 'main' into doc/helper-utils
2 parents bd9ef79 + e3a7207 commit 118f374

File tree

3 files changed

+60
-45
lines changed

3 files changed

+60
-45
lines changed

src/a2a/utils/proto_utils.py

Lines changed: 58 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -46,14 +46,7 @@ def metadata(
4646
) -> struct_pb2.Struct | None:
4747
if metadata is None:
4848
return None
49-
return struct_pb2.Struct(
50-
# TODO: Add support for other types.
51-
fields={
52-
key: struct_pb2.Value(string_value=value)
53-
for key, value in metadata.items()
54-
if isinstance(value, str)
55-
}
56-
)
49+
return dict_to_struct(metadata)
5750

5851
@classmethod
5952
def part(cls, part: types.Part) -> a2a_pb2.Part:
@@ -324,6 +317,23 @@ def capabilities(
324317
return a2a_pb2.AgentCapabilities(
325318
streaming=bool(capabilities.streaming),
326319
push_notifications=bool(capabilities.push_notifications),
320+
extensions=[
321+
cls.extension(x) for x in capabilities.extensions or []
322+
],
323+
)
324+
325+
@classmethod
326+
def extension(
327+
cls,
328+
extension: types.AgentExtension,
329+
) -> a2a_pb2.AgentExtension:
330+
return a2a_pb2.AgentExtension(
331+
uri=extension.uri,
332+
description=extension.description,
333+
params=dict_to_struct(extension.params)
334+
if extension.params
335+
else None,
336+
required=extension.required,
327337
)
328338

329339
@classmethod
@@ -477,11 +487,9 @@ def message(cls, message: a2a_pb2.Message) -> types.Message:
477487

478488
@classmethod
479489
def metadata(cls, metadata: struct_pb2.Struct) -> dict[str, Any]:
480-
return {
481-
key: value.string_value
482-
for key, value in metadata.fields.items()
483-
if value.string_value
484-
}
490+
if not metadata.fields:
491+
return {}
492+
return json_format.MessageToDict(metadata)
485493

486494
@classmethod
487495
def part(cls, part: a2a_pb2.Part) -> types.Part:
@@ -777,6 +785,21 @@ def capabilities(
777785
return types.AgentCapabilities(
778786
streaming=capabilities.streaming,
779787
push_notifications=capabilities.push_notifications,
788+
extensions=[
789+
cls.agent_extension(x) for x in capabilities.extensions
790+
],
791+
)
792+
793+
@classmethod
794+
def agent_extension(
795+
cls,
796+
extension: a2a_pb2.AgentExtension,
797+
) -> types.AgentExtension:
798+
return types.AgentExtension(
799+
uri=extension.uri,
800+
description=extension.description,
801+
params=json_format.MessageToDict(extension.params),
802+
required=extension.required,
780803
)
781804

782805
@classmethod
@@ -916,3 +939,25 @@ def role(cls, role: a2a_pb2.Role) -> types.Role:
916939
return types.Role.agent
917940
case _:
918941
return types.Role.agent
942+
943+
944+
def dict_to_struct(dictionary: dict[str, Any]) -> struct_pb2.Struct:
945+
"""Converts a Python dict to a Struct proto.
946+
947+
Unfortunately, using `json_format.ParseDict` does not work because this
948+
wants the dictionary to be an exact match of the Struct proto with fields
949+
and keys and values, not the traditional Python dict structure.
950+
951+
Args:
952+
dictionary: The Python dict to convert.
953+
954+
Returns:
955+
The Struct proto.
956+
"""
957+
struct = struct_pb2.Struct()
958+
for key, val in dictionary.items():
959+
if isinstance(val, dict):
960+
struct[key] = dict_to_struct(val)
961+
else:
962+
struct[key] = val
963+
return struct

src/a2a/utils/task.py

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -28,22 +28,10 @@ def new_task(request: Message) -> Task:
2828
if isinstance(part.root, TextPart) and not part.root.text:
2929
raise ValueError('TextPart content cannot be empty')
3030

31-
context_id_str = request.context_id
32-
if context_id_str is not None:
33-
try:
34-
uuid.UUID(context_id_str)
35-
context_id = context_id_str
36-
except (ValueError, AttributeError, TypeError) as e:
37-
raise ValueError(
38-
f"Invalid context_id: '{context_id_str}' is not a valid UUID."
39-
) from e
40-
else:
41-
context_id = str(uuid.uuid4())
42-
4331
return Task(
4432
status=TaskStatus(state=TaskState.submitted),
45-
id=(request.task_id if request.task_id else str(uuid.uuid4())),
46-
context_id=context_id,
33+
id=request.task_id or str(uuid.uuid4()),
34+
context_id=request.context_id or str(uuid.uuid4()),
4735
history=[request],
4836
)
4937

tests/utils/test_task.py

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -188,24 +188,6 @@ def test_completed_task_invalid_artifact_type(self):
188188
history=[],
189189
)
190190

191-
def test_new_task_with_invalid_context_id(self):
192-
"""Test that new_task raises a ValueError for various invalid context_id formats."""
193-
invalid_ids = ['not-a-uuid', '']
194-
for invalid_id in invalid_ids:
195-
with self.subTest(invalid_id=invalid_id):
196-
with pytest.raises(
197-
ValueError,
198-
match=f"Invalid context_id: '{invalid_id}' is not a valid UUID.",
199-
):
200-
new_task(
201-
Message(
202-
role=Role.user,
203-
parts=[Part(root=TextPart(text='test message'))],
204-
message_id=str(uuid.uuid4()),
205-
context_id=invalid_id,
206-
)
207-
)
208-
209191

210192
if __name__ == '__main__':
211193
unittest.main()

0 commit comments

Comments
 (0)