diff --git a/src/a2a/contrib/tasks/vertex_task_converter.py b/src/a2a/contrib/tasks/vertex_task_converter.py index 5015211c..f2afe56d 100644 --- a/src/a2a/contrib/tasks/vertex_task_converter.py +++ b/src/a2a/contrib/tasks/vertex_task_converter.py @@ -11,13 +11,18 @@ import base64 import json +from dataclasses import dataclass +from typing import Any + from a2a.types import ( Artifact, DataPart, FilePart, FileWithBytes, FileWithUri, + Message, Part, + Role, Task, TaskState, TaskStatus, @@ -25,6 +30,15 @@ ) +_ORIGINAL_METADATA_KEY = 'originalMetadata' +_EXTENSIONS_KEY = 'extensions' +_REFERENCE_TASK_IDS_KEY = 'referenceTaskIds' +_PART_METADATA_KEY = 'partMetadata' +_PART_TYPES_KEY = 'partTypes' +_METADATA_VERSION_KEY = '__vertex_compat_v' +_METADATA_VERSION_NUMBER = 1.0 + + _TO_SDK_TASK_STATE = { vertexai_types.A2aTaskState.STATE_UNSPECIFIED: TaskState.unknown, vertexai_types.A2aTaskState.SUBMITTED: TaskState.submitted, @@ -52,6 +66,62 @@ def to_stored_task_state(task_state: TaskState) -> vertexai_types.A2aTaskState: ) +def to_stored_metadata( + original_metadata: dict[str, Any] | None, + extensions: list[str] | None, + reference_task_ids: list[str] | None, + parts: list[Part], +) -> dict[str, Any]: + """Packs original metadata, extensions, and part types/metadata into a storage dictionary.""" + metadata: dict[str, Any] = {_METADATA_VERSION_KEY: _METADATA_VERSION_NUMBER} + if original_metadata: + metadata[_ORIGINAL_METADATA_KEY] = original_metadata + if extensions: + metadata[_EXTENSIONS_KEY] = extensions + if reference_task_ids: + metadata[_REFERENCE_TASK_IDS_KEY] = reference_task_ids + + part_types = [] + part_metadata = [] + for part in parts: + part_types.append('data' if isinstance(part.root, DataPart) else '') + part_metadata.append(part.root.metadata) + + metadata[_PART_TYPES_KEY] = part_types + metadata[_PART_METADATA_KEY] = part_metadata + + return metadata + + +@dataclass +class _UnpackedMetadata: + original_metadata: dict[str, Any] | None = None + extensions: list[str] | None = None + reference_task_ids: list[str] | None = None + part_metadata: list[dict[str, Any] | None] | None = None + part_types: list[str] | None = None + + +def to_sdk_metadata( + stored_metadata: dict[str, Any] | None, +) -> _UnpackedMetadata: + """Unpacks metadata, extensions, and part types/metadata from a storage dictionary.""" + if not stored_metadata: + return _UnpackedMetadata() + + version = stored_metadata.get(_METADATA_VERSION_KEY) + if version is None: + return _UnpackedMetadata(original_metadata=stored_metadata) + + return _UnpackedMetadata( + original_metadata=stored_metadata.get(_ORIGINAL_METADATA_KEY), + extensions=stored_metadata.get(_EXTENSIONS_KEY), + reference_task_ids=stored_metadata.get(_REFERENCE_TASK_IDS_KEY), + part_metadata=stored_metadata.get(_PART_METADATA_KEY), + part_types=stored_metadata.get(_PART_TYPES_KEY), + ) + + def to_stored_part(part: Part) -> genai_types.Part: """Converts a SDK Part to a proto Part.""" if isinstance(part.root, TextPart): @@ -82,20 +152,32 @@ def to_stored_part(part: Part) -> genai_types.Part: raise ValueError(f'Unsupported part type: {type(part.root)}') -def to_sdk_part(stored_part: genai_types.Part) -> Part: +def to_sdk_part( + stored_part: genai_types.Part, + part_metadata: dict[str, Any] | None = None, + part_type: str = '', +) -> Part: """Converts a proto Part to a SDK Part.""" if stored_part.text: - return Part(root=TextPart(text=stored_part.text)) + return Part( + root=TextPart(text=stored_part.text, metadata=part_metadata) + ) if stored_part.inline_data: + mime_type = stored_part.inline_data.mime_type + if part_type == 'data' and mime_type == 'application/json': + data_dict = json.loads(stored_part.inline_data.data or b'{}') + return Part(root=DataPart(data=data_dict, metadata=part_metadata)) + encoded_bytes = base64.b64encode( stored_part.inline_data.data or b'' ).decode('utf-8') return Part( root=FilePart( file=FileWithBytes( - mime_type=stored_part.inline_data.mime_type, + mime_type=mime_type, bytes=encoded_bytes, - ) + ), + metadata=part_metadata, ) ) if stored_part.file_data: @@ -103,8 +185,9 @@ def to_sdk_part(stored_part: genai_types.Part) -> Part: root=FilePart( file=FileWithUri( mime_type=stored_part.file_data.mime_type, - uri=stored_part.file_data.file_uri, - ) + uri=stored_part.file_data.file_uri or '', + ), + metadata=part_metadata, ) ) @@ -115,15 +198,93 @@ def to_stored_artifact(artifact: Artifact) -> vertexai_types.TaskArtifact: """Converts a SDK Artifact to a proto TaskArtifact.""" return vertexai_types.TaskArtifact( artifact_id=artifact.artifact_id, + display_name=artifact.name, + description=artifact.description, parts=[to_stored_part(part) for part in artifact.parts], + metadata=to_stored_metadata( + original_metadata=artifact.metadata, + extensions=artifact.extensions, + reference_task_ids=None, + parts=artifact.parts, + ), ) def to_sdk_artifact(stored_artifact: vertexai_types.TaskArtifact) -> Artifact: """Converts a proto TaskArtifact to a SDK Artifact.""" + unpacked_meta = to_sdk_metadata(stored_artifact.metadata) + part_metadata_list = unpacked_meta.part_metadata or [] + part_types = unpacked_meta.part_types or [] + + parts = [] + for i, part in enumerate(stored_artifact.parts or []): + meta: dict[str, Any] | None = None + if i < len(part_metadata_list): + meta = part_metadata_list[i] + ptype = '' + if i < len(part_types): + ptype = part_types[i] + parts.append(to_sdk_part(part, part_metadata=meta, part_type=ptype)) + return Artifact( artifact_id=stored_artifact.artifact_id, - parts=[to_sdk_part(part) for part in stored_artifact.parts], + name=stored_artifact.display_name, + description=stored_artifact.description, + extensions=unpacked_meta.extensions, + metadata=unpacked_meta.original_metadata, + parts=parts, + ) + + +def to_stored_message( + message: Message | None, +) -> vertexai_types.TaskMessage | None: + """Converts a SDK Message to a proto Message.""" + if not message: + return None + role = message.role.value if message.role else '' + return vertexai_types.TaskMessage( + message_id=message.message_id, + role=role, + parts=[to_stored_part(part) for part in message.parts], + metadata=to_stored_metadata( + original_metadata=message.metadata, + extensions=message.extensions, + reference_task_ids=message.reference_task_ids, + parts=message.parts, + ), + ) + + +def to_sdk_message( + stored_msg: vertexai_types.TaskMessage | None, +) -> Message | None: + """Converts a proto Message to a SDK Message.""" + if not stored_msg: + return None + unpacked_meta = to_sdk_metadata(stored_msg.metadata) + part_metadata_list = unpacked_meta.part_metadata or [] + part_types = unpacked_meta.part_types or [] + + parts = [] + for i, part in enumerate(stored_msg.parts or []): + part_metadata: dict[str, Any] | None = None + if i < len(part_metadata_list): + part_metadata = part_metadata_list[i] + part_type = '' + if i < len(part_types): + part_type = part_types[i] + parts.append( + to_sdk_part(part, part_metadata=part_metadata, part_type=part_type) + ) + + return Message( + message_id=stored_msg.message_id, + role=Role(stored_msg.role), + extensions=unpacked_meta.extensions, + reference_task_ids=unpacked_meta.reference_task_ids, + metadata=unpacked_meta.original_metadata, + parts=parts, ) @@ -133,6 +294,11 @@ def to_stored_task(task: Task) -> vertexai_types.A2aTask: context_id=task.context_id, metadata=task.metadata, state=to_stored_task_state(task.status.state), + status_details=vertexai_types.TaskStatusDetails( + task_message=to_stored_message(task.status.message) + ) + if task.status.message + else None, output=vertexai_types.TaskOutput( artifacts=[ to_stored_artifact(artifact) @@ -144,10 +310,14 @@ def to_stored_task(task: Task) -> vertexai_types.A2aTask: def to_sdk_task(a2a_task: vertexai_types.A2aTask) -> Task: """Converts a proto A2aTask to a SDK Task.""" + msg: Message | None = None + if a2a_task.status_details and a2a_task.status_details.task_message: + msg = to_sdk_message(a2a_task.status_details.task_message) + return Task( id=a2a_task.name.split('/')[-1], context_id=a2a_task.context_id, - status=TaskStatus(state=to_sdk_task_state(a2a_task.state)), + status=TaskStatus(state=to_sdk_task_state(a2a_task.state), message=msg), metadata=a2a_task.metadata or {}, artifacts=[ to_sdk_artifact(artifact) diff --git a/tests/contrib/tasks/test_vertex_task_converter.py b/tests/contrib/tasks/test_vertex_task_converter.py index de6ae8cd..7e032b07 100644 --- a/tests/contrib/tasks/test_vertex_task_converter.py +++ b/tests/contrib/tasks/test_vertex_task_converter.py @@ -10,10 +10,12 @@ from google.genai import types as genai_types from a2a.contrib.tasks.vertex_task_converter import ( to_sdk_artifact, + to_sdk_message, to_sdk_part, to_sdk_task, to_sdk_task_state, to_stored_artifact, + to_stored_message, to_stored_part, to_stored_task, to_stored_task_state, @@ -24,7 +26,9 @@ FilePart, FileWithBytes, FileWithUri, + Message, Part, + Role, Task, TaskState, TaskStatus, @@ -313,24 +317,14 @@ def test_sdk_part_text_conversion_round_trip() -> None: def test_sdk_part_data_conversion_round_trip() -> None: - # A DataPart is converted to `inline_data` in Vertex AI, which lacks the original - # `DataPart` vs `FilePart` distinction. When reading it back from the stored - # protocol format, it becomes a `FilePart` with base64-encoded `FileWithBytes` - # and `mime_type="application/json"`. sdk_part = Part(root=DataPart(data={'key': 'value'})) stored_part = to_stored_part(sdk_part) - round_trip_sdk_part = to_sdk_part(stored_part) - - expected_b64 = base64.b64encode(b'{"key": "value"}').decode('utf-8') - assert round_trip_sdk_part == Part( - root=FilePart( - file=FileWithBytes( - bytes=expected_b64, - mime_type='application/json', - ) - ) + round_trip_sdk_part = to_sdk_part( + stored_part, part_metadata=None, part_type='data' ) + assert round_trip_sdk_part == sdk_part + def test_sdk_part_file_bytes_conversion_round_trip() -> None: encoded_b64 = base64.b64encode(b'test data').decode('utf-8') @@ -361,16 +355,6 @@ def test_sdk_part_file_uri_conversion_round_trip() -> None: assert round_trip_sdk_part == sdk_part -def test_sdk_artifact_conversion_round_trip() -> None: - sdk_artifact = Artifact( - artifact_id='art-123', - parts=[Part(root=TextPart(text='part_1'))], - ) - stored_artifact = to_stored_artifact(sdk_artifact) - round_trip_sdk_artifact = to_sdk_artifact(stored_artifact) - assert round_trip_sdk_artifact == sdk_artifact - - def test_sdk_task_conversion_round_trip() -> None: sdk_task = Task( id='task-1', @@ -403,3 +387,88 @@ def test_sdk_task_conversion_round_trip() -> None: assert round_trip_sdk_task.metadata == sdk_task.metadata assert round_trip_sdk_task.artifacts == sdk_task.artifacts assert round_trip_sdk_task.history == [] + + +def test_stored_artifact_conversion_round_trip() -> None: + """Test converting an Artifact to TaskArtifact and back restores everything.""" + original_artifact = Artifact( + artifact_id='art123', + name='My cool artifact', + description='A very interesting description', + extensions=['ext1', 'ext2'], + metadata={'custom': 'value'}, + parts=[ + Part( + root=TextPart( + text='hello', metadata={'part_meta': 'hello_meta'} + ) + ), + Part(root=DataPart(data={'foo': 'bar'})), # no metadata + ], + ) + + stored = to_stored_artifact(original_artifact) + assert isinstance(stored, vertexai_types.TaskArtifact) + + # ensure it was populated correctly + assert stored.display_name == 'My cool artifact' + assert stored.description == 'A very interesting description' + assert stored.metadata['__vertex_compat_v'] == 1.0 + + restored_artifact = to_sdk_artifact(stored) + + assert restored_artifact.artifact_id == original_artifact.artifact_id + assert restored_artifact.name == original_artifact.name + assert restored_artifact.description == original_artifact.description + assert restored_artifact.extensions == original_artifact.extensions + assert restored_artifact.metadata == original_artifact.metadata + + assert len(restored_artifact.parts) == 2 + assert isinstance(restored_artifact.parts[0].root, TextPart) + assert restored_artifact.parts[0].root.text == 'hello' + assert restored_artifact.parts[0].root.metadata == { + 'part_meta': 'hello_meta' + } + + assert isinstance(restored_artifact.parts[1].root, DataPart) + assert restored_artifact.parts[1].root.data == {'foo': 'bar'} + assert restored_artifact.parts[1].root.metadata is None + + +def test_stored_message_conversion_round_trip() -> None: + """Test converting a Message to TaskMessage and back restores everything.""" + original_message = Message( + message_id='msg456', + role=Role.agent, + reference_task_ids=['tsk2', 'tsk3'], + extensions=['ext_msg'], + metadata={'msg_meta': 42}, + parts=[ + Part(root=TextPart(text='message text')), + ], + ) + + stored = to_stored_message(original_message) + assert stored is not None + assert isinstance(stored, vertexai_types.TaskMessage) + + assert stored.message_id == 'msg456' + assert stored.role == 'agent' + assert stored.metadata['__vertex_compat_v'] == 1.0 + + restored_message = to_sdk_message(stored) + assert restored_message is not None + + assert restored_message.message_id == original_message.message_id + assert restored_message.role == original_message.role + assert ( + restored_message.reference_task_ids + == original_message.reference_task_ids + ) + assert restored_message.extensions == original_message.extensions + assert restored_message.metadata == original_message.metadata + + assert len(restored_message.parts) == 1 + assert isinstance(restored_message.parts[0].root, TextPart) + assert restored_message.parts[0].root.text == 'message text' + assert restored_message.parts[0].root.metadata is None