Skip to content

Commit 82e1d3a

Browse files
committed
adding 19 tests for agent_execution/context.py
1 parent 224d4f5 commit 82e1d3a

File tree

1 file changed

+219
-0
lines changed

1 file changed

+219
-0
lines changed
Lines changed: 219 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,219 @@
1+
import pytest
2+
import uuid
3+
from unittest.mock import Mock, patch
4+
5+
from a2a.server.agent_execution import RequestContext
6+
from a2a.types import (
7+
InvalidParamsError,
8+
Message,
9+
MessageSendParams,
10+
Task,
11+
)
12+
from a2a.utils.errors import ServerError
13+
14+
15+
16+
class TestRequestContext:
17+
"""Tests for the RequestContext class."""
18+
19+
@pytest.fixture
20+
def mock_message(self):
21+
"""Fixture for a mock Message."""
22+
return Mock(spec=Message, taskId=None, contextId=None)
23+
24+
@pytest.fixture
25+
def mock_params(self, mock_message):
26+
"""Fixture for a mock MessageSendParams."""
27+
return Mock(spec=MessageSendParams, message=mock_message)
28+
29+
@pytest.fixture
30+
def mock_task(self):
31+
"""Fixture for a mock Task."""
32+
return Mock(spec=Task, id="task-123", contextId="context-456")
33+
34+
def test_init_without_params(self):
35+
"""Test initialization without parameters."""
36+
context = RequestContext()
37+
assert context.message is None
38+
assert context.task_id is None
39+
assert context.context_id is None
40+
assert context.current_task is None
41+
assert context.related_tasks == []
42+
43+
def test_init_with_params_no_ids(self, mock_params):
44+
"""Test initialization with params but no task or context IDs."""
45+
with patch('uuid.uuid4', side_effect=[
46+
uuid.UUID('00000000-0000-0000-0000-000000000001'),
47+
uuid.UUID('00000000-0000-0000-0000-000000000002')
48+
]):
49+
context = RequestContext(request=mock_params)
50+
51+
assert context.message == mock_params.message
52+
assert context.task_id == '00000000-0000-0000-0000-000000000001'
53+
assert mock_params.message.taskId == '00000000-0000-0000-0000-000000000001'
54+
assert context.context_id == '00000000-0000-0000-0000-000000000002'
55+
assert mock_params.message.contextId == '00000000-0000-0000-0000-000000000002'
56+
57+
def test_init_with_task_id(self, mock_params):
58+
"""Test initialization with task ID provided."""
59+
task_id = "task-123"
60+
context = RequestContext(request=mock_params, task_id=task_id)
61+
62+
assert context.task_id == task_id
63+
assert mock_params.message.taskId == task_id
64+
65+
def test_init_with_context_id(self, mock_params):
66+
"""Test initialization with context ID provided."""
67+
context_id = "context-456"
68+
context = RequestContext(request=mock_params, context_id=context_id)
69+
70+
assert context.context_id == context_id
71+
assert mock_params.message.contextId == context_id
72+
73+
def test_init_with_both_ids(self, mock_params):
74+
"""Test initialization with both task and context IDs provided."""
75+
task_id = "task-123"
76+
context_id = "context-456"
77+
context = RequestContext(
78+
request=mock_params,
79+
task_id=task_id,
80+
context_id=context_id
81+
)
82+
83+
assert context.task_id == task_id
84+
assert mock_params.message.taskId == task_id
85+
assert context.context_id == context_id
86+
assert mock_params.message.contextId == context_id
87+
88+
def test_init_with_task(self, mock_params, mock_task):
89+
"""Test initialization with a task object."""
90+
context = RequestContext(
91+
request=mock_params,
92+
task=mock_task
93+
)
94+
95+
assert context.current_task == mock_task
96+
97+
def test_get_user_input_no_params(self):
98+
"""Test get_user_input with no params returns empty string."""
99+
context = RequestContext()
100+
assert context.get_user_input() == ''
101+
102+
def test_attach_related_task(self, mock_task):
103+
"""Test attach_related_task adds a task to related_tasks."""
104+
context = RequestContext()
105+
assert len(context.related_tasks) == 0
106+
107+
context.attach_related_task(mock_task)
108+
assert len(context.related_tasks) == 1
109+
assert context.related_tasks[0] == mock_task
110+
111+
# Test adding multiple tasks
112+
another_task = Mock(spec=Task)
113+
context.attach_related_task(another_task)
114+
assert len(context.related_tasks) == 2
115+
assert context.related_tasks[1] == another_task
116+
117+
def test_current_task_property(self, mock_task):
118+
"""Test current_task getter and setter."""
119+
context = RequestContext()
120+
assert context.current_task is None
121+
122+
context.current_task = mock_task
123+
assert context.current_task == mock_task
124+
125+
# Change current task
126+
new_task = Mock(spec=Task)
127+
context.current_task = new_task
128+
assert context.current_task == new_task
129+
130+
def test_check_or_generate_task_id_no_params(self):
131+
"""Test _check_or_generate_task_id with no params does nothing."""
132+
context = RequestContext()
133+
context._check_or_generate_task_id()
134+
assert context.task_id is None
135+
136+
def test_check_or_generate_task_id_with_existing_task_id(self, mock_params):
137+
"""Test _check_or_generate_task_id with existing task ID."""
138+
existing_id = "existing-task-id"
139+
mock_params.message.taskId = existing_id
140+
141+
context = RequestContext(request=mock_params)
142+
# The method is called during initialization
143+
144+
assert context.task_id == existing_id
145+
assert mock_params.message.taskId == existing_id
146+
147+
def test_check_or_generate_context_id_no_params(self):
148+
"""Test _check_or_generate_context_id with no params does nothing."""
149+
context = RequestContext()
150+
context._check_or_generate_context_id()
151+
assert context.context_id is None
152+
153+
def test_check_or_generate_context_id_with_existing_context_id(self, mock_params):
154+
"""Test _check_or_generate_context_id with existing context ID."""
155+
existing_id = "existing-context-id"
156+
mock_params.message.contextId = existing_id
157+
158+
context = RequestContext(request=mock_params)
159+
# The method is called during initialization
160+
161+
assert context.context_id == existing_id
162+
assert mock_params.message.contextId == existing_id
163+
164+
def test_with_related_tasks_provided(self, mock_task):
165+
"""Test initialization with related tasks provided."""
166+
related_tasks = [mock_task, Mock(spec=Task)]
167+
context = RequestContext(related_tasks=related_tasks)
168+
169+
assert context.related_tasks == related_tasks
170+
assert len(context.related_tasks) == 2
171+
172+
def test_message_property_without_params(self):
173+
"""Test message property returns None when no params are provided."""
174+
context = RequestContext()
175+
assert context.message is None
176+
177+
def test_message_property_with_params(self, mock_params):
178+
"""Test message property returns the message from params."""
179+
context = RequestContext(request=mock_params)
180+
assert context.message == mock_params.message
181+
182+
def test_init_with_existing_ids_in_message(self, mock_message, mock_params):
183+
"""Test initialization with existing IDs in the message."""
184+
mock_message.taskId = "existing-task-id"
185+
mock_message.contextId = "existing-context-id"
186+
187+
context = RequestContext(request=mock_params)
188+
189+
assert context.task_id == "existing-task-id"
190+
assert context.context_id == "existing-context-id"
191+
# No new UUIDs should be generated
192+
193+
def test_init_with_task_id_and_existing_task_id_match(self, mock_params, mock_task):
194+
"""Test initialization succeeds when task_id matches task.id."""
195+
mock_params.message.taskId = mock_task.id
196+
197+
context = RequestContext(
198+
request=mock_params,
199+
task_id=mock_task.id,
200+
task=mock_task
201+
)
202+
203+
assert context.task_id == mock_task.id
204+
assert context.current_task == mock_task
205+
206+
def test_init_with_context_id_and_existing_context_id_match(self, mock_params, mock_task):
207+
"""Test initialization succeeds when context_id matches task.contextId."""
208+
mock_params.message.taskId = mock_task.id # Set matching task ID
209+
mock_params.message.contextId = mock_task.contextId
210+
211+
context = RequestContext(
212+
request=mock_params,
213+
task_id=mock_task.id,
214+
context_id=mock_task.contextId,
215+
task=mock_task
216+
)
217+
218+
assert context.context_id == mock_task.contextId
219+
assert context.current_task == mock_task

0 commit comments

Comments
 (0)