Skip to content

Commit eafdd42

Browse files
Merge branch 'main' into fix-proto-utils-metadata-serialization
2 parents 995d554 + e3a7207 commit eafdd42

File tree

5 files changed

+88
-66
lines changed

5 files changed

+88
-66
lines changed

src/a2a/grpc/a2a_pb2.py

Lines changed: 29 additions & 29 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/a2a/grpc/a2a_pb2.pyi

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -488,14 +488,12 @@ class SendMessageRequest(_message.Message):
488488
def __init__(self, request: _Optional[_Union[Message, _Mapping]] = ..., configuration: _Optional[_Union[SendMessageConfiguration, _Mapping]] = ..., metadata: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ...) -> None: ...
489489

490490
class GetTaskRequest(_message.Message):
491-
__slots__ = ("name", "history_length", "metadata")
491+
__slots__ = ("name", "history_length")
492492
NAME_FIELD_NUMBER: _ClassVar[int]
493493
HISTORY_LENGTH_FIELD_NUMBER: _ClassVar[int]
494-
METADATA_FIELD_NUMBER: _ClassVar[int]
495494
name: str
496495
history_length: int
497-
metadata: _struct_pb2.Struct
498-
def __init__(self, name: _Optional[str] = ..., history_length: _Optional[int] = ..., metadata: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ...) -> None: ...
496+
def __init__(self, name: _Optional[str] = ..., history_length: _Optional[int] = ...) -> None: ...
499497

500498
class CancelTaskRequest(_message.Message):
501499
__slots__ = ("name",)

src/a2a/utils/proto_utils.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -372,6 +372,23 @@ def capabilities(
372372
return a2a_pb2.AgentCapabilities(
373373
streaming=bool(capabilities.streaming),
374374
push_notifications=bool(capabilities.push_notifications),
375+
extensions=[
376+
cls.extension(x) for x in capabilities.extensions or []
377+
],
378+
)
379+
380+
@classmethod
381+
def extension(
382+
cls,
383+
extension: types.AgentExtension,
384+
) -> a2a_pb2.AgentExtension:
385+
return a2a_pb2.AgentExtension(
386+
uri=extension.uri,
387+
description=extension.description,
388+
params=dict_to_struct(extension.params)
389+
if extension.params
390+
else None,
391+
required=extension.required,
375392
)
376393

377394
@classmethod
@@ -841,7 +858,7 @@ def task_query_params(
841858
if request.history_length
842859
else None,
843860
id=m.group(1),
844-
metadata=request.metadata,
861+
metadata=None,
845862
)
846863

847864
@classmethod
@@ -851,6 +868,21 @@ def capabilities(
851868
return types.AgentCapabilities(
852869
streaming=capabilities.streaming,
853870
push_notifications=capabilities.push_notifications,
871+
extensions=[
872+
cls.agent_extension(x) for x in capabilities.extensions
873+
],
874+
)
875+
876+
@classmethod
877+
def agent_extension(
878+
cls,
879+
extension: a2a_pb2.AgentExtension,
880+
) -> types.AgentExtension:
881+
return types.AgentExtension(
882+
uri=extension.uri,
883+
description=extension.description,
884+
params=json_format.MessageToDict(extension.params),
885+
required=extension.required,
854886
)
855887

856888
@classmethod
@@ -990,3 +1022,25 @@ def role(cls, role: a2a_pb2.Role) -> types.Role:
9901022
return types.Role.agent
9911023
case _:
9921024
return types.Role.agent
1025+
1026+
1027+
def dict_to_struct(dictionary: dict[str, Any]) -> struct_pb2.Struct:
1028+
"""Converts a Python dict to a Struct proto.
1029+
1030+
Unfortunately, using `json_format.ParseDict` does not work because this
1031+
wants the dictionary to be an exact match of the Struct proto with fields
1032+
and keys and values, not the traditional Python dict structure.
1033+
1034+
Args:
1035+
dictionary: The Python dict to convert.
1036+
1037+
Returns:
1038+
The Struct proto.
1039+
"""
1040+
struct = struct_pb2.Struct()
1041+
for key, val in dictionary.items():
1042+
if isinstance(val, dict):
1043+
struct[key] = dict_to_struct(val)
1044+
else:
1045+
struct[key] = val
1046+
return struct

src/a2a/utils/task.py

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -28,22 +28,10 @@ def new_task(request: Message) -> Task:
2828
if isinstance(part.root, TextPart) and not part.root.text:
2929
raise ValueError('TextPart content cannot be empty')
3030

31-
context_id_str = request.context_id
32-
if context_id_str is not None:
33-
try:
34-
uuid.UUID(context_id_str)
35-
context_id = context_id_str
36-
except (ValueError, AttributeError, TypeError) as e:
37-
raise ValueError(
38-
f"Invalid context_id: '{context_id_str}' is not a valid UUID."
39-
) from e
40-
else:
41-
context_id = str(uuid.uuid4())
42-
4331
return Task(
4432
status=TaskStatus(state=TaskState.submitted),
45-
id=(request.task_id if request.task_id else str(uuid.uuid4())),
46-
context_id=context_id,
33+
id=request.task_id or str(uuid.uuid4()),
34+
context_id=request.context_id or str(uuid.uuid4()),
4735
history=[request],
4836
)
4937

tests/utils/test_task.py

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -188,24 +188,6 @@ def test_completed_task_invalid_artifact_type(self):
188188
history=[],
189189
)
190190

191-
def test_new_task_with_invalid_context_id(self):
192-
"""Test that new_task raises a ValueError for various invalid context_id formats."""
193-
invalid_ids = ['not-a-uuid', '']
194-
for invalid_id in invalid_ids:
195-
with self.subTest(invalid_id=invalid_id):
196-
with pytest.raises(
197-
ValueError,
198-
match=f"Invalid context_id: '{invalid_id}' is not a valid UUID.",
199-
):
200-
new_task(
201-
Message(
202-
role=Role.user,
203-
parts=[Part(root=TextPart(text='test message'))],
204-
message_id=str(uuid.uuid4()),
205-
context_id=invalid_id,
206-
)
207-
)
208-
209191

210192
if __name__ == '__main__':
211193
unittest.main()

0 commit comments

Comments
 (0)