Skip to content
Merged
71 changes: 58 additions & 13 deletions src/a2a/utils/proto_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,7 @@ def metadata(
) -> struct_pb2.Struct | None:
if metadata is None:
return None
return struct_pb2.Struct(
# TODO: Add support for other types.
fields={
key: struct_pb2.Value(string_value=value)
for key, value in metadata.items()
if isinstance(value, str)
}
)
return dict_to_struct(metadata)

@classmethod
def part(cls, part: types.Part) -> a2a_pb2.Part:
Expand Down Expand Up @@ -324,6 +317,23 @@ def capabilities(
return a2a_pb2.AgentCapabilities(
streaming=bool(capabilities.streaming),
push_notifications=bool(capabilities.push_notifications),
extensions=[
cls.extension(x) for x in capabilities.extensions or []
],
)

@classmethod
def extension(
cls,
extension: types.AgentExtension,
) -> a2a_pb2.AgentExtension:
return a2a_pb2.AgentExtension(
uri=extension.uri,
description=extension.description,
params=dict_to_struct(extension.params)
if extension.params
else None,
required=extension.required,
)

@classmethod
Expand Down Expand Up @@ -477,11 +487,9 @@ def message(cls, message: a2a_pb2.Message) -> types.Message:

@classmethod
def metadata(cls, metadata: struct_pb2.Struct) -> dict[str, Any]:
return {
key: value.string_value
for key, value in metadata.fields.items()
if value.string_value
}
if not metadata.fields:
return {}
return json_format.MessageToDict(metadata)

@classmethod
def part(cls, part: a2a_pb2.Part) -> types.Part:
Expand Down Expand Up @@ -777,6 +785,21 @@ def capabilities(
return types.AgentCapabilities(
streaming=capabilities.streaming,
push_notifications=capabilities.push_notifications,
extensions=[
cls.agent_extension(x) for x in capabilities.extensions
],
)

@classmethod
def agent_extension(
cls,
extension: a2a_pb2.AgentExtension,
) -> types.AgentExtension:
return types.AgentExtension(
uri=extension.uri,
description=extension.description,
params=json_format.MessageToDict(extension.params),
required=extension.required,
)

@classmethod
Expand Down Expand Up @@ -916,3 +939,25 @@ def role(cls, role: a2a_pb2.Role) -> types.Role:
return types.Role.agent
case _:
return types.Role.agent


def dict_to_struct(dictionary: dict[str, Any]) -> struct_pb2.Struct:
"""Converts a Python dict to a Struct proto.

Unforunately, using the json_format.ParseDict does not work because this
wants the dictionary to be an exact match of the Struct proto with fields
and keys and values, not the traditional python dict struture.

Args:
dictionary: The Python dict to convert.

Returns:
The Struct proto.
"""
struct = struct_pb2.Struct()
for key, val in dictionary.items():
if isinstance(val, dict):
struct[key] = dict_to_struct(val)
else:
struct[key] = val
return struct
12 changes: 3 additions & 9 deletions src/a2a/utils/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,10 @@ def new_task(request: Message) -> Task:
raise ValueError('TextPart content cannot be empty')

context_id_str = request.context_id
if context_id_str is not None:
try:
uuid.UUID(context_id_str)
context_id = context_id_str
except (ValueError, AttributeError, TypeError) as e:
raise ValueError(
f"Invalid context_id: '{context_id_str}' is not a valid UUID."
) from e
else:
if not context_id_str:
context_id = str(uuid.uuid4())
else:
context_id = context_id_str

return Task(
status=TaskStatus(state=TaskState.submitted),
Expand Down