diff --git a/src/a2a/utils/__init__.py b/src/a2a/utils/__init__.py index 15f40265..e5b5663d 100644 --- a/src/a2a/utils/__init__.py +++ b/src/a2a/utils/__init__.py @@ -1,6 +1,7 @@ """Utility functions for the A2A Python SDK.""" from a2a.utils.artifact import ( + get_artifact_text, new_artifact, new_data_artifact, new_text_artifact, @@ -18,13 +19,15 @@ create_task_obj, ) from a2a.utils.message import ( - get_data_parts, - get_file_parts, get_message_text, - get_text_parts, new_agent_parts_message, new_agent_text_message, ) +from a2a.utils.parts import ( + get_data_parts, + get_file_parts, + get_text_parts, +) from a2a.utils.task import ( completed_task, new_task, @@ -41,6 +44,7 @@ 'build_text_artifact', 'completed_task', 'create_task_obj', + 'get_artifact_text', 'get_data_parts', 'get_file_parts', 'get_message_text', diff --git a/src/a2a/utils/artifact.py b/src/a2a/utils/artifact.py index 1cf0a89a..03e8adaa 100644 --- a/src/a2a/utils/artifact.py +++ b/src/a2a/utils/artifact.py @@ -5,6 +5,7 @@ from typing import Any from a2a.types import Artifact, DataPart, Part, TextPart +from a2a.utils.parts import get_text_parts def new_artifact( @@ -70,3 +71,16 @@ def new_data_artifact( name, description, ) + + +def get_artifact_text(artifact: Artifact, delimiter: str = '\n') -> str: + """Extracts and joins all text content from an Artifact's parts. + + Args: + artifact: The `Artifact` object. + delimiter: The string to use when joining text from multiple TextParts. + + Returns: + A single string containing all text content, or an empty string if no text parts are found. + """ + return delimiter.join(get_text_parts(artifact.parts)) diff --git a/src/a2a/utils/message.py b/src/a2a/utils/message.py index 4d78cd46..bfd675fd 100644 --- a/src/a2a/utils/message.py +++ b/src/a2a/utils/message.py @@ -2,18 +2,13 @@ import uuid -from typing import Any - from a2a.types import ( - DataPart, - FilePart, - FileWithBytes, - FileWithUri, Message, Part, Role, TextPart, ) +from a2a.utils.parts import get_text_parts def new_agent_text_message( @@ -64,42 +59,6 @@ def new_agent_parts_message( ) -def get_text_parts(parts: list[Part]) -> list[str]: - """Extracts text content from all TextPart objects in a list of Parts. - - Args: - parts: A list of `Part` objects. - - Returns: - A list of strings containing the text content from any `TextPart` objects found. - """ - return [part.root.text for part in parts if isinstance(part.root, TextPart)] - - -def get_data_parts(parts: list[Part]) -> list[dict[str, Any]]: - """Extracts dictionary data from all DataPart objects in a list of Parts. - - Args: - parts: A list of `Part` objects. - - Returns: - A list of dictionaries containing the data from any `DataPart` objects found. - """ - return [part.root.data for part in parts if isinstance(part.root, DataPart)] - - -def get_file_parts(parts: list[Part]) -> list[FileWithBytes | FileWithUri]: - """Extracts file data from all FilePart objects in a list of Parts. - - Args: - parts: A list of `Part` objects. - - Returns: - A list of `FileWithBytes` or `FileWithUri` objects containing the file data from any `FilePart` objects found. - """ - return [part.root.file for part in parts if isinstance(part.root, FilePart)] - - def get_message_text(message: Message, delimiter: str = '\n') -> str: """Extracts and joins all text content from a Message's parts. diff --git a/src/a2a/utils/parts.py b/src/a2a/utils/parts.py new file mode 100644 index 00000000..f32076c8 --- /dev/null +++ b/src/a2a/utils/parts.py @@ -0,0 +1,48 @@ +"""Utility functions for creating and handling A2A Parts objects.""" + +from typing import Any + +from a2a.types import ( + DataPart, + FilePart, + FileWithBytes, + FileWithUri, + Part, + TextPart, +) + + +def get_text_parts(parts: list[Part]) -> list[str]: + """Extracts text content from all TextPart objects in a list of Parts. + + Args: + parts: A list of `Part` objects. + + Returns: + A list of strings containing the text content from any `TextPart` objects found. + """ + return [part.root.text for part in parts if isinstance(part.root, TextPart)] + + +def get_data_parts(parts: list[Part]) -> list[dict[str, Any]]: + """Extracts dictionary data from all DataPart objects in a list of Parts. + + Args: + parts: A list of `Part` objects. + + Returns: + A list of dictionaries containing the data from any `DataPart` objects found. + """ + return [part.root.data for part in parts if isinstance(part.root, DataPart)] + + +def get_file_parts(parts: list[Part]) -> list[FileWithBytes | FileWithUri]: + """Extracts file data from all FilePart objects in a list of Parts. + + Args: + parts: A list of `Part` objects. + + Returns: + A list of `FileWithBytes` or `FileWithUri` objects containing the file data from any `FilePart` objects found. + """ + return [part.root.file for part in parts if isinstance(part.root, FilePart)] diff --git a/tests/utils/test_artifact.py b/tests/utils/test_artifact.py index 132d0567..c3590c17 100644 --- a/tests/utils/test_artifact.py +++ b/tests/utils/test_artifact.py @@ -3,8 +3,14 @@ from unittest.mock import patch -from a2a.types import DataPart, Part, TextPart +from a2a.types import ( + Artifact, + DataPart, + Part, + TextPart, +) from a2a.utils.artifact import ( + get_artifact_text, new_artifact, new_data_artifact, new_text_artifact, @@ -83,5 +89,71 @@ def test_new_data_artifact_assigns_name_description(self): self.assertEqual(artifact.description, description) +class TestGetArtifactText(unittest.TestCase): + def test_get_artifact_text_single_part(self): + # Setup + artifact = Artifact( + name='test-artifact', + parts=[Part(root=TextPart(text='Hello world'))], + artifact_id='test-artifact-id', + ) + + # Exercise + result = get_artifact_text(artifact) + + # Verify + assert result == 'Hello world' + + def test_get_artifact_text_multiple_parts(self): + # Setup + artifact = Artifact( + name='test-artifact', + parts=[ + Part(root=TextPart(text='First line')), + Part(root=TextPart(text='Second line')), + Part(root=TextPart(text='Third line')), + ], + artifact_id='test-artifact-id', + ) + + # Exercise + result = get_artifact_text(artifact) + + # Verify - default delimiter is newline + assert result == 'First line\nSecond line\nThird line' + + def test_get_artifact_text_custom_delimiter(self): + # Setup + artifact = Artifact( + name='test-artifact', + parts=[ + Part(root=TextPart(text='First part')), + Part(root=TextPart(text='Second part')), + Part(root=TextPart(text='Third part')), + ], + artifact_id='test-artifact-id', + ) + + # Exercise + result = get_artifact_text(artifact, delimiter=' | ') + + # Verify + assert result == 'First part | Second part | Third part' + + def test_get_artifact_text_empty_parts(self): + # Setup + artifact = Artifact( + name='test-artifact', + parts=[], + artifact_id='test-artifact-id', + ) + + # Exercise + result = get_artifact_text(artifact) + + # Verify + assert result == '' + + if __name__ == '__main__': unittest.main() diff --git a/tests/utils/test_message.py b/tests/utils/test_message.py index 3270eab7..11523cbd 100644 --- a/tests/utils/test_message.py +++ b/tests/utils/test_message.py @@ -4,19 +4,13 @@ from a2a.types import ( DataPart, - FilePart, - FileWithBytes, - FileWithUri, Message, Part, Role, TextPart, ) from a2a.utils.message import ( - get_data_parts, - get_file_parts, get_message_text, - get_text_parts, new_agent_parts_message, new_agent_text_message, ) @@ -147,177 +141,6 @@ def test_new_agent_parts_message(self): assert message.message_id == 'abcdefab-cdef-abcd-efab-cdefabcdefab' -class TestGetTextParts: - def test_get_text_parts_single_text_part(self): - # Setup - parts = [Part(root=TextPart(text='Hello world'))] - - # Exercise - result = get_text_parts(parts) - - # Verify - assert result == ['Hello world'] - - def test_get_text_parts_multiple_text_parts(self): - # Setup - parts = [ - Part(root=TextPart(text='First part')), - Part(root=TextPart(text='Second part')), - Part(root=TextPart(text='Third part')), - ] - - # Exercise - result = get_text_parts(parts) - - # Verify - assert result == ['First part', 'Second part', 'Third part'] - - def test_get_text_parts_empty_list(self): - # Setup - parts = [] - - # Exercise - result = get_text_parts(parts) - - # Verify - assert result == [] - - -class TestGetDataParts: - def test_get_data_parts_single_data_part(self): - # Setup - parts = [Part(root=DataPart(data={'key': 'value'}))] - - # Exercise - result = get_data_parts(parts) - - # Verify - assert result == [{'key': 'value'}] - - def test_get_data_parts_multiple_data_parts(self): - # Setup - parts = [ - Part(root=DataPart(data={'key1': 'value1'})), - Part(root=DataPart(data={'key2': 'value2'})), - ] - - # Exercise - result = get_data_parts(parts) - - # Verify - assert result == [{'key1': 'value1'}, {'key2': 'value2'}] - - def test_get_data_parts_mixed_parts(self): - # Setup - parts = [ - Part(root=TextPart(text='some text')), - Part(root=DataPart(data={'key1': 'value1'})), - Part(root=DataPart(data={'key2': 'value2'})), - ] - - # Exercise - result = get_data_parts(parts) - - # Verify - assert result == [{'key1': 'value1'}, {'key2': 'value2'}] - - def test_get_data_parts_no_data_parts(self): - # Setup - parts = [ - Part(root=TextPart(text='some text')), - ] - - # Exercise - result = get_data_parts(parts) - - # Verify - assert result == [] - - def test_get_data_parts_empty_list(self): - # Setup - parts = [] - - # Exercise - result = get_data_parts(parts) - - # Verify - assert result == [] - - -class TestGetFileParts: - def test_get_file_parts_single_file_part(self): - # Setup - file_with_uri = FileWithUri( - uri='file://path/to/file', mimeType='text/plain' - ) - parts = [Part(root=FilePart(file=file_with_uri))] - - # Exercise - result = get_file_parts(parts) - - # Verify - assert result == [file_with_uri] - - def test_get_file_parts_multiple_file_parts(self): - # Setup - file_with_uri1 = FileWithUri( - uri='file://path/to/file1', mime_type='text/plain' - ) - file_with_bytes = FileWithBytes( - bytes='ZmlsZSBjb250ZW50', - mime_type='application/octet-stream', # 'file content' - ) - parts = [ - Part(root=FilePart(file=file_with_uri1)), - Part(root=FilePart(file=file_with_bytes)), - ] - - # Exercise - result = get_file_parts(parts) - - # Verify - assert result == [file_with_uri1, file_with_bytes] - - def test_get_file_parts_mixed_parts(self): - # Setup - file_with_uri = FileWithUri( - uri='file://path/to/file', mime_type='text/plain' - ) - parts = [ - Part(root=TextPart(text='some text')), - Part(root=FilePart(file=file_with_uri)), - ] - - # Exercise - result = get_file_parts(parts) - - # Verify - assert result == [file_with_uri] - - def test_get_file_parts_no_file_parts(self): - # Setup - parts = [ - Part(root=TextPart(text='some text')), - Part(root=DataPart(data={'key': 'value'})), - ] - - # Exercise - result = get_file_parts(parts) - - # Verify - assert result == [] - - def test_get_file_parts_empty_list(self): - # Setup - parts = [] - - # Exercise - result = get_file_parts(parts) - - # Verify - assert result == [] - - class TestGetMessageText: def test_get_message_text_single_part(self): # Setup diff --git a/tests/utils/test_parts.py b/tests/utils/test_parts.py new file mode 100644 index 00000000..dcb027c2 --- /dev/null +++ b/tests/utils/test_parts.py @@ -0,0 +1,184 @@ +from a2a.types import ( + DataPart, + FilePart, + FileWithBytes, + FileWithUri, + Part, + TextPart, +) +from a2a.utils.parts import ( + get_data_parts, + get_file_parts, + get_text_parts, +) + + +class TestGetTextParts: + def test_get_text_parts_single_text_part(self): + # Setup + parts = [Part(root=TextPart(text='Hello world'))] + + # Exercise + result = get_text_parts(parts) + + # Verify + assert result == ['Hello world'] + + def test_get_text_parts_multiple_text_parts(self): + # Setup + parts = [ + Part(root=TextPart(text='First part')), + Part(root=TextPart(text='Second part')), + Part(root=TextPart(text='Third part')), + ] + + # Exercise + result = get_text_parts(parts) + + # Verify + assert result == ['First part', 'Second part', 'Third part'] + + def test_get_text_parts_empty_list(self): + # Setup + parts = [] + + # Exercise + result = get_text_parts(parts) + + # Verify + assert result == [] + + +class TestGetDataParts: + def test_get_data_parts_single_data_part(self): + # Setup + parts = [Part(root=DataPart(data={'key': 'value'}))] + + # Exercise + result = get_data_parts(parts) + + # Verify + assert result == [{'key': 'value'}] + + def test_get_data_parts_multiple_data_parts(self): + # Setup + parts = [ + Part(root=DataPart(data={'key1': 'value1'})), + Part(root=DataPart(data={'key2': 'value2'})), + ] + + # Exercise + result = get_data_parts(parts) + + # Verify + assert result == [{'key1': 'value1'}, {'key2': 'value2'}] + + def test_get_data_parts_mixed_parts(self): + # Setup + parts = [ + Part(root=TextPart(text='some text')), + Part(root=DataPart(data={'key1': 'value1'})), + Part(root=DataPart(data={'key2': 'value2'})), + ] + + # Exercise + result = get_data_parts(parts) + + # Verify + assert result == [{'key1': 'value1'}, {'key2': 'value2'}] + + def test_get_data_parts_no_data_parts(self): + # Setup + parts = [ + Part(root=TextPart(text='some text')), + ] + + # Exercise + result = get_data_parts(parts) + + # Verify + assert result == [] + + def test_get_data_parts_empty_list(self): + # Setup + parts = [] + + # Exercise + result = get_data_parts(parts) + + # Verify + assert result == [] + + +class TestGetFileParts: + def test_get_file_parts_single_file_part(self): + # Setup + file_with_uri = FileWithUri( + uri='file://path/to/file', mimeType='text/plain' + ) + parts = [Part(root=FilePart(file=file_with_uri))] + + # Exercise + result = get_file_parts(parts) + + # Verify + assert result == [file_with_uri] + + def test_get_file_parts_multiple_file_parts(self): + # Setup + file_with_uri1 = FileWithUri( + uri='file://path/to/file1', mime_type='text/plain' + ) + file_with_bytes = FileWithBytes( + bytes='ZmlsZSBjb250ZW50', + mime_type='application/octet-stream', # 'file content' + ) + parts = [ + Part(root=FilePart(file=file_with_uri1)), + Part(root=FilePart(file=file_with_bytes)), + ] + + # Exercise + result = get_file_parts(parts) + + # Verify + assert result == [file_with_uri1, file_with_bytes] + + def test_get_file_parts_mixed_parts(self): + # Setup + file_with_uri = FileWithUri( + uri='file://path/to/file', mime_type='text/plain' + ) + parts = [ + Part(root=TextPart(text='some text')), + Part(root=FilePart(file=file_with_uri)), + ] + + # Exercise + result = get_file_parts(parts) + + # Verify + assert result == [file_with_uri] + + def test_get_file_parts_no_file_parts(self): + # Setup + parts = [ + Part(root=TextPart(text='some text')), + Part(root=DataPart(data={'key': 'value'})), + ] + + # Exercise + result = get_file_parts(parts) + + # Verify + assert result == [] + + def test_get_file_parts_empty_list(self): + # Setup + parts = [] + + # Exercise + result = get_file_parts(parts) + + # Verify + assert result == []