Skip to content

Commit 019ded4

Browse files
committed
Add tests for BaseClient and ClientTaskManager
1 parent 8a81b93 commit 019ded4

File tree

3 files changed

+282
-1
lines changed

3 files changed

+282
-1
lines changed

src/a2a/client/client_factory.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
try:
2424
from a2a.client.transports.grpc import GrpcTransport
2525
except ImportError:
26-
GrpcTransport = None # type: ignore
26+
GrpcTransport = None # type: ignore # pyright: ignore
2727

2828

2929
logger = logging.getLogger(__name__)

tests/client/test_base_client.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
from unittest.mock import AsyncMock, MagicMock, patch
2+
3+
import pytest
4+
5+
from a2a.client.base_client import BaseClient
6+
from a2a.client.client import ClientConfig
7+
from a2a.client.transports.base import ClientTransport
8+
from a2a.types import (
9+
AgentCapabilities,
10+
AgentCard,
11+
Message,
12+
Part,
13+
Role,
14+
Task,
15+
TaskState,
16+
TaskStatus,
17+
TextPart,
18+
)
19+
20+
21+
@pytest.fixture
22+
def mock_transport():
23+
transport = AsyncMock(spec=ClientTransport)
24+
return transport
25+
26+
27+
@pytest.fixture
28+
def sample_agent_card():
29+
return AgentCard(
30+
name='Test Agent',
31+
description='An agent for testing',
32+
url='http://test.com',
33+
version='1.0',
34+
capabilities=AgentCapabilities(streaming=True),
35+
default_input_modes=['text/plain'],
36+
default_output_modes=['text/plain'],
37+
skills=[],
38+
)
39+
40+
41+
@pytest.fixture
42+
def sample_message():
43+
return Message(
44+
role=Role.user,
45+
message_id='msg-1',
46+
parts=[Part(root=TextPart(text='Hello'))],
47+
)
48+
49+
50+
@pytest.fixture
51+
def base_client(sample_agent_card, mock_transport):
52+
config = ClientConfig(streaming=True)
53+
return BaseClient(
54+
card=sample_agent_card,
55+
config=config,
56+
transport=mock_transport,
57+
consumers=[],
58+
middleware=[],
59+
)
60+
61+
62+
@pytest.mark.asyncio
63+
async def test_send_message_streaming(
64+
base_client: BaseClient, mock_transport: MagicMock, sample_message: Message
65+
):
66+
async def create_stream(*args, **kwargs):
67+
yield Task(
68+
id='task-123',
69+
context_id='ctx-456',
70+
status=TaskStatus(state=TaskState.completed),
71+
)
72+
73+
mock_transport.send_message_streaming.return_value = create_stream()
74+
75+
events = [event async for event in base_client.send_message(sample_message)]
76+
77+
mock_transport.send_message_streaming.assert_called_once()
78+
assert not mock_transport.send_message.called
79+
assert len(events) == 1
80+
assert events[0][0].id == 'task-123'
81+
82+
83+
@pytest.mark.asyncio
84+
async def test_send_message_non_streaming(
85+
base_client: BaseClient, mock_transport: MagicMock, sample_message: Message
86+
):
87+
base_client._config.streaming = False
88+
mock_transport.send_message.return_value = Task(
89+
id='task-456',
90+
context_id='ctx-789',
91+
status=TaskStatus(state=TaskState.completed),
92+
)
93+
94+
events = [event async for event in base_client.send_message(sample_message)]
95+
96+
mock_transport.send_message.assert_called_once()
97+
assert not mock_transport.send_message_streaming.called
98+
assert len(events) == 1
99+
assert events[0][0].id == 'task-456'
100+
101+
102+
@pytest.mark.asyncio
103+
async def test_send_message_non_streaming_agent_capability_false(
104+
base_client: BaseClient, mock_transport: MagicMock, sample_message: Message
105+
):
106+
base_client._card.capabilities.streaming = False
107+
mock_transport.send_message.return_value = Task(
108+
id='task-789',
109+
context_id='ctx-101',
110+
status=TaskStatus(state=TaskState.completed),
111+
)
112+
113+
events = [event async for event in base_client.send_message(sample_message)]
114+
115+
mock_transport.send_message.assert_called_once()
116+
assert not mock_transport.send_message_streaming.called
117+
assert len(events) == 1
118+
assert events[0][0].id == 'task-789'
Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
import pytest
2+
from unittest.mock import AsyncMock, Mock, patch
3+
from a2a.client.client_task_manager import ClientTaskManager
4+
from a2a.client.errors import A2AClientInvalidArgsError, A2AClientInvalidStateError
5+
from a2a.types import (
6+
Task,
7+
TaskStatus,
8+
TaskState,
9+
TaskStatusUpdateEvent,
10+
TaskArtifactUpdateEvent,
11+
Message,
12+
Role,
13+
Part,
14+
TextPart,
15+
Artifact,
16+
)
17+
18+
19+
@pytest.fixture
20+
def task_manager():
21+
return ClientTaskManager()
22+
23+
24+
@pytest.fixture
25+
def sample_task():
26+
return Task(
27+
id="task123",
28+
context_id="context456",
29+
status=TaskStatus(state=TaskState.working),
30+
history=[],
31+
artifacts=[],
32+
)
33+
34+
35+
@pytest.fixture
36+
def sample_message():
37+
return Message(
38+
message_id="msg1",
39+
role=Role.user,
40+
parts=[Part(root=TextPart(text="Hello"))],
41+
)
42+
43+
44+
def test_get_task_no_task_id_returns_none(task_manager: ClientTaskManager):
45+
assert task_manager.get_task() is None
46+
47+
48+
def test_get_task_or_raise_no_task_raises_error(task_manager: ClientTaskManager):
49+
with pytest.raises(A2AClientInvalidStateError, match="no current Task"):
50+
task_manager.get_task_or_raise()
51+
52+
53+
@pytest.mark.asyncio
54+
async def test_save_task_event_with_task(
55+
task_manager: ClientTaskManager, sample_task: Task
56+
):
57+
await task_manager.save_task_event(sample_task)
58+
assert task_manager.get_task() == sample_task
59+
assert task_manager._task_id == sample_task.id
60+
assert task_manager._context_id == sample_task.context_id
61+
62+
63+
@pytest.mark.asyncio
64+
async def test_save_task_event_with_task_already_set_raises_error(
65+
task_manager: ClientTaskManager, sample_task: Task
66+
):
67+
await task_manager.save_task_event(sample_task)
68+
with pytest.raises(
69+
A2AClientInvalidArgsError,
70+
match="Task is already set, create new manager for new tasks.",
71+
):
72+
await task_manager.save_task_event(sample_task)
73+
74+
75+
@pytest.mark.asyncio
76+
async def test_save_task_event_with_status_update(
77+
task_manager: ClientTaskManager, sample_task: Task, sample_message: Message
78+
):
79+
await task_manager.save_task_event(sample_task)
80+
status_update = TaskStatusUpdateEvent(
81+
task_id=sample_task.id,
82+
context_id=sample_task.context_id,
83+
status=TaskStatus(state=TaskState.completed, message=sample_message),
84+
final=True,
85+
)
86+
updated_task = await task_manager.save_task_event(status_update)
87+
assert updated_task.status.state == TaskState.completed
88+
assert updated_task.history == [sample_message]
89+
90+
91+
@pytest.mark.asyncio
92+
async def test_save_task_event_with_artifact_update(
93+
task_manager: ClientTaskManager, sample_task: Task
94+
):
95+
await task_manager.save_task_event(sample_task)
96+
artifact = Artifact(
97+
artifact_id="art1", parts=[Part(root=TextPart(text="artifact content"))]
98+
)
99+
artifact_update = TaskArtifactUpdateEvent(
100+
task_id=sample_task.id,
101+
context_id=sample_task.context_id,
102+
artifact=artifact,
103+
)
104+
105+
with patch("a2a.client.client_task_manager.append_artifact_to_task") as mock_append:
106+
updated_task = await task_manager.save_task_event(artifact_update)
107+
mock_append.assert_called_once_with(updated_task, artifact_update)
108+
109+
110+
@pytest.mark.asyncio
111+
async def test_save_task_event_creates_task_if_not_exists(
112+
task_manager: ClientTaskManager,
113+
):
114+
status_update = TaskStatusUpdateEvent(
115+
task_id="new_task",
116+
context_id="new_context",
117+
status=TaskStatus(state=TaskState.working),
118+
final=False,
119+
)
120+
updated_task = await task_manager.save_task_event(status_update)
121+
assert updated_task is not None
122+
assert updated_task.id == "new_task"
123+
assert updated_task.status.state == TaskState.working
124+
125+
126+
@pytest.mark.asyncio
127+
async def test_process_with_task_event(task_manager: ClientTaskManager, sample_task: Task):
128+
with patch.object(
129+
task_manager, "save_task_event", new_callable=AsyncMock
130+
) as mock_save:
131+
await task_manager.process(sample_task)
132+
mock_save.assert_called_once_with(sample_task)
133+
134+
135+
@pytest.mark.asyncio
136+
async def test_process_with_non_task_event(task_manager: ClientTaskManager):
137+
with patch.object(
138+
task_manager, "save_task_event", new_callable=Mock
139+
) as mock_save:
140+
non_task_event = "not a task event"
141+
await task_manager.process(non_task_event)
142+
mock_save.assert_not_called()
143+
144+
145+
def test_update_with_message(
146+
task_manager: ClientTaskManager, sample_task: Task, sample_message: Message
147+
):
148+
updated_task = task_manager.update_with_message(sample_message, sample_task)
149+
assert updated_task.history == [sample_message]
150+
151+
152+
def test_update_with_message_moves_status_message(
153+
task_manager: ClientTaskManager, sample_task: Task, sample_message: Message
154+
):
155+
status_message = Message(
156+
message_id="status_msg",
157+
role=Role.agent,
158+
parts=[Part(root=TextPart(text="Status"))],
159+
)
160+
sample_task.status.message = status_message
161+
updated_task = task_manager.update_with_message(sample_message, sample_task)
162+
assert updated_task.history == [status_message, sample_message]
163+
assert updated_task.status.message is None

0 commit comments

Comments
 (0)