Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/a2a/utils/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ def completed_task(
Returns:
A `Task` object with status set to 'completed'.
"""
if not artifacts or not all(isinstance(a, Artifact) for a in artifacts):
raise ValueError("artifacts must be a non-empty list of Artifact objects")

if history is None:
history = []
return Task(
Expand Down
29 changes: 24 additions & 5 deletions tests/utils/test_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@

from unittest.mock import patch

from a2a.types import Message, Part, Role, TextPart
import pytest
from a2a.types import Message, Part, Role, TextPart, Artifact
from a2a.utils.task import completed_task, new_task


Expand Down Expand Up @@ -57,7 +58,7 @@ def test_new_task_initial_message_in_history(self):
def test_completed_task_status(self):
task_id = str(uuid.uuid4())
context_id = str(uuid.uuid4())
artifacts = [] # Artifacts should be of type Artifact
artifacts = [Artifact(artifactId="artifact_1", parts=[Part(root=TextPart(text="some content"))])]
task = completed_task(
task_id=task_id,
context_id=context_id,
Expand All @@ -69,7 +70,7 @@ def test_completed_task_status(self):
def test_completed_task_assigns_ids_and_artifacts(self):
task_id = str(uuid.uuid4())
context_id = str(uuid.uuid4())
artifacts = [] # Artifacts should be of type Artifact
artifacts = [Artifact(artifactId="artifact_1", parts=[Part(root=TextPart(text="some content"))])]
task = completed_task(
task_id=task_id,
context_id=context_id,
Expand All @@ -83,7 +84,7 @@ def test_completed_task_assigns_ids_and_artifacts(self):
def test_completed_task_empty_history_if_not_provided(self):
task_id = str(uuid.uuid4())
context_id = str(uuid.uuid4())
artifacts = [] # Artifacts should be of type Artifact
artifacts = [Artifact(artifactId="artifact_1", parts=[Part(root=TextPart(text="some content"))])]
task = completed_task(
task_id=task_id, context_id=context_id, artifacts=artifacts
)
Expand All @@ -92,7 +93,7 @@ def test_completed_task_empty_history_if_not_provided(self):
def test_completed_task_uses_provided_history(self):
task_id = str(uuid.uuid4())
context_id = str(uuid.uuid4())
artifacts = [] # Artifacts should be of type Artifact
artifacts = [Artifact(artifactId="artifact_1", parts=[Part(root=TextPart(text="some content"))])]
history = [
Message(
role=Role.user,
Expand Down Expand Up @@ -132,6 +133,24 @@ def test_new_task_invalid_message_none_role(self):
)
new_task(msg)

def test_completed_task_empty_artifacts(self):
with pytest.raises(ValueError, match="artifacts must be a non-empty list of Artifact objects"):
completed_task(
task_id="task-123",
context_id="ctx-456",
artifacts=[],
history=[]
)

def test_completed_task_invalid_artifact_type(self):
with pytest.raises(ValueError, match="artifacts must be a non-empty list of Artifact objects"):
completed_task(
task_id="task-123",
context_id="ctx-456",
artifacts=["not an artifact"],
history=[]
)


if __name__ == '__main__':
unittest.main()
Loading