Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 80 additions & 94 deletions src/a2a/utils/proto_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,34 @@


# Regexp patterns for matching
_TASK_NAME_MATCH = r'tasks/([\w-]+)'
_TASK_PUSH_CONFIG_NAME_MATCH = (
_TASK_NAME_MATCH = re.compile(r'tasks/([\w-]+)')
_TASK_PUSH_CONFIG_NAME_MATCH = re.compile(
r'tasks/([\w-]+)/pushNotificationConfigs/([\w-]+)'
)


def dict_to_struct(dictionary: dict[str, Any]) -> struct_pb2.Struct:
"""Converts a Python dict to a Struct proto.

Unfortunately, using `json_format.ParseDict` does not work because this
wants the dictionary to be an exact match of the Struct proto with fields
and keys and values, not the traditional Python dict structure.

Args:
dictionary: The Python dict to convert.

Returns:
The Struct proto.
"""
struct = struct_pb2.Struct()
for key, val in dictionary.items():
if isinstance(val, dict):
struct[key] = dict_to_struct(val)
else:
struct[key] = val
return struct


class ToProto:
"""Converts Python types to proto types."""

Expand All @@ -33,11 +55,11 @@
return None
return a2a_pb2.Message(
message_id=message.message_id,
content=[ToProto.part(p) for p in message.parts],
content=[cls.part(p) for p in message.parts],
context_id=message.context_id or '',
task_id=message.task_id or '',
role=cls.role(message.role),
metadata=ToProto.metadata(message.metadata),
metadata=cls.metadata(message.metadata),
)

@classmethod
Expand All @@ -53,20 +75,14 @@
if isinstance(part.root, types.TextPart):
return a2a_pb2.Part(text=part.root.text)
if isinstance(part.root, types.FilePart):
return a2a_pb2.Part(file=ToProto.file(part.root.file))
return a2a_pb2.Part(file=cls.file(part.root.file))
if isinstance(part.root, types.DataPart):
return a2a_pb2.Part(data=ToProto.data(part.root.data))
return a2a_pb2.Part(data=cls.data(part.root.data))
raise ValueError(f'Unsupported part type: {part.root}')

@classmethod
def data(cls, data: dict[str, Any]) -> a2a_pb2.DataPart:
json_data = json.dumps(data)
return a2a_pb2.DataPart(
data=json_format.Parse(
json_data,
struct_pb2.Struct(),
)
)
return a2a_pb2.DataPart(data=dict_to_struct(data))

@classmethod
def file(
Expand All @@ -87,14 +103,14 @@
return a2a_pb2.Task(
id=task.id,
context_id=task.context_id,
status=ToProto.task_status(task.status),
status=cls.task_status(task.status),
artifacts=(
[ToProto.artifact(a) for a in task.artifacts]
[cls.artifact(a) for a in task.artifacts]
if task.artifacts
else None
),
history=(
[ToProto.message(h) for h in task.history] # type: ignore[misc]
[cls.message(h) for h in task.history] # type: ignore[misc]
if task.history
else None
),
Expand All @@ -103,8 +119,8 @@
@classmethod
def task_status(cls, status: types.TaskStatus) -> a2a_pb2.TaskStatus:
return a2a_pb2.TaskStatus(
state=ToProto.task_state(status.state),
update=ToProto.message(status.message),
state=cls.task_state(status.state),
update=cls.message(status.message),
)

@classmethod
Expand All @@ -129,16 +145,16 @@

@classmethod
def artifact(cls, artifact: types.Artifact) -> a2a_pb2.Artifact:
return a2a_pb2.Artifact(
artifact_id=artifact.artifact_id,
description=artifact.description,
metadata=ToProto.metadata(artifact.metadata),
metadata=cls.metadata(artifact.metadata),
name=artifact.name,
parts=[ToProto.part(p) for p in artifact.parts],
parts=[cls.part(p) for p in artifact.parts],
)

@classmethod
def authentication_info(

Check notice on line 157 in src/a2a/utils/proto_utils.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

Copy/pasted code

see src/a2a/utils/proto_utils.py (582-591)
cls, info: types.PushNotificationAuthenticationInfo
) -> a2a_pb2.AuthenticationInfo:
return a2a_pb2.AuthenticationInfo(
Expand All @@ -151,7 +167,7 @@
cls, config: types.PushNotificationConfig
) -> a2a_pb2.PushNotificationConfig:
auth_info = (
ToProto.authentication_info(config.authentication)
cls.authentication_info(config.authentication)
if config.authentication
else None
)
Expand All @@ -169,8 +185,8 @@
return a2a_pb2.TaskArtifactUpdateEvent(
task_id=event.task_id,
context_id=event.context_id,
artifact=ToProto.artifact(event.artifact),
metadata=ToProto.metadata(event.metadata),
artifact=cls.artifact(event.artifact),
metadata=cls.metadata(event.metadata),
append=event.append or False,
last_chunk=event.last_chunk or False,
)
Expand All @@ -182,8 +198,8 @@
return a2a_pb2.TaskStatusUpdateEvent(
task_id=event.task_id,
context_id=event.context_id,
status=ToProto.task_status(event.status),
metadata=ToProto.metadata(event.metadata),
status=cls.task_status(event.status),
metadata=cls.metadata(event.metadata),
final=event.final,
)

Expand All @@ -195,7 +211,7 @@
return a2a_pb2.SendMessageConfiguration()
return a2a_pb2.SendMessageConfiguration(
accepted_output_modes=config.accepted_output_modes,
push_notification=ToProto.push_notification_config(
push_notification=cls.push_notification_config(
config.push_notification_config
)
if config.push_notification_config
Expand All @@ -213,19 +229,7 @@
| types.TaskArtifactUpdateEvent,
) -> a2a_pb2.StreamResponse:
"""Converts a task, message, or task update event to a StreamResponse."""
if isinstance(event, types.TaskStatusUpdateEvent):
return a2a_pb2.StreamResponse(
status_update=ToProto.task_status_update_event(event)
)
if isinstance(event, types.TaskArtifactUpdateEvent):
return a2a_pb2.StreamResponse(
artifact_update=ToProto.task_artifact_update_event(event)
)
if isinstance(event, types.Message):
return a2a_pb2.StreamResponse(msg=ToProto.message(event))
if isinstance(event, types.Task):
return a2a_pb2.StreamResponse(task=ToProto.task(event))
raise ValueError(f'Unsupported event type: {type(event)}')
return cls.stream_response(event)

@classmethod
def task_or_message(
Expand Down Expand Up @@ -257,9 +261,11 @@
return a2a_pb2.StreamResponse(
status_update=cls.task_status_update_event(event),
)
return a2a_pb2.StreamResponse(
artifact_update=cls.task_artifact_update_event(event),
)
if isinstance(event, types.TaskArtifactUpdateEvent):
return a2a_pb2.StreamResponse(
artifact_update=cls.task_artifact_update_event(event),
)
raise ValueError(f'Unsupported event type: {type(event)}')

@classmethod
def task_push_notification_config(
Expand Down Expand Up @@ -480,11 +486,11 @@
def message(cls, message: a2a_pb2.Message) -> types.Message:
return types.Message(
message_id=message.message_id,
parts=[FromProto.part(p) for p in message.content],
parts=[cls.part(p) for p in message.content],
context_id=message.context_id or None,
task_id=message.task_id or None,
role=FromProto.role(message.role),
metadata=FromProto.metadata(message.metadata),
role=cls.role(message.role),
metadata=cls.metadata(message.metadata),
)

@classmethod
Expand All @@ -498,13 +504,9 @@
if part.HasField('text'):
return types.Part(root=types.TextPart(text=part.text))
if part.HasField('file'):
return types.Part(
root=types.FilePart(file=FromProto.file(part.file))
)
return types.Part(root=types.FilePart(file=cls.file(part.file)))
if part.HasField('data'):
return types.Part(
root=types.DataPart(data=FromProto.data(part.data))
)
return types.Part(root=types.DataPart(data=cls.data(part.data)))
raise ValueError(f'Unsupported part type: {part}')

@classmethod
Expand Down Expand Up @@ -543,16 +545,16 @@
return types.Task(
id=task.id,
context_id=task.context_id,
status=FromProto.task_status(task.status),
artifacts=[FromProto.artifact(a) for a in task.artifacts],
history=[FromProto.message(h) for h in task.history],
status=cls.task_status(task.status),
artifacts=[cls.artifact(a) for a in task.artifacts],
history=[cls.message(h) for h in task.history],
)

@classmethod
def task_status(cls, status: a2a_pb2.TaskStatus) -> types.TaskStatus:
return types.TaskStatus(
state=FromProto.task_state(status.state),
message=FromProto.message(status.update),
state=cls.task_state(status.state),
message=cls.message(status.update),
)

@classmethod
Expand All @@ -577,23 +579,23 @@

@classmethod
def artifact(cls, artifact: a2a_pb2.Artifact) -> types.Artifact:
return types.Artifact(
artifact_id=artifact.artifact_id,
description=artifact.description,
metadata=FromProto.metadata(artifact.metadata),
metadata=cls.metadata(artifact.metadata),
name=artifact.name,
parts=[FromProto.part(p) for p in artifact.parts],
parts=[cls.part(p) for p in artifact.parts],
)

@classmethod
def task_artifact_update_event(

Check notice on line 591 in src/a2a/utils/proto_utils.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

Copy/pasted code

see src/a2a/utils/proto_utils.py (148-157)
cls, event: a2a_pb2.TaskArtifactUpdateEvent
) -> types.TaskArtifactUpdateEvent:
return types.TaskArtifactUpdateEvent(
task_id=event.task_id,
context_id=event.context_id,
artifact=FromProto.artifact(event.artifact),
metadata=FromProto.metadata(event.metadata),
artifact=cls.artifact(event.artifact),
metadata=cls.metadata(event.metadata),
append=event.append,
last_chunk=event.last_chunk,
)
Expand All @@ -605,8 +607,8 @@
return types.TaskStatusUpdateEvent(
task_id=event.task_id,
context_id=event.context_id,
status=FromProto.task_status(event.status),
metadata=FromProto.metadata(event.metadata),
status=cls.task_status(event.status),
metadata=cls.metadata(event.metadata),
final=event.final,
)

Expand All @@ -618,7 +620,7 @@
id=config.id,
url=config.url,
token=config.token,
authentication=FromProto.authentication_info(config.authentication)
authentication=cls.authentication_info(config.authentication)
if config.HasField('authentication')
else None,
)
Expand All @@ -638,7 +640,7 @@
) -> types.MessageSendConfiguration:
return types.MessageSendConfiguration(
accepted_output_modes=list(config.accepted_output_modes),
push_notification_config=FromProto.push_notification_config(
push_notification_config=cls.push_notification_config(
config.push_notification
)
if config.HasField('push_notification')
Expand Down Expand Up @@ -666,18 +668,16 @@
| a2a_pb2.GetTaskPushNotificationConfigRequest
),
) -> types.TaskIdParams:
# This is currently incomplete until the core sdk supports multiple
# configs for a single task.
if isinstance(request, a2a_pb2.GetTaskPushNotificationConfigRequest):
m = re.match(_TASK_PUSH_CONFIG_NAME_MATCH, request.name)
m = _TASK_PUSH_CONFIG_NAME_MATCH.match(request.name)
if not m:
raise ServerError(
error=types.InvalidParamsError(
message=f'No task for {request.name}'
)
)
return types.TaskIdParams(id=m.group(1))
m = re.match(_TASK_NAME_MATCH, request.name)
m = _TASK_NAME_MATCH.match(request.name)
if not m:
raise ServerError(
error=types.InvalidParamsError(
Expand All @@ -691,7 +691,7 @@
cls,
request: a2a_pb2.CreateTaskPushNotificationConfigRequest,
) -> types.TaskPushNotificationConfig:
m = re.match(_TASK_NAME_MATCH, request.parent)
m = _TASK_NAME_MATCH.match(request.parent)
if not m:
raise ServerError(
error=types.InvalidParamsError(
Expand All @@ -710,7 +710,7 @@
cls,
config: a2a_pb2.TaskPushNotificationConfig,
) -> types.TaskPushNotificationConfig:
m = re.match(_TASK_PUSH_CONFIG_NAME_MATCH, config.name)
m = _TASK_PUSH_CONFIG_NAME_MATCH.match(config.name)
if not m:
raise ServerError(
error=types.InvalidParamsError(
Expand Down Expand Up @@ -767,7 +767,7 @@
cls,
request: a2a_pb2.GetTaskRequest,
) -> types.TaskQueryParams:
m = re.match(_TASK_NAME_MATCH, request.name)
m = _TASK_NAME_MATCH.match(request.name)
if not m:
raise ServerError(
error=types.InvalidParamsError(
Expand Down Expand Up @@ -862,6 +862,12 @@
flows=cls.oauth2_flows(scheme.oauth2_security_scheme.flows),
)
)
if scheme.HasField('mtls_security_scheme'):
return types.SecurityScheme(
root=types.MutualTLSSecurityScheme(
description=scheme.mtls_security_scheme.description,
)
)
return types.SecurityScheme(
root=types.OpenIdConnectSecurityScheme(
description=scheme.open_id_connect_security_scheme.description,
Expand Down Expand Up @@ -920,7 +926,9 @@
return cls.task(response.task)
if response.HasField('status_update'):
return cls.task_status_update_event(response.status_update)
return cls.task_artifact_update_event(response.artifact_update)
if response.HasField('artifact_update'):
return cls.task_artifact_update_event(response.artifact_update)
raise ValueError('Unsupported StreamResponse type')

@classmethod
def skill(cls, skill: a2a_pb2.AgentSkill) -> types.AgentSkill:
Expand All @@ -943,25 +951,3 @@
return types.Role.agent
case _:
return types.Role.agent


def dict_to_struct(dictionary: dict[str, Any]) -> struct_pb2.Struct:
"""Converts a Python dict to a Struct proto.

Unfortunately, using `json_format.ParseDict` does not work because this
wants the dictionary to be an exact match of the Struct proto with fields
and keys and values, not the traditional Python dict structure.

Args:
dictionary: The Python dict to convert.

Returns:
The Struct proto.
"""
struct = struct_pb2.Struct()
for key, val in dictionary.items():
if isinstance(val, dict):
struct[key] = dict_to_struct(val)
else:
struct[key] = val
return struct
Loading