From 6259beac0afa624cdb332db7345b373ebf4f47ca Mon Sep 17 00:00:00 2001 From: Holt Skinner Date: Tue, 2 Sep 2025 10:32:06 -0500 Subject: [PATCH 1/2] perf: Improve performance and code style for `proto_utils.py` --- src/a2a/utils/proto_utils.py | 170 ++++++++++++++++------------------- 1 file changed, 76 insertions(+), 94 deletions(-) diff --git a/src/a2a/utils/proto_utils.py b/src/a2a/utils/proto_utils.py index d8c07f7c..92906fb2 100644 --- a/src/a2a/utils/proto_utils.py +++ b/src/a2a/utils/proto_utils.py @@ -18,12 +18,30 @@ # Regexp patterns for matching -_TASK_NAME_MATCH = r'tasks/([\w-]+)' -_TASK_PUSH_CONFIG_NAME_MATCH = ( +_TASK_NAME_MATCH = re.compile(r'tasks/([\w-]+)') +_TASK_PUSH_CONFIG_NAME_MATCH = re.compile( r'tasks/([\w-]+)/pushNotificationConfigs/([\w-]+)' ) +def dict_to_struct(dictionary: dict[str, Any]) -> struct_pb2.Struct: + """Converts a Python dict to a Struct proto. + + Args: + dictionary: The Python dict to convert. + + Returns: + The Struct proto. + """ + struct = struct_pb2.Struct() + for key, val in dictionary.items(): + if isinstance(val, dict): + struct[key] = dict_to_struct(val) + else: + struct[key] = val + return struct + + class ToProto: """Converts Python types to proto types.""" @@ -33,11 +51,11 @@ def message(cls, message: types.Message | None) -> a2a_pb2.Message | None: return None return a2a_pb2.Message( message_id=message.message_id, - content=[ToProto.part(p) for p in message.parts], + content=[cls.part(p) for p in message.parts], context_id=message.context_id or '', task_id=message.task_id or '', role=cls.role(message.role), - metadata=ToProto.metadata(message.metadata), + metadata=cls.metadata(message.metadata), ) @classmethod @@ -53,20 +71,14 @@ def part(cls, part: types.Part) -> a2a_pb2.Part: if isinstance(part.root, types.TextPart): return a2a_pb2.Part(text=part.root.text) if isinstance(part.root, types.FilePart): - return a2a_pb2.Part(file=ToProto.file(part.root.file)) + return a2a_pb2.Part(file=cls.file(part.root.file)) if isinstance(part.root, types.DataPart): - return a2a_pb2.Part(data=ToProto.data(part.root.data)) + return a2a_pb2.Part(data=cls.data(part.root.data)) raise ValueError(f'Unsupported part type: {part.root}') @classmethod def data(cls, data: dict[str, Any]) -> a2a_pb2.DataPart: - json_data = json.dumps(data) - return a2a_pb2.DataPart( - data=json_format.Parse( - json_data, - struct_pb2.Struct(), - ) - ) + return a2a_pb2.DataPart(data=dict_to_struct(data)) @classmethod def file( @@ -87,14 +99,14 @@ def task(cls, task: types.Task) -> a2a_pb2.Task: return a2a_pb2.Task( id=task.id, context_id=task.context_id, - status=ToProto.task_status(task.status), + status=cls.task_status(task.status), artifacts=( - [ToProto.artifact(a) for a in task.artifacts] + [cls.artifact(a) for a in task.artifacts] if task.artifacts else None ), history=( - [ToProto.message(h) for h in task.history] # type: ignore[misc] + [cls.message(h) for h in task.history] # type: ignore[misc] if task.history else None ), @@ -103,8 +115,8 @@ def task(cls, task: types.Task) -> a2a_pb2.Task: @classmethod def task_status(cls, status: types.TaskStatus) -> a2a_pb2.TaskStatus: return a2a_pb2.TaskStatus( - state=ToProto.task_state(status.state), - update=ToProto.message(status.message), + state=cls.task_state(status.state), + update=cls.message(status.message), ) @classmethod @@ -132,9 +144,9 @@ def artifact(cls, artifact: types.Artifact) -> a2a_pb2.Artifact: return a2a_pb2.Artifact( artifact_id=artifact.artifact_id, description=artifact.description, - metadata=ToProto.metadata(artifact.metadata), + metadata=cls.metadata(artifact.metadata), name=artifact.name, - parts=[ToProto.part(p) for p in artifact.parts], + parts=[cls.part(p) for p in artifact.parts], ) @classmethod @@ -151,7 +163,7 @@ def push_notification_config( cls, config: types.PushNotificationConfig ) -> a2a_pb2.PushNotificationConfig: auth_info = ( - ToProto.authentication_info(config.authentication) + cls.authentication_info(config.authentication) if config.authentication else None ) @@ -169,8 +181,8 @@ def task_artifact_update_event( return a2a_pb2.TaskArtifactUpdateEvent( task_id=event.task_id, context_id=event.context_id, - artifact=ToProto.artifact(event.artifact), - metadata=ToProto.metadata(event.metadata), + artifact=cls.artifact(event.artifact), + metadata=cls.metadata(event.metadata), append=event.append or False, last_chunk=event.last_chunk or False, ) @@ -182,8 +194,8 @@ def task_status_update_event( return a2a_pb2.TaskStatusUpdateEvent( task_id=event.task_id, context_id=event.context_id, - status=ToProto.task_status(event.status), - metadata=ToProto.metadata(event.metadata), + status=cls.task_status(event.status), + metadata=cls.metadata(event.metadata), final=event.final, ) @@ -195,7 +207,7 @@ def message_send_configuration( return a2a_pb2.SendMessageConfiguration() return a2a_pb2.SendMessageConfiguration( accepted_output_modes=config.accepted_output_modes, - push_notification=ToProto.push_notification_config( + push_notification=cls.push_notification_config( config.push_notification_config ) if config.push_notification_config @@ -213,19 +225,7 @@ def update_event( | types.TaskArtifactUpdateEvent, ) -> a2a_pb2.StreamResponse: """Converts a task, message, or task update event to a StreamResponse.""" - if isinstance(event, types.TaskStatusUpdateEvent): - return a2a_pb2.StreamResponse( - status_update=ToProto.task_status_update_event(event) - ) - if isinstance(event, types.TaskArtifactUpdateEvent): - return a2a_pb2.StreamResponse( - artifact_update=ToProto.task_artifact_update_event(event) - ) - if isinstance(event, types.Message): - return a2a_pb2.StreamResponse(msg=ToProto.message(event)) - if isinstance(event, types.Task): - return a2a_pb2.StreamResponse(task=ToProto.task(event)) - raise ValueError(f'Unsupported event type: {type(event)}') + return cls.stream_response(event) @classmethod def task_or_message( @@ -257,9 +257,11 @@ def stream_response( return a2a_pb2.StreamResponse( status_update=cls.task_status_update_event(event), ) - return a2a_pb2.StreamResponse( - artifact_update=cls.task_artifact_update_event(event), - ) + if isinstance(event, types.TaskArtifactUpdateEvent): + return a2a_pb2.StreamResponse( + artifact_update=cls.task_artifact_update_event(event), + ) + raise ValueError(f'Unsupported event type: {type(event)}') @classmethod def task_push_notification_config( @@ -480,11 +482,11 @@ class FromProto: def message(cls, message: a2a_pb2.Message) -> types.Message: return types.Message( message_id=message.message_id, - parts=[FromProto.part(p) for p in message.content], + parts=[cls.part(p) for p in message.content], context_id=message.context_id or None, task_id=message.task_id or None, - role=FromProto.role(message.role), - metadata=FromProto.metadata(message.metadata), + role=cls.role(message.role), + metadata=cls.metadata(message.metadata), ) @classmethod @@ -498,13 +500,9 @@ def part(cls, part: a2a_pb2.Part) -> types.Part: if part.HasField('text'): return types.Part(root=types.TextPart(text=part.text)) if part.HasField('file'): - return types.Part( - root=types.FilePart(file=FromProto.file(part.file)) - ) + return types.Part(root=types.FilePart(file=cls.file(part.file))) if part.HasField('data'): - return types.Part( - root=types.DataPart(data=FromProto.data(part.data)) - ) + return types.Part(root=types.DataPart(data=cls.data(part.data))) raise ValueError(f'Unsupported part type: {part}') @classmethod @@ -543,16 +541,16 @@ def task(cls, task: a2a_pb2.Task) -> types.Task: return types.Task( id=task.id, context_id=task.context_id, - status=FromProto.task_status(task.status), - artifacts=[FromProto.artifact(a) for a in task.artifacts], - history=[FromProto.message(h) for h in task.history], + status=cls.task_status(task.status), + artifacts=[cls.artifact(a) for a in task.artifacts], + history=[cls.message(h) for h in task.history], ) @classmethod def task_status(cls, status: a2a_pb2.TaskStatus) -> types.TaskStatus: return types.TaskStatus( - state=FromProto.task_state(status.state), - message=FromProto.message(status.update), + state=cls.task_state(status.state), + message=cls.message(status.update), ) @classmethod @@ -580,9 +578,9 @@ def artifact(cls, artifact: a2a_pb2.Artifact) -> types.Artifact: return types.Artifact( artifact_id=artifact.artifact_id, description=artifact.description, - metadata=FromProto.metadata(artifact.metadata), + metadata=cls.metadata(artifact.metadata), name=artifact.name, - parts=[FromProto.part(p) for p in artifact.parts], + parts=[cls.part(p) for p in artifact.parts], ) @classmethod @@ -592,8 +590,8 @@ def task_artifact_update_event( return types.TaskArtifactUpdateEvent( task_id=event.task_id, context_id=event.context_id, - artifact=FromProto.artifact(event.artifact), - metadata=FromProto.metadata(event.metadata), + artifact=cls.artifact(event.artifact), + metadata=cls.metadata(event.metadata), append=event.append, last_chunk=event.last_chunk, ) @@ -605,8 +603,8 @@ def task_status_update_event( return types.TaskStatusUpdateEvent( task_id=event.task_id, context_id=event.context_id, - status=FromProto.task_status(event.status), - metadata=FromProto.metadata(event.metadata), + status=cls.task_status(event.status), + metadata=cls.metadata(event.metadata), final=event.final, ) @@ -618,7 +616,7 @@ def push_notification_config( id=config.id, url=config.url, token=config.token, - authentication=FromProto.authentication_info(config.authentication) + authentication=cls.authentication_info(config.authentication) if config.HasField('authentication') else None, ) @@ -638,7 +636,7 @@ def message_send_configuration( ) -> types.MessageSendConfiguration: return types.MessageSendConfiguration( accepted_output_modes=list(config.accepted_output_modes), - push_notification_config=FromProto.push_notification_config( + push_notification_config=cls.push_notification_config( config.push_notification ) if config.HasField('push_notification') @@ -666,10 +664,8 @@ def task_id_params( | a2a_pb2.GetTaskPushNotificationConfigRequest ), ) -> types.TaskIdParams: - # This is currently incomplete until the core sdk supports multiple - # configs for a single task. if isinstance(request, a2a_pb2.GetTaskPushNotificationConfigRequest): - m = re.match(_TASK_PUSH_CONFIG_NAME_MATCH, request.name) + m = _TASK_PUSH_CONFIG_NAME_MATCH.match(request.name) if not m: raise ServerError( error=types.InvalidParamsError( @@ -677,7 +673,7 @@ def task_id_params( ) ) return types.TaskIdParams(id=m.group(1)) - m = re.match(_TASK_NAME_MATCH, request.name) + m = _TASK_NAME_MATCH.match(request.name) if not m: raise ServerError( error=types.InvalidParamsError( @@ -691,7 +687,7 @@ def task_push_notification_config_request( cls, request: a2a_pb2.CreateTaskPushNotificationConfigRequest, ) -> types.TaskPushNotificationConfig: - m = re.match(_TASK_NAME_MATCH, request.parent) + m = _TASK_NAME_MATCH.match(request.parent) if not m: raise ServerError( error=types.InvalidParamsError( @@ -710,7 +706,7 @@ def task_push_notification_config( cls, config: a2a_pb2.TaskPushNotificationConfig, ) -> types.TaskPushNotificationConfig: - m = re.match(_TASK_PUSH_CONFIG_NAME_MATCH, config.name) + m = _TASK_PUSH_CONFIG_NAME_MATCH.match(config.name) if not m: raise ServerError( error=types.InvalidParamsError( @@ -767,7 +763,7 @@ def task_query_params( cls, request: a2a_pb2.GetTaskRequest, ) -> types.TaskQueryParams: - m = re.match(_TASK_NAME_MATCH, request.name) + m = _TASK_NAME_MATCH.match(request.name) if not m: raise ServerError( error=types.InvalidParamsError( @@ -862,6 +858,12 @@ def security_scheme( flows=cls.oauth2_flows(scheme.oauth2_security_scheme.flows), ) ) + if scheme.HasField('mtls_security_scheme'): + return types.SecurityScheme( + root=types.MutualTLSSecurityScheme( + description=scheme.mtls_security_scheme.description, + ) + ) return types.SecurityScheme( root=types.OpenIdConnectSecurityScheme( description=scheme.open_id_connect_security_scheme.description, @@ -920,7 +922,9 @@ def stream_response( return cls.task(response.task) if response.HasField('status_update'): return cls.task_status_update_event(response.status_update) - return cls.task_artifact_update_event(response.artifact_update) + if response.HasField('artifact_update'): + return cls.task_artifact_update_event(response.artifact_update) + raise ValueError('Unsupported StreamResponse type') @classmethod def skill(cls, skill: a2a_pb2.AgentSkill) -> types.AgentSkill: @@ -943,25 +947,3 @@ def role(cls, role: a2a_pb2.Role) -> types.Role: return types.Role.agent case _: return types.Role.agent - - -def dict_to_struct(dictionary: dict[str, Any]) -> struct_pb2.Struct: - """Converts a Python dict to a Struct proto. - - Unfortunately, using `json_format.ParseDict` does not work because this - wants the dictionary to be an exact match of the Struct proto with fields - and keys and values, not the traditional Python dict structure. - - Args: - dictionary: The Python dict to convert. - - Returns: - The Struct proto. - """ - struct = struct_pb2.Struct() - for key, val in dictionary.items(): - if isinstance(val, dict): - struct[key] = dict_to_struct(val) - else: - struct[key] = val - return struct From cb3c2377915995244f99c732803333c3638a5b3c Mon Sep 17 00:00:00 2001 From: Holt Skinner Date: Tue, 2 Sep 2025 10:38:33 -0500 Subject: [PATCH 2/2] Re-add comment about `json_format.ParseDict` --- src/a2a/utils/proto_utils.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/a2a/utils/proto_utils.py b/src/a2a/utils/proto_utils.py index 92906fb2..3806e618 100644 --- a/src/a2a/utils/proto_utils.py +++ b/src/a2a/utils/proto_utils.py @@ -27,6 +27,10 @@ def dict_to_struct(dictionary: dict[str, Any]) -> struct_pb2.Struct: """Converts a Python dict to a Struct proto. + Unfortunately, using `json_format.ParseDict` does not work because this + wants the dictionary to be an exact match of the Struct proto with fields + and keys and values, not the traditional Python dict structure. + Args: dictionary: The Python dict to convert.