Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/a2a/utils/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,11 @@ def completed_task(
Returns:
A `Task` object with status set to 'completed'.
"""
if not artifacts or not all(isinstance(a, Artifact) for a in artifacts):
raise ValueError(
'artifacts must be a non-empty list of Artifact objects'
)

if history is None:
history = []
return Task(
Expand Down
56 changes: 51 additions & 5 deletions tests/utils/test_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@

from unittest.mock import patch

from a2a.types import Message, Part, Role, TextPart
import pytest

from a2a.types import Artifact, Message, Part, Role, TextPart
from a2a.utils.task import completed_task, new_task


Expand Down Expand Up @@ -57,7 +59,12 @@ def test_new_task_initial_message_in_history(self):
def test_completed_task_status(self):
task_id = str(uuid.uuid4())
context_id = str(uuid.uuid4())
artifacts = [] # Artifacts should be of type Artifact
artifacts = [
Artifact(
artifactId='artifact_1',
parts=[Part(root=TextPart(text='some content'))],
)
]
task = completed_task(
task_id=task_id,
context_id=context_id,
Expand All @@ -69,7 +76,12 @@ def test_completed_task_status(self):
def test_completed_task_assigns_ids_and_artifacts(self):
task_id = str(uuid.uuid4())
context_id = str(uuid.uuid4())
artifacts = [] # Artifacts should be of type Artifact
artifacts = [
Artifact(
artifactId='artifact_1',
parts=[Part(root=TextPart(text='some content'))],
)
]
task = completed_task(
task_id=task_id,
context_id=context_id,
Expand All @@ -83,7 +95,12 @@ def test_completed_task_assigns_ids_and_artifacts(self):
def test_completed_task_empty_history_if_not_provided(self):
task_id = str(uuid.uuid4())
context_id = str(uuid.uuid4())
artifacts = [] # Artifacts should be of type Artifact
artifacts = [
Artifact(
artifactId='artifact_1',
parts=[Part(root=TextPart(text='some content'))],
)
]
task = completed_task(
task_id=task_id, context_id=context_id, artifacts=artifacts
)
Expand All @@ -92,7 +109,12 @@ def test_completed_task_empty_history_if_not_provided(self):
def test_completed_task_uses_provided_history(self):
task_id = str(uuid.uuid4())
context_id = str(uuid.uuid4())
artifacts = [] # Artifacts should be of type Artifact
artifacts = [
Artifact(
artifactId='artifact_1',
parts=[Part(root=TextPart(text='some content'))],
)
]
history = [
Message(
role=Role.user,
Expand Down Expand Up @@ -132,6 +154,30 @@ def test_new_task_invalid_message_none_role(self):
)
new_task(msg)

def test_completed_task_empty_artifacts(self):
with pytest.raises(
ValueError,
match='artifacts must be a non-empty list of Artifact objects',
):
completed_task(
task_id='task-123',
context_id='ctx-456',
artifacts=[],
history=[],
)

def test_completed_task_invalid_artifact_type(self):
with pytest.raises(
ValueError,
match='artifacts must be a non-empty list of Artifact objects',
):
completed_task(
task_id='task-123',
context_id='ctx-456',
artifacts=['not an artifact'],
history=[],
)


if __name__ == '__main__':
unittest.main()
Loading