Skip to content

Commit d907865

Browse files
authored
Merge branch 'main' into test.task-updater
2 parents 2a477d9 + 0fdd7ae commit d907865

File tree

2 files changed

+435
-0
lines changed

2 files changed

+435
-0
lines changed
Lines changed: 225 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,225 @@
1+
import uuid
2+
3+
from unittest.mock import Mock, patch
4+
5+
import pytest
6+
7+
from a2a.server.agent_execution import RequestContext
8+
from a2a.types import (
9+
Message,
10+
MessageSendParams,
11+
Task,
12+
)
13+
14+
15+
class TestRequestContext:
16+
"""Tests for the RequestContext class."""
17+
18+
@pytest.fixture
19+
def mock_message(self):
20+
"""Fixture for a mock Message."""
21+
return Mock(spec=Message, taskId=None, contextId=None)
22+
23+
@pytest.fixture
24+
def mock_params(self, mock_message):
25+
"""Fixture for a mock MessageSendParams."""
26+
return Mock(spec=MessageSendParams, message=mock_message)
27+
28+
@pytest.fixture
29+
def mock_task(self):
30+
"""Fixture for a mock Task."""
31+
return Mock(spec=Task, id='task-123', contextId='context-456')
32+
33+
def test_init_without_params(self):
34+
"""Test initialization without parameters."""
35+
context = RequestContext()
36+
assert context.message is None
37+
assert context.task_id is None
38+
assert context.context_id is None
39+
assert context.current_task is None
40+
assert context.related_tasks == []
41+
42+
def test_init_with_params_no_ids(self, mock_params):
43+
"""Test initialization with params but no task or context IDs."""
44+
with patch(
45+
'uuid.uuid4',
46+
side_effect=[
47+
uuid.UUID('00000000-0000-0000-0000-000000000001'),
48+
uuid.UUID('00000000-0000-0000-0000-000000000002'),
49+
],
50+
):
51+
context = RequestContext(request=mock_params)
52+
53+
assert context.message == mock_params.message
54+
assert context.task_id == '00000000-0000-0000-0000-000000000001'
55+
assert (
56+
mock_params.message.taskId == '00000000-0000-0000-0000-000000000001'
57+
)
58+
assert context.context_id == '00000000-0000-0000-0000-000000000002'
59+
assert (
60+
mock_params.message.contextId
61+
== '00000000-0000-0000-0000-000000000002'
62+
)
63+
64+
def test_init_with_task_id(self, mock_params):
65+
"""Test initialization with task ID provided."""
66+
task_id = 'task-123'
67+
context = RequestContext(request=mock_params, task_id=task_id)
68+
69+
assert context.task_id == task_id
70+
assert mock_params.message.taskId == task_id
71+
72+
def test_init_with_context_id(self, mock_params):
73+
"""Test initialization with context ID provided."""
74+
context_id = 'context-456'
75+
context = RequestContext(request=mock_params, context_id=context_id)
76+
77+
assert context.context_id == context_id
78+
assert mock_params.message.contextId == context_id
79+
80+
def test_init_with_both_ids(self, mock_params):
81+
"""Test initialization with both task and context IDs provided."""
82+
task_id = 'task-123'
83+
context_id = 'context-456'
84+
context = RequestContext(
85+
request=mock_params, task_id=task_id, context_id=context_id
86+
)
87+
88+
assert context.task_id == task_id
89+
assert mock_params.message.taskId == task_id
90+
assert context.context_id == context_id
91+
assert mock_params.message.contextId == context_id
92+
93+
def test_init_with_task(self, mock_params, mock_task):
94+
"""Test initialization with a task object."""
95+
context = RequestContext(request=mock_params, task=mock_task)
96+
97+
assert context.current_task == mock_task
98+
99+
def test_get_user_input_no_params(self):
100+
"""Test get_user_input with no params returns empty string."""
101+
context = RequestContext()
102+
assert context.get_user_input() == ''
103+
104+
def test_attach_related_task(self, mock_task):
105+
"""Test attach_related_task adds a task to related_tasks."""
106+
context = RequestContext()
107+
assert len(context.related_tasks) == 0
108+
109+
context.attach_related_task(mock_task)
110+
assert len(context.related_tasks) == 1
111+
assert context.related_tasks[0] == mock_task
112+
113+
# Test adding multiple tasks
114+
another_task = Mock(spec=Task)
115+
context.attach_related_task(another_task)
116+
assert len(context.related_tasks) == 2
117+
assert context.related_tasks[1] == another_task
118+
119+
def test_current_task_property(self, mock_task):
120+
"""Test current_task getter and setter."""
121+
context = RequestContext()
122+
assert context.current_task is None
123+
124+
context.current_task = mock_task
125+
assert context.current_task == mock_task
126+
127+
# Change current task
128+
new_task = Mock(spec=Task)
129+
context.current_task = new_task
130+
assert context.current_task == new_task
131+
132+
def test_check_or_generate_task_id_no_params(self):
133+
"""Test _check_or_generate_task_id with no params does nothing."""
134+
context = RequestContext()
135+
context._check_or_generate_task_id()
136+
assert context.task_id is None
137+
138+
def test_check_or_generate_task_id_with_existing_task_id(self, mock_params):
139+
"""Test _check_or_generate_task_id with existing task ID."""
140+
existing_id = 'existing-task-id'
141+
mock_params.message.taskId = existing_id
142+
143+
context = RequestContext(request=mock_params)
144+
# The method is called during initialization
145+
146+
assert context.task_id == existing_id
147+
assert mock_params.message.taskId == existing_id
148+
149+
def test_check_or_generate_context_id_no_params(self):
150+
"""Test _check_or_generate_context_id with no params does nothing."""
151+
context = RequestContext()
152+
context._check_or_generate_context_id()
153+
assert context.context_id is None
154+
155+
def test_check_or_generate_context_id_with_existing_context_id(
156+
self, mock_params
157+
):
158+
"""Test _check_or_generate_context_id with existing context ID."""
159+
existing_id = 'existing-context-id'
160+
mock_params.message.contextId = existing_id
161+
162+
context = RequestContext(request=mock_params)
163+
# The method is called during initialization
164+
165+
assert context.context_id == existing_id
166+
assert mock_params.message.contextId == existing_id
167+
168+
def test_with_related_tasks_provided(self, mock_task):
169+
"""Test initialization with related tasks provided."""
170+
related_tasks = [mock_task, Mock(spec=Task)]
171+
context = RequestContext(related_tasks=related_tasks)
172+
173+
assert context.related_tasks == related_tasks
174+
assert len(context.related_tasks) == 2
175+
176+
def test_message_property_without_params(self):
177+
"""Test message property returns None when no params are provided."""
178+
context = RequestContext()
179+
assert context.message is None
180+
181+
def test_message_property_with_params(self, mock_params):
182+
"""Test message property returns the message from params."""
183+
context = RequestContext(request=mock_params)
184+
assert context.message == mock_params.message
185+
186+
def test_init_with_existing_ids_in_message(self, mock_message, mock_params):
187+
"""Test initialization with existing IDs in the message."""
188+
mock_message.taskId = 'existing-task-id'
189+
mock_message.contextId = 'existing-context-id'
190+
191+
context = RequestContext(request=mock_params)
192+
193+
assert context.task_id == 'existing-task-id'
194+
assert context.context_id == 'existing-context-id'
195+
# No new UUIDs should be generated
196+
197+
def test_init_with_task_id_and_existing_task_id_match(
198+
self, mock_params, mock_task
199+
):
200+
"""Test initialization succeeds when task_id matches task.id."""
201+
mock_params.message.taskId = mock_task.id
202+
203+
context = RequestContext(
204+
request=mock_params, task_id=mock_task.id, task=mock_task
205+
)
206+
207+
assert context.task_id == mock_task.id
208+
assert context.current_task == mock_task
209+
210+
def test_init_with_context_id_and_existing_context_id_match(
211+
self, mock_params, mock_task
212+
):
213+
"""Test initialization succeeds when context_id matches task.contextId."""
214+
mock_params.message.taskId = mock_task.id # Set matching task ID
215+
mock_params.message.contextId = mock_task.contextId
216+
217+
context = RequestContext(
218+
request=mock_params,
219+
task_id=mock_task.id,
220+
context_id=mock_task.contextId,
221+
task=mock_task,
222+
)
223+
224+
assert context.context_id == mock_task.contextId
225+
assert context.current_task == mock_task

0 commit comments

Comments
 (0)