diff --git a/src/a2a/utils/task.py b/src/a2a/utils/task.py index 6a91f2d4..9b8c82a9 100644 --- a/src/a2a/utils/task.py +++ b/src/a2a/utils/task.py @@ -55,6 +55,11 @@ def completed_task( Returns: A `Task` object with status set to 'completed'. """ + if not artifacts or not all(isinstance(a, Artifact) for a in artifacts): + raise ValueError( + 'artifacts must be a non-empty list of Artifact objects' + ) + if history is None: history = [] return Task( diff --git a/tests/utils/test_task.py b/tests/utils/test_task.py index 76119f00..0e391b74 100644 --- a/tests/utils/test_task.py +++ b/tests/utils/test_task.py @@ -3,7 +3,9 @@ from unittest.mock import patch -from a2a.types import Message, Part, Role, TextPart +import pytest + +from a2a.types import Artifact, Message, Part, Role, TextPart from a2a.utils.task import completed_task, new_task @@ -57,7 +59,12 @@ def test_new_task_initial_message_in_history(self): def test_completed_task_status(self): task_id = str(uuid.uuid4()) context_id = str(uuid.uuid4()) - artifacts = [] # Artifacts should be of type Artifact + artifacts = [ + Artifact( + artifactId='artifact_1', + parts=[Part(root=TextPart(text='some content'))], + ) + ] task = completed_task( task_id=task_id, context_id=context_id, @@ -69,7 +76,12 @@ def test_completed_task_status(self): def test_completed_task_assigns_ids_and_artifacts(self): task_id = str(uuid.uuid4()) context_id = str(uuid.uuid4()) - artifacts = [] # Artifacts should be of type Artifact + artifacts = [ + Artifact( + artifactId='artifact_1', + parts=[Part(root=TextPart(text='some content'))], + ) + ] task = completed_task( task_id=task_id, context_id=context_id, @@ -83,7 +95,12 @@ def test_completed_task_assigns_ids_and_artifacts(self): def test_completed_task_empty_history_if_not_provided(self): task_id = str(uuid.uuid4()) context_id = str(uuid.uuid4()) - artifacts = [] # Artifacts should be of type Artifact + artifacts = [ + Artifact( + artifactId='artifact_1', + parts=[Part(root=TextPart(text='some content'))], + ) + ] task = completed_task( task_id=task_id, context_id=context_id, artifacts=artifacts ) @@ -92,7 +109,12 @@ def test_completed_task_empty_history_if_not_provided(self): def test_completed_task_uses_provided_history(self): task_id = str(uuid.uuid4()) context_id = str(uuid.uuid4()) - artifacts = [] # Artifacts should be of type Artifact + artifacts = [ + Artifact( + artifactId='artifact_1', + parts=[Part(root=TextPart(text='some content'))], + ) + ] history = [ Message( role=Role.user, @@ -132,6 +154,30 @@ def test_new_task_invalid_message_none_role(self): ) new_task(msg) + def test_completed_task_empty_artifacts(self): + with pytest.raises( + ValueError, + match='artifacts must be a non-empty list of Artifact objects', + ): + completed_task( + task_id='task-123', + context_id='ctx-456', + artifacts=[], + history=[], + ) + + def test_completed_task_invalid_artifact_type(self): + with pytest.raises( + ValueError, + match='artifacts must be a non-empty list of Artifact objects', + ): + completed_task( + task_id='task-123', + context_id='ctx-456', + artifacts=['not an artifact'], + history=[], + ) + if __name__ == '__main__': unittest.main()