Skip to content
Draft
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
177 changes: 170 additions & 7 deletions src/a2a/contrib/tasks/vertex_task_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,33 @@
import base64
import json

from typing import Any

from a2a.types import (
Artifact,
DataPart,
FilePart,
FileWithBytes,
FileWithUri,
Message,
Part,
Role,
Task,
TaskState,
TaskStatus,
TextPart,
)


_ORIGINAL_METADATA_KEY = 'originalMetadata'
_EXTENSIONS_KEY = 'extensions'
_REFERENCE_TASK_IDS_KEY = 'referenceTaskIds'
_PART_METADATA_KEY = 'partMetadata'
_PART_TYPES_KEY = 'partTypes'
_METADATA_VERSION_KEY = '__vertex_compat_v'
_METADATA_VERSION_NUMBER = 1.0


_TO_SDK_TASK_STATE = {
vertexai_types.A2aTaskState.STATE_UNSPECIFIED: TaskState.unknown,
vertexai_types.A2aTaskState.SUBMITTED: TaskState.submitted,
Expand Down Expand Up @@ -52,6 +65,51 @@
)


def to_stored_metadata(
original_metadata: dict[str, Any] | None,
extensions: list[str] | None,
reference_task_ids: list[str] | None,
parts: list[Part],
) -> dict[str, Any]:
"""Packs original metadata, extensions, and part types/metadata into a storage dictionary."""
metadata: dict[str, Any] = {_METADATA_VERSION_KEY: _METADATA_VERSION_NUMBER}
if original_metadata:
metadata[_ORIGINAL_METADATA_KEY] = original_metadata
if extensions:
metadata[_EXTENSIONS_KEY] = extensions
if reference_task_ids:
metadata[_REFERENCE_TASK_IDS_KEY] = reference_task_ids

part_types = []
part_metadata = []
for part in parts:
part_types.append('data' if isinstance(part.root, DataPart) else '')
part_metadata.append(part.root.metadata)

metadata[_PART_TYPES_KEY] = part_types
metadata[_PART_METADATA_KEY] = part_metadata

return metadata


def to_sdk_metadata(stored_metadata: dict[str, Any] | None) -> dict[str, Any]:
"""Unpacks metadata, extensions, and part types/metadata from a storage dictionary."""
if not stored_metadata:
return {}

version = stored_metadata.get(_METADATA_VERSION_KEY)
if version is None:
return {'original_metadata': stored_metadata}

return {
'original_metadata': stored_metadata.get(_ORIGINAL_METADATA_KEY),
'extensions': stored_metadata.get(_EXTENSIONS_KEY),
'reference_tasks': stored_metadata.get(_REFERENCE_TASK_IDS_KEY),
'part_metadata': stored_metadata.get(_PART_METADATA_KEY),
'part_types': stored_metadata.get(_PART_TYPES_KEY),
}


def to_stored_part(part: Part) -> genai_types.Part:
"""Converts a SDK Part to a proto Part."""
if isinstance(part.root, TextPart):
Expand Down Expand Up @@ -82,20 +140,32 @@
raise ValueError(f'Unsupported part type: {type(part.root)}')


def to_sdk_part(stored_part: genai_types.Part) -> Part:
def to_sdk_part(
stored_part: genai_types.Part,
part_metadata: dict[str, Any] | None = None,
part_type: str = '',
) -> Part:
"""Converts a proto Part to a SDK Part."""
if stored_part.text:
return Part(root=TextPart(text=stored_part.text))
return Part(
root=TextPart(text=stored_part.text, metadata=part_metadata)
)
if stored_part.inline_data:
mime_type = stored_part.inline_data.mime_type
if part_type == 'data' and mime_type == 'application/json':
data_dict = json.loads(stored_part.inline_data.data or b'{}')
return Part(root=DataPart(data=data_dict, metadata=part_metadata))

encoded_bytes = base64.b64encode(
stored_part.inline_data.data or b''
).decode('utf-8')
return Part(
root=FilePart(
file=FileWithBytes(
mime_type=stored_part.inline_data.mime_type,
mime_type=mime_type,
bytes=encoded_bytes,
)
),
metadata=part_metadata,
)
)
if stored_part.file_data:
Expand All @@ -104,7 +174,8 @@
file=FileWithUri(
mime_type=stored_part.file_data.mime_type,
uri=stored_part.file_data.file_uri,
)
),
metadata=part_metadata,
)
)

Expand All @@ -115,15 +186,98 @@
"""Converts a SDK Artifact to a proto TaskArtifact."""
return vertexai_types.TaskArtifact(
artifact_id=artifact.artifact_id,
display_name=artifact.name,
description=artifact.description,
parts=[to_stored_part(part) for part in artifact.parts],
metadata=to_stored_metadata(
original_metadata=artifact.metadata,
extensions=artifact.extensions,
reference_task_ids=None,
parts=artifact.parts,
),
)


def to_sdk_artifact(stored_artifact: vertexai_types.TaskArtifact) -> Artifact:
"""Converts a proto TaskArtifact to a SDK Artifact."""
unpacked_meta = to_sdk_metadata(stored_artifact.metadata)
part_metadatas = unpacked_meta.get('part_metadata') or []

Check failure on line 204 in src/a2a/contrib/tasks/vertex_task_converter.py

View workflow job for this annotation

GitHub Actions / Check Spelling

`metadatas` is not a recognized word. (unrecognized-spelling)
part_types = unpacked_meta.get('part_types') or []

parts = []
for i, part in enumerate(stored_artifact.parts or []):
meta: dict[str, Any] | None = None
if i < len(part_metadatas):

Check failure on line 210 in src/a2a/contrib/tasks/vertex_task_converter.py

View workflow job for this annotation

GitHub Actions / Check Spelling

`metadatas` is not a recognized word. (unrecognized-spelling)

Check warning on line 210 in src/a2a/contrib/tasks/vertex_task_converter.py

View workflow job for this annotation

GitHub Actions / Check Spelling

`metadatas` is not a recognized word. (unrecognized-spelling)
meta = part_metadatas[i]

Check failure on line 211 in src/a2a/contrib/tasks/vertex_task_converter.py

View workflow job for this annotation

GitHub Actions / Check Spelling

`metadatas` is not a recognized word. (unrecognized-spelling)

Check warning on line 211 in src/a2a/contrib/tasks/vertex_task_converter.py

View workflow job for this annotation

GitHub Actions / Check Spelling

`metadatas` is not a recognized word. (unrecognized-spelling)
ptype = ''
if i < len(part_types):
ptype = part_types[i]
parts.append(to_sdk_part(part, part_metadata=meta, part_type=ptype))

return Artifact(

Check notice on line 217 in src/a2a/contrib/tasks/vertex_task_converter.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

Copy/pasted code

see src/a2a/contrib/tasks/vertex_task_converter.py (258-267)
artifact_id=stored_artifact.artifact_id,
parts=[to_sdk_part(part) for part in stored_artifact.parts],
name=stored_artifact.display_name,
description=stored_artifact.description,
extensions=unpacked_meta.get('extensions'),
metadata=unpacked_meta.get('original_metadata'),
parts=parts,
)


def to_stored_message(
message: Message | None,
) -> vertexai_types.TaskMessage | None:
"""Converts a SDK Message to a proto Message."""
if not message:
return None
role = message.role.value if message.role else ''
return vertexai_types.TaskMessage(
message_id=message.message_id,
role=role,
parts=[to_stored_part(part) for part in message.parts],
metadata=to_stored_metadata(
original_metadata=message.metadata,
extensions=message.extensions,
reference_task_ids=message.reference_task_ids,
parts=message.parts,
),
)


def to_sdk_message(
stored_msg: vertexai_types.TaskMessage | None,
) -> Message | None:
"""Converts a proto Message to a SDK Message."""
if not stored_msg:
return None
unpacked_meta = to_sdk_metadata(stored_msg.metadata)
part_metadatas = unpacked_meta.get('part_metadata') or []

Check failure on line 254 in src/a2a/contrib/tasks/vertex_task_converter.py

View workflow job for this annotation

GitHub Actions / Check Spelling

`metadatas` is not a recognized word. (unrecognized-spelling)

Check warning on line 254 in src/a2a/contrib/tasks/vertex_task_converter.py

View workflow job for this annotation

GitHub Actions / Check Spelling

`metadatas` is not a recognized word. (unrecognized-spelling)
part_types = unpacked_meta.get('part_types') or []

parts = []
for i, part in enumerate(stored_msg.parts or []):
meta: dict[str, Any] | None = None
if i < len(part_metadatas):

Check failure on line 260 in src/a2a/contrib/tasks/vertex_task_converter.py

View workflow job for this annotation

GitHub Actions / Check Spelling

`metadatas` is not a recognized word. (unrecognized-spelling)

Check warning on line 260 in src/a2a/contrib/tasks/vertex_task_converter.py

View workflow job for this annotation

GitHub Actions / Check Spelling

`metadatas` is not a recognized word. (unrecognized-spelling)
meta = part_metadatas[i]
ptype = ''
if i < len(part_types):
ptype = part_types[i]
parts.append(to_sdk_part(part, part_metadata=meta, part_type=ptype))

role = None

Check notice on line 267 in src/a2a/contrib/tasks/vertex_task_converter.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

Copy/pasted code

see src/a2a/contrib/tasks/vertex_task_converter.py (208-217)
if stored_msg.role:
try:
role = Role(stored_msg.role)
except ValueError:
role = None

return Message(
message_id=stored_msg.message_id,
role=role, # type: ignore
extensions=unpacked_meta.get('extensions'),
reference_task_ids=unpacked_meta.get('reference_tasks'),
metadata=unpacked_meta.get('original_metadata'),
parts=parts,
)


Expand All @@ -133,6 +287,11 @@
context_id=task.context_id,
metadata=task.metadata,
state=to_stored_task_state(task.status.state),
status_details=vertexai_types.TaskStatusDetails(
task_message=to_stored_message(task.status.message)
)
if task.status.message
else None,
output=vertexai_types.TaskOutput(
artifacts=[
to_stored_artifact(artifact)
Expand All @@ -144,10 +303,14 @@

def to_sdk_task(a2a_task: vertexai_types.A2aTask) -> Task:
"""Converts a proto A2aTask to a SDK Task."""
msg: Message | None = None
if a2a_task.status_details and a2a_task.status_details.task_message:
msg = to_sdk_message(a2a_task.status_details.task_message)

return Task(
id=a2a_task.name.split('/')[-1],
context_id=a2a_task.context_id,
status=TaskStatus(state=to_sdk_task_state(a2a_task.state)),
status=TaskStatus(state=to_sdk_task_state(a2a_task.state), message=msg),
metadata=a2a_task.metadata or {},
artifacts=[
to_sdk_artifact(artifact)
Expand Down
Loading
Loading