11import unittest
2- from unittest .mock import AsyncMock , MagicMock
32
3+ from unittest .mock import AsyncMock
4+
5+ from a2a .auth .user import UnauthenticatedUser # Import User types
6+ from a2a .server .agent_execution .context import (
7+ RequestContext , # Corrected import path
8+ )
49from a2a .server .agent_execution .simple_request_context_builder import (
510 SimpleRequestContextBuilder ,
611)
7- from a2a .server .agent_execution . context import RequestContext # Corrected import path
12+ from a2a .server .context import ServerCallContext
813from a2a .server .tasks .task_store import TaskStore
914from a2a .types import (
1015 Message ,
1116 MessageSendParams ,
12- Task ,
17+ Part ,
1318 # ServerCallContext, # Removed from a2a.types
1419 Role ,
15- Part ,
16- TextPart ,
17- TaskStatus ,
20+ Task ,
1821 TaskState ,
22+ TaskStatus ,
23+ TextPart ,
1924)
20- from a2a .server .context import ServerCallContext
21- from a2a .auth .user import User , UnauthenticatedUser # Import User types
2225
2326
2427# Helper to create a simple message
2528def create_sample_message (
26- content = "test message" , msg_id = "msg1" , role = Role .user , reference_task_ids = None
29+ content = 'test message' ,
30+ msg_id = 'msg1' ,
31+ role = Role .user ,
32+ reference_task_ids = None ,
2733):
2834 return Message (
2935 messageId = msg_id ,
@@ -32,14 +38,18 @@ def create_sample_message(
3238 referenceTaskIds = reference_task_ids if reference_task_ids else [],
3339 )
3440
41+
3542# Helper to create a simple task
36- def create_sample_task (task_id = "task1" , status_state = TaskState .submitted , context_id = "ctx1" ):
43+ def create_sample_task (
44+ task_id = 'task1' , status_state = TaskState .submitted , context_id = 'ctx1'
45+ ):
3746 return Task (
3847 id = task_id ,
3948 contextId = context_id ,
4049 status = TaskStatus (state = status_state ),
4150 )
4251
52+
4353class TestSimpleRequestContextBuilder (unittest .IsolatedAsyncioTestCase ):
4454 def setUp (self ):
4555 self .mock_task_store = AsyncMock (spec = TaskStore )
@@ -61,22 +71,28 @@ def test_init_with_populate_false_task_store_none(self):
6171 def test_init_with_populate_false_task_store_provided (self ):
6272 # Even if populate is false, task_store might still be provided (though not used by build for related_tasks)
6373 builder = SimpleRequestContextBuilder (
64- should_populate_referred_tasks = False , task_store = self .mock_task_store
74+ should_populate_referred_tasks = False ,
75+ task_store = self .mock_task_store ,
6576 )
6677 self .assertFalse (builder ._should_populate_referred_tasks )
6778 self .assertEqual (builder ._task_store , self .mock_task_store )
6879
6980 async def test_build_basic_context_no_populate (self ):
7081 builder = SimpleRequestContextBuilder (
71- should_populate_referred_tasks = False , task_store = self .mock_task_store
82+ should_populate_referred_tasks = False ,
83+ task_store = self .mock_task_store ,
7284 )
7385
7486 params = MessageSendParams (message = create_sample_message ())
75- task_id = "test_task_id_1"
76- context_id = "test_context_id_1"
77- current_task = create_sample_task (task_id = task_id , context_id = context_id )
87+ task_id = 'test_task_id_1'
88+ context_id = 'test_context_id_1'
89+ current_task = create_sample_task (
90+ task_id = task_id , context_id = context_id
91+ )
7892 # Pass a valid User instance, e.g., UnauthenticatedUser or a mock spec'd as User
79- server_call_context = ServerCallContext (user = UnauthenticatedUser (), auth_token = "dummy_token" )
93+ server_call_context = ServerCallContext (
94+ user = UnauthenticatedUser (), auth_token = 'dummy_token'
95+ )
8096
8197 request_context = await builder .build (
8298 params = params ,
@@ -92,18 +108,22 @@ async def test_build_basic_context_no_populate(self):
92108 self .assertEqual (request_context .configuration , params .configuration )
93109 self .assertEqual (request_context .task_id , task_id )
94110 self .assertEqual (request_context .context_id , context_id )
95- self .assertEqual (request_context .current_task , current_task ) # Property is current_task
96- self .assertEqual (request_context .call_context , server_call_context ) # Property is call_context
97- self .assertEqual (request_context .related_tasks , []) # Initialized to []
111+ self .assertEqual (
112+ request_context .current_task , current_task
113+ ) # Property is current_task
114+ self .assertEqual (
115+ request_context .call_context , server_call_context
116+ ) # Property is call_context
117+ self .assertEqual (request_context .related_tasks , []) # Initialized to []
98118 self .mock_task_store .get .assert_not_called ()
99119
100120 async def test_build_populate_true_with_reference_task_ids (self ):
101121 builder = SimpleRequestContextBuilder (
102122 should_populate_referred_tasks = True , task_store = self .mock_task_store
103123 )
104- ref_task_id1 = " ref_task1"
105- ref_task_id2 = " ref_task2_missing"
106- ref_task_id3 = " ref_task3"
124+ ref_task_id1 = ' ref_task1'
125+ ref_task_id2 = ' ref_task2_missing'
126+ ref_task_id3 = ' ref_task3'
107127
108128 mock_ref_task1 = create_sample_task (task_id = ref_task_id1 )
109129 mock_ref_task3 = create_sample_task (task_id = ref_task_id3 )
@@ -117,15 +137,22 @@ async def get_side_effect(task_id):
117137 if task_id == ref_task_id3 :
118138 return mock_ref_task3
119139 return None
140+
120141 self .mock_task_store .get = AsyncMock (side_effect = get_side_effect )
121142
122143 params = MessageSendParams (
123- message = create_sample_message (reference_task_ids = [ref_task_id1 , ref_task_id2 , ref_task_id3 ])
144+ message = create_sample_message (
145+ reference_task_ids = [ref_task_id1 , ref_task_id2 , ref_task_id3 ]
146+ )
124147 )
125148 server_call_context = ServerCallContext (user = UnauthenticatedUser ())
126149
127150 request_context = await builder .build (
128- params = params , task_id = "t1" , context_id = "c1" , task = None , context = server_call_context
151+ params = params ,
152+ task_id = 't1' ,
153+ context_id = 'c1' ,
154+ task = None ,
155+ context = server_call_context ,
129156 )
130157
131158 self .assertEqual (self .mock_task_store .get .call_count , 3 )
@@ -134,7 +161,9 @@ async def get_side_effect(task_id):
134161 self .mock_task_store .get .assert_any_call (ref_task_id3 )
135162
136163 self .assertIsNotNone (request_context .related_tasks )
137- self .assertEqual (len (request_context .related_tasks ), 2 ) # Only non-None tasks
164+ self .assertEqual (
165+ len (request_context .related_tasks ), 2
166+ ) # Only non-None tasks
138167 self .assertIn (mock_ref_task1 , request_context .related_tasks )
139168 self .assertIn (mock_ref_task3 , request_context .related_tasks )
140169
@@ -144,7 +173,11 @@ async def test_build_populate_true_params_none(self):
144173 )
145174 server_call_context = ServerCallContext (user = UnauthenticatedUser ())
146175 request_context = await builder .build (
147- params = None , task_id = "t1" , context_id = "c1" , task = None , context = server_call_context
176+ params = None ,
177+ task_id = 't1' ,
178+ context_id = 'c1' ,
179+ task = None ,
180+ context = server_call_context ,
148181 )
149182 self .assertEqual (request_context .related_tasks , [])
150183 self .mock_task_store .get .assert_not_called ()
@@ -156,64 +189,88 @@ async def test_build_populate_true_reference_ids_empty_or_none(self):
156189 server_call_context = ServerCallContext (user = UnauthenticatedUser ())
157190
158191 # Test with empty list
159- params_empty_refs = MessageSendParams (message = create_sample_message (reference_task_ids = []))
192+ params_empty_refs = MessageSendParams (
193+ message = create_sample_message (reference_task_ids = [])
194+ )
160195 request_context_empty = await builder .build (
161- params = params_empty_refs , task_id = "t1" , context_id = "c1" , task = None , context = server_call_context
196+ params = params_empty_refs ,
197+ task_id = 't1' ,
198+ context_id = 'c1' ,
199+ task = None ,
200+ context = server_call_context ,
162201 )
163- self .assertEqual (request_context_empty .related_tasks , []) # Should be [] if list is empty
202+ self .assertEqual (
203+ request_context_empty .related_tasks , []
204+ ) # Should be [] if list is empty
164205 self .mock_task_store .get .assert_not_called ()
165206
166- self .mock_task_store .get .reset_mock () # Reset for next call
207+ self .mock_task_store .get .reset_mock () # Reset for next call
167208
168209 # Test with referenceTaskIds=None (Pydantic model might default it to empty list or handle it)
169210 # create_sample_message defaults to [] if None is passed, so this tests the same as above.
170211 # To explicitly test None in Message, we'd have to bypass Pydantic default or modify helper.
171212 # For now, this covers the "no IDs to process" case.
172- msg_with_no_refs = Message (messageId = "m2" , role = Role .user , parts = [], referenceTaskIds = None )
213+ msg_with_no_refs = Message (
214+ messageId = 'm2' , role = Role .user , parts = [], referenceTaskIds = None
215+ )
173216 params_none_refs = MessageSendParams (message = msg_with_no_refs )
174217 request_context_none = await builder .build (
175- params = params_none_refs , task_id = "t2" , context_id = "c2" , task = None , context = server_call_context
218+ params = params_none_refs ,
219+ task_id = 't2' ,
220+ context_id = 'c2' ,
221+ task = None ,
222+ context = server_call_context ,
176223 )
177224 self .assertEqual (request_context_none .related_tasks , [])
178225 self .mock_task_store .get .assert_not_called ()
179226
180-
181227 async def test_build_populate_true_task_store_none (self ):
182228 # This scenario might be prevented by constructor logic if should_populate_referred_tasks is True,
183229 # but testing defensively. The builder might allow task_store=None if it's set post-init,
184230 # or if constructor logic changes. Current SimpleRequestContextBuilder takes it at init.
185231 # If task_store is None, it should not attempt to call get.
186232 builder = SimpleRequestContextBuilder (
187- should_populate_referred_tasks = True , task_store = None # Explicitly None
233+ should_populate_referred_tasks = True ,
234+ task_store = None , # Explicitly None
188235 )
189236 params = MessageSendParams (
190- message = create_sample_message (reference_task_ids = [" ref1" ])
237+ message = create_sample_message (reference_task_ids = [' ref1' ])
191238 )
192239 server_call_context = ServerCallContext (user = UnauthenticatedUser ())
193240
194241 request_context = await builder .build (
195- params = params , task_id = "t1" , context_id = "c1" , task = None , context = server_call_context
242+ params = params ,
243+ task_id = 't1' ,
244+ context_id = 'c1' ,
245+ task = None ,
246+ context = server_call_context ,
196247 )
197248 # Expect related_tasks to be an empty list as task_store is None
198249 self .assertEqual (request_context .related_tasks , [])
199250 # No mock_task_store to check calls on, this test is mostly for graceful handling.
200251
201-
202252 async def test_build_populate_false_with_reference_task_ids (self ):
203253 builder = SimpleRequestContextBuilder (
204- should_populate_referred_tasks = False , task_store = self .mock_task_store
254+ should_populate_referred_tasks = False ,
255+ task_store = self .mock_task_store ,
205256 )
206257 params = MessageSendParams (
207- message = create_sample_message (reference_task_ids = ["ref_task_should_not_be_fetched" ])
258+ message = create_sample_message (
259+ reference_task_ids = ['ref_task_should_not_be_fetched' ]
260+ )
208261 )
209262 server_call_context = ServerCallContext (user = UnauthenticatedUser ())
210263
211264 request_context = await builder .build (
212- params = params , task_id = "t1" , context_id = "c1" , task = None , context = server_call_context
265+ params = params ,
266+ task_id = 't1' ,
267+ context_id = 'c1' ,
268+ task = None ,
269+ context = server_call_context ,
213270 )
214271 self .assertEqual (request_context .related_tasks , [])
215272 self .mock_task_store .get .assert_not_called ()
216273
217274
218- if __name__ == " __main__" :
275+ if __name__ == ' __main__' :
219276 unittest .main ()
0 commit comments