diff --git a/.github/workflows/unit-tests.yml b/.github/workflows/unit-tests.yml index ce8d62ab..cf85bdd3 100644 --- a/.github/workflows/unit-tests.yml +++ b/.github/workflows/unit-tests.yml @@ -55,6 +55,6 @@ jobs: - name: Install dependencies run: uv sync --dev --extra all - name: Run tests and check coverage - run: uv run pytest --cov=a2a --cov-report term --cov-fail-under=88 + run: PYTHONPATH=. uv run pytest --cov=a2a --cov-report term --cov-fail-under=88 - name: Show coverage summary in log run: uv run coverage report diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/builders.py b/tests/builders.py new file mode 100644 index 00000000..c1317242 --- /dev/null +++ b/tests/builders.py @@ -0,0 +1,211 @@ +from dataclasses import dataclass, field +from typing import Any + +from a2a.types import ( + Artifact, + Message, + Part, + Role, + Task, + TaskArtifactUpdateEvent, + TaskState, + TaskStatus, + TaskStatusUpdateEvent, + TextPart, +) + + +@dataclass +class TaskBuilder: + id: str = 'task-default' + context_id: str = 'context-default' + state: TaskState = TaskState.submitted + kind: str = 'task' + artifacts: list[Artifact] = field(default_factory=list) + history: list[Message] = field(default_factory=list) + metadata: dict[str, Any] = field(default_factory=dict) + + def with_id(self, task_id: str) -> 'TaskBuilder': + self.id = task_id + return self + + def with_context_id(self, context_id: str) -> 'TaskBuilder': + self.context_id = context_id + return self + + def with_state(self, state: TaskState) -> 'TaskBuilder': + self.state = state + return self + + def with_metadata(self, **kwargs) -> 'TaskBuilder': + self.metadata.update(kwargs) + return self + + def with_history(self, *messages: Message) -> 'TaskBuilder': + self.history.extend(messages) + return self + + def with_artifacts(self, *artifacts: Artifact) -> 'TaskBuilder': + self.artifacts.extend(artifacts) + return self + + def build(self) -> Task: + return Task( + id=self.id, + context_id=self.context_id, + status=TaskStatus(state=self.state), + kind=self.kind, + artifacts=self.artifacts if self.artifacts else None, + history=self.history if self.history else None, + metadata=self.metadata if self.metadata else None, + ) + + +@dataclass +class MessageBuilder: + role: Role = Role.user + text: str = 'default message' + message_id: str = 'msg-default' + task_id: str | None = None + context_id: str | None = None + + def as_agent(self) -> 'MessageBuilder': + self.role = Role.agent + return self + + def as_user(self) -> 'MessageBuilder': + self.role = Role.user + return self + + def with_text(self, text: str) -> 'MessageBuilder': + self.text = text + return self + + def with_id(self, message_id: str) -> 'MessageBuilder': + self.message_id = message_id + return self + + def with_task_id(self, task_id: str) -> 'MessageBuilder': + self.task_id = task_id + return self + + def with_context_id(self, context_id: str) -> 'MessageBuilder': + self.context_id = context_id + return self + + def build(self) -> Message: + return Message( + role=self.role, + parts=[Part(TextPart(text=self.text))], + message_id=self.message_id, + task_id=self.task_id, + context_id=self.context_id, + ) + + +@dataclass +class ArtifactBuilder: + artifact_id: str = 'artifact-default' + name: str = 'default artifact' + text: str = 'default content' + description: str | None = None + + def with_id(self, artifact_id: str) -> 'ArtifactBuilder': + self.artifact_id = artifact_id + return self + + def with_name(self, name: str) -> 'ArtifactBuilder': + self.name = name + return self + + def with_text(self, text: str) -> 'ArtifactBuilder': + self.text = text + return self + + def with_description(self, description: str) -> 'ArtifactBuilder': + self.description = description + return self + + def build(self) -> Artifact: + return Artifact( + artifact_id=self.artifact_id, + name=self.name, + parts=[Part(TextPart(text=self.text))], + description=self.description, + ) + + +@dataclass +class StatusUpdateEventBuilder: + task_id: str = 'task-default' + context_id: str = 'context-default' + state: TaskState = TaskState.working + message: Message | None = None + final: bool = False + metadata: dict[str, Any] = field(default_factory=dict) + + def for_task(self, task_id: str) -> 'StatusUpdateEventBuilder': + self.task_id = task_id + return self + + def with_state(self, state: TaskState) -> 'StatusUpdateEventBuilder': + self.state = state + return self + + def with_message(self, message: Message) -> 'StatusUpdateEventBuilder': + self.message = message + return self + + def as_final(self) -> 'StatusUpdateEventBuilder': + self.final = True + return self + + def with_metadata(self, **kwargs) -> 'StatusUpdateEventBuilder': + self.metadata.update(kwargs) + return self + + def build(self) -> TaskStatusUpdateEvent: + return TaskStatusUpdateEvent( + task_id=self.task_id, + context_id=self.context_id, + status=TaskStatus(state=self.state, message=self.message), + final=self.final, + metadata=self.metadata if self.metadata else None, + ) + + +@dataclass +class ArtifactUpdateEventBuilder: + task_id: str = 'task-default' + context_id: str = 'context-default' + artifact: Artifact | None = None + append: bool = False + last_chunk: bool = False + + def for_task(self, task_id: str) -> 'ArtifactUpdateEventBuilder': + self.task_id = task_id + return self + + def with_artifact(self, artifact: Artifact) -> 'ArtifactUpdateEventBuilder': + self.artifact = artifact + return self + + def as_append(self) -> 'ArtifactUpdateEventBuilder': + self.append = True + return self + + def as_last_chunk(self) -> 'ArtifactUpdateEventBuilder': + self.last_chunk = True + return self + + def build(self) -> TaskArtifactUpdateEvent: + artifact = self.artifact + if not artifact: + artifact = ArtifactBuilder().build() + return TaskArtifactUpdateEvent( + task_id=self.task_id, + context_id=self.context_id, + artifact=artifact, + append=self.append, + last_chunk=self.last_chunk, + ) diff --git a/tests/fixtures.py b/tests/fixtures.py new file mode 100644 index 00000000..19d75f7b --- /dev/null +++ b/tests/fixtures.py @@ -0,0 +1,118 @@ +import pytest + +from a2a.server.tasks import TaskManager +from a2a.types import TaskState +from tests.builders import ( + ArtifactBuilder, + MessageBuilder, + TaskBuilder, +) +from tests.test_doubles import ( + FakeHttpClient, + InMemoryTaskStore, + SpyEventQueue, + StubPushNotificationConfigStore, +) + + +@pytest.fixture +def task_store(): + return InMemoryTaskStore() + + +@pytest.fixture +def event_queue(): + return SpyEventQueue() + + +@pytest.fixture +def push_config_store(): + return StubPushNotificationConfigStore() + + +@pytest.fixture +def http_client(): + return FakeHttpClient() + + +@pytest.fixture +def task_builder(): + return TaskBuilder() + + +@pytest.fixture +def message_builder(): + return MessageBuilder() + + +@pytest.fixture +def artifact_builder(): + return ArtifactBuilder() + + +@pytest.fixture +def submitted_task(task_builder): + return task_builder.with_state(TaskState.submitted).build() + + +@pytest.fixture +def working_task(task_builder): + return task_builder.with_state(TaskState.working).build() + + +@pytest.fixture +def completed_task(task_builder): + return task_builder.with_state(TaskState.completed).build() + + +@pytest.fixture +def task_with_history(task_builder): + messages = [ + MessageBuilder().as_user().with_text('Hello').build(), + MessageBuilder().as_agent().with_text('Hi there!').build(), + ] + return task_builder.with_history(*messages).build() + + +@pytest.fixture +def task_with_artifacts(task_builder): + artifacts = [ + ArtifactBuilder().with_id('art1').with_name('file.txt').build(), + ArtifactBuilder().with_id('art2').with_name('data.json').build(), + ] + return task_builder.with_artifacts(*artifacts).build() + + +@pytest.fixture +def task_manager(task_store): + return TaskManager( + task_id='task-123', + context_id='context-456', + task_store=task_store, + initial_message=None, + ) + + +@pytest.fixture +def task_manager_factory(task_store): + def factory(task_id=None, context_id=None, initial_message=None): + return TaskManager( + task_id=task_id, + context_id=context_id, + task_store=task_store, + initial_message=initial_message, + ) + + return factory + + +@pytest.fixture +def populated_task_store(task_store): + tasks = [ + TaskBuilder().with_id('task-1').with_state(TaskState.submitted).build(), + TaskBuilder().with_id('task-2').with_state(TaskState.working).build(), + TaskBuilder().with_id('task-3').with_state(TaskState.completed).build(), + ] + for task in tasks: + task_store.set_task(task) + return task_store diff --git a/tests/integration/test_client_server_integration.py b/tests/integration/test_client_server_integration.py index 88d4d3d1..323ff00f 100644 --- a/tests/integration/test_client_server_integration.py +++ b/tests/integration/test_client_server_integration.py @@ -1,4 +1,5 @@ import asyncio + from collections.abc import AsyncGenerator from typing import NamedTuple from unittest.mock import ANY, AsyncMock @@ -7,6 +8,7 @@ import httpx import pytest import pytest_asyncio + from grpc.aio import Channel from a2a.client.transports import JsonRpcTransport, RestTransport @@ -36,6 +38,7 @@ TransportProtocol, ) + # --- Test Constants --- TASK_FROM_STREAM = Task( diff --git a/tests/test_doubles.py b/tests/test_doubles.py new file mode 100644 index 00000000..d52939ba --- /dev/null +++ b/tests/test_doubles.py @@ -0,0 +1,219 @@ +from collections import defaultdict +from typing import Any + +from a2a.server.events.event_queue import Event, EventQueue +from a2a.server.tasks import TaskStore +from a2a.server.tasks.push_notification_config_store import ( + PushNotificationConfigStore, +) +from a2a.types import PushNotificationConfig, Task + + +class InMemoryTaskStore(TaskStore): + def __init__(self): + self._tasks: dict[str, Task] = {} + self._save_count = 0 + self._get_count = 0 + self._delete_count = 0 + + async def save(self, task: Task, context: Any = None) -> None: + self._save_count += 1 + self._tasks[task.id] = task + + async def get(self, task_id: str, context: Any = None) -> Task | None: + self._get_count += 1 + return self._tasks.get(task_id) + + async def delete(self, task_id: str, context: Any = None) -> None: + self._delete_count += 1 + self._tasks.pop(task_id, None) + + def assert_saved(self, task_id: str) -> None: + assert task_id in self._tasks, f'Task {task_id} was not saved' + + def assert_not_saved(self, task_id: str) -> None: + assert task_id not in self._tasks, f'Task {task_id} should not be saved' + + def assert_save_called(self, times: int = 1) -> None: + assert self._save_count == times, ( + f'Expected save to be called {times} times, but was called {self._save_count} times' + ) + + def assert_get_called(self, times: int = 1) -> None: + assert self._get_count == times, ( + f'Expected get to be called {times} times, but was called {self._get_count} times' + ) + + def assert_delete_called(self, times: int = 1) -> None: + assert self._delete_count == times, ( + f'Expected delete to be called {times} times, but was called {self._delete_count} times' + ) + + def get_saved_task(self, task_id: str) -> Task: + assert task_id in self._tasks, f'Task {task_id} not found' + return self._tasks[task_id] + + def set_task(self, task: Task) -> None: + self._tasks[task.id] = task + + def clear(self) -> None: + self._tasks.clear() + self._save_count = 0 + self._get_count = 0 + self._delete_count = 0 + + +class SpyEventQueue(EventQueue): + def __init__(self): + self.events: list[Event] = [] + self._closed = False + + async def publish(self, event: Event) -> None: + if self._closed: + raise RuntimeError('Cannot publish to closed queue') + self.events.append(event) + + async def close(self) -> None: + self._closed = True + + def is_closed(self) -> bool: + return self._closed + + def assert_event_published(self, event_type: type) -> None: + assert any(isinstance(e, event_type) for e in self.events), ( + f'No event of type {event_type.__name__} was published' + ) + + def assert_no_event_published(self, event_type: type) -> None: + assert not any(isinstance(e, event_type) for e in self.events), ( + f'Event of type {event_type.__name__} should not have been published' + ) + + def assert_event_count(self, count: int) -> None: + assert len(self.events) == count, ( + f'Expected {count} events, but got {len(self.events)}' + ) + + def get_events_of_type(self, event_type: type) -> list[Event]: + return [e for e in self.events if isinstance(e, event_type)] + + def get_last_event(self) -> Event | None: + return self.events[-1] if self.events else None + + def clear(self) -> None: + self.events.clear() + self._closed = False + + +class StubPushNotificationConfigStore(PushNotificationConfigStore): + def __init__(self): + self._configs: dict[str, list[PushNotificationConfig]] = defaultdict( + list + ) + self._set_count = 0 + self._get_count = 0 + self._delete_count = 0 + + async def set_info( + self, task_id: str, config: PushNotificationConfig + ) -> None: + self._set_count += 1 + configs = self._configs[task_id] + if config.id: + configs = [c for c in configs if c.id != config.id] + configs.append(config) + self._configs[task_id] = configs + + async def get_info(self, task_id: str) -> list[PushNotificationConfig]: + self._get_count += 1 + return self._configs.get(task_id, []) + + async def delete_info( + self, task_id: str, config_id: str | None = None + ) -> None: + self._delete_count += 1 + if config_id: + self._configs[task_id] = [ + c for c in self._configs.get(task_id, []) if c.id != config_id + ] + else: + self._configs.pop(task_id, None) + + def assert_config_set(self, task_id: str) -> None: + assert task_id in self._configs, f'No config set for task {task_id}' + + def assert_set_called(self, times: int = 1) -> None: + assert self._set_count == times, ( + f'Expected set_info to be called {times} times, but was called {self._set_count} times' + ) + + def get_config(self, task_id: str) -> PushNotificationConfig | None: + configs = self._configs.get(task_id, []) + return configs[0] if configs else None + + def clear(self) -> None: + self._configs.clear() + self._set_count = 0 + self._get_count = 0 + self._delete_count = 0 + + +class FakeHttpClient: + def __init__(self): + self.requests: list[dict[str, Any]] = [] + self.responses: list[dict[str, Any]] = [] + self._response_index = 0 + + def add_response( + self, + status: int, + json: dict | None = None, + text: str | None = None, + ): + self.responses.append({'status': status, 'json': json, 'text': text}) + + async def post(self, url: str, **kwargs): + self.requests.append({'method': 'POST', 'url': url, **kwargs}) + + if self._response_index < len(self.responses): + response = self.responses[self._response_index] + self._response_index += 1 + return FakeResponse( + response['status'], response.get('json'), response.get('text') + ) + + return FakeResponse(200, {}) + + def assert_request_made(self, url: str, method: str = 'POST') -> None: + assert any( + r['url'] == url and r.get('method', 'POST') == method + for r in self.requests + ), f'No {method} request made to {url}' + + def get_last_request(self) -> dict[str, Any] | None: + return self.requests[-1] if self.requests else None + + +class FakeResponse: + def __init__( + self, + status_code: int, + json_data: dict | None = None, + text_data: str | None = None, + ): + self.status_code = status_code + self._json = json_data + self._text = text_data or '' + + def json(self): + if self._json is None: + raise ValueError('No JSON data') + return self._json + + @property + def text(self): + return self._text + + def raise_for_status(self): + if self.status_code >= 400: + raise Exception(f'HTTP {self.status_code}')