33import json
44from contextlib import asynccontextmanager
55from datetime import datetime , timedelta
6+ from typing import Any , Dict , Iterable , Sequence , cast
67
78import pytest
89from sqlalchemy import select , text , update
1415from agents .extensions .memory .sqlalchemy_session import SQLAlchemySession
1516from tests .fake_model import FakeModel
1617from tests .test_responses import get_text_message
18+ from openai .types .responses .response_output_message_param import ResponseOutputMessageParam
19+ from openai .types .responses .response_output_text_param import ResponseOutputTextParam
20+ from openai .types .responses .response_reasoning_item_param import (
21+ ResponseReasoningItemParam ,
22+ Summary ,
23+ )
1724
1825# Mark all tests in this file as asyncio
1926pytestmark = pytest .mark .asyncio
2229DB_URL = "sqlite+aiosqlite:///:memory:"
2330
2431
32+ def _make_message_item (item_id : str , text_value : str ) -> TResponseInputItem :
33+ content : ResponseOutputTextParam = {
34+ "type" : "output_text" ,
35+ "text" : text_value ,
36+ "annotations" : [],
37+ }
38+ message : ResponseOutputMessageParam = {
39+ "id" : item_id ,
40+ "type" : "message" ,
41+ "role" : "assistant" ,
42+ "status" : "completed" ,
43+ "content" : [content ],
44+ }
45+ return cast (TResponseInputItem , message )
46+
47+
48+ def _make_reasoning_item (item_id : str , summary_text : str ) -> TResponseInputItem :
49+ summary : Summary = {"type" : "summary_text" , "text" : summary_text }
50+ reasoning : ResponseReasoningItemParam = {
51+ "id" : item_id ,
52+ "type" : "reasoning" ,
53+ "summary" : [summary ],
54+ }
55+ return cast (TResponseInputItem , reasoning )
56+
57+
58+ def _item_ids (items : Sequence [TResponseInputItem ]) -> list [str ]:
59+ result : list [str ] = []
60+ for item in items :
61+ item_dict = cast (Dict [str , Any ], item )
62+ result .append (cast (str , item_dict ["id" ]))
63+ return result
64+
65+
2566@pytest .fixture
2667def agent () -> Agent :
2768 """Fixture for a basic agent with a fake model."""
@@ -164,21 +205,9 @@ async def test_get_items_same_timestamp_consistent_order():
164205 session_id = "same_timestamp_test"
165206 session = SQLAlchemySession .from_url (session_id , url = DB_URL , create_tables = True )
166207
167- older_item : TResponseInputItem = {
168- "id" : "older_same_ts" ,
169- "type" : "message" ,
170- "content" : [{"type" : "output_text" , "text" : "old" }],
171- }
172- reasoning_item : TResponseInputItem = {
173- "id" : "rs_same_ts" ,
174- "type" : "reasoning" ,
175- "summary" : "..." ,
176- }
177- message_item : TResponseInputItem = {
178- "id" : "msg_same_ts" ,
179- "type" : "message" ,
180- "content" : [{"type" : "output_text" , "text" : "..." }],
181- }
208+ older_item = _make_message_item ("older_same_ts" , "old" )
209+ reasoning_item = _make_reasoning_item ("rs_same_ts" , "..." )
210+ message_item = _make_message_item ("msg_same_ts" , "..." )
182211 await session .add_items ([older_item ])
183212 await session .add_items ([reasoning_item , message_item ])
184213
@@ -214,13 +243,13 @@ async def test_get_items_same_timestamp_consistent_order():
214243 real_factory = session ._session_factory
215244
216245 class FakeResult :
217- def __init__ (self , rows ):
218- self ._rows = rows
246+ def __init__ (self , rows : Iterable [ Any ] ):
247+ self ._rows = list ( rows )
219248
220- def all (self ):
249+ def all (self ) -> list [ Any ] :
221250 return list (self ._rows )
222251
223- def needs_shuffle (statement : Select ) -> bool :
252+ def needs_shuffle (statement : Any ) -> bool :
224253 if not isinstance (statement , Select ):
225254 return False
226255 orderings = list (statement ._order_by_clause )
@@ -231,22 +260,28 @@ def needs_shuffle(statement: Select) -> bool:
231260
232261 def references_id (clause ) -> bool :
233262 try :
234- return clause .compare (id_asc ) or clause .compare (id_desc )
263+ return bool ( clause .compare (id_asc ) or clause .compare (id_desc ) )
235264 except AttributeError :
236265 return False
237266
238267 if any (references_id (clause ) for clause in orderings ):
239268 return False
240269 # Only shuffle queries that target the messages table.
241- target_tables = {from_clause .name for from_clause in statement .get_final_froms ()}
242- return session ._messages .name in target_tables
270+ target_tables : set [str ] = set ()
271+ for from_clause in statement .get_final_froms ():
272+ name_attr = getattr (from_clause , "name" , None )
273+ if isinstance (name_attr , str ):
274+ target_tables .add (name_attr )
275+ table_name_obj = getattr (session ._messages , "name" , "" )
276+ table_name = table_name_obj if isinstance (table_name_obj , str ) else ""
277+ return bool (table_name in target_tables )
243278
244279 @asynccontextmanager
245280 async def shuffled_session ():
246281 async with real_factory () as inner :
247282 original_execute = inner .execute
248283
249- async def execute_with_shuffle (statement , * args , ** kwargs ) :
284+ async def execute_with_shuffle (statement : Any , * args : Any , ** kwargs : Any ) -> Any :
250285 result = await original_execute (statement , * args , ** kwargs )
251286 if needs_shuffle (statement ):
252287 rows = result .all ()
@@ -255,42 +290,30 @@ async def execute_with_shuffle(statement, *args, **kwargs):
255290 return FakeResult (shuffled )
256291 return result
257292
258- inner .execute = execute_with_shuffle # type: ignore[assignment]
293+ cast ( Any , inner ) .execute = execute_with_shuffle
259294 try :
260295 yield inner
261296 finally :
262- inner .execute = original_execute # type: ignore[assignment]
297+ cast ( Any , inner ) .execute = original_execute
263298
264- session ._session_factory = shuffled_session # type: ignore[assignment]
299+ session ._session_factory = cast ( Any , shuffled_session )
265300 try :
266301 retrieved = await session .get_items ()
267- assert [item ["id" ] for item in retrieved ] == [
268- "older_same_ts" ,
269- "rs_same_ts" ,
270- "msg_same_ts" ,
271- ]
302+ assert _item_ids (retrieved ) == ["older_same_ts" , "rs_same_ts" , "msg_same_ts" ]
272303
273304 latest_two = await session .get_items (limit = 2 )
274- assert [ item [ "id" ] for item in latest_two ] == ["rs_same_ts" , "msg_same_ts" ]
305+ assert _item_ids ( latest_two ) == ["rs_same_ts" , "msg_same_ts" ]
275306 finally :
276- session ._session_factory = real_factory # type: ignore[assignment]
307+ session ._session_factory = real_factory
277308
278309
279310async def test_pop_item_same_timestamp_returns_latest ():
280311 """Test that pop_item returns the newest item when timestamps tie."""
281312 session_id = "same_timestamp_pop_test"
282313 session = SQLAlchemySession .from_url (session_id , url = DB_URL , create_tables = True )
283314
284- reasoning_item : TResponseInputItem = {
285- "id" : "rs_pop_same_ts" ,
286- "type" : "reasoning" ,
287- "summary" : "..." ,
288- }
289- message_item : TResponseInputItem = {
290- "id" : "msg_pop_same_ts" ,
291- "type" : "message" ,
292- "content" : [{"type" : "output_text" , "text" : "..." }],
293- }
315+ reasoning_item = _make_reasoning_item ("rs_pop_same_ts" , "..." )
316+ message_item = _make_message_item ("msg_pop_same_ts" , "..." )
294317 await session .add_items ([reasoning_item , message_item ])
295318
296319 async with session ._session_factory () as sess :
@@ -309,10 +332,10 @@ async def test_pop_item_same_timestamp_returns_latest():
309332
310333 popped = await session .pop_item ()
311334 assert popped is not None
312- assert popped ["id" ] == "msg_pop_same_ts"
335+ assert cast ( Dict [ str , Any ], popped ) ["id" ] == "msg_pop_same_ts"
313336
314337 remaining = await session .get_items ()
315- assert [ item [ "id" ] for item in remaining ] == ["rs_pop_same_ts" ]
338+ assert _item_ids ( remaining ) == ["rs_pop_same_ts" ]
316339
317340
318341async def test_get_items_orders_by_id_for_ties ():
@@ -322,35 +345,35 @@ async def test_get_items_orders_by_id_for_ties():
322345
323346 await session .add_items (
324347 [
325- { "id" : " rs_first" , "type" : "reasoning" } ,
326- { "id" : " msg_second" , "type" : "message" } ,
348+ _make_reasoning_item ( " rs_first" , "..." ) ,
349+ _make_message_item ( " msg_second" , "..." ) ,
327350 ]
328351 )
329352
330353 real_factory = session ._session_factory
331- recorded : list = []
354+ recorded : list [ Any ] = []
332355
333356 @asynccontextmanager
334357 async def wrapped_session ():
335358 async with real_factory () as inner :
336359 original_execute = inner .execute
337360
338- async def recording_execute (statement , * args , ** kwargs ) :
361+ async def recording_execute (statement : Any , * args : Any , ** kwargs : Any ) -> Any :
339362 recorded .append (statement )
340363 return await original_execute (statement , * args , ** kwargs )
341364
342- inner .execute = recording_execute # type: ignore[assignment]
365+ cast ( Any , inner ) .execute = recording_execute
343366 try :
344367 yield inner
345368 finally :
346- inner .execute = original_execute # type: ignore[assignment]
369+ cast ( Any , inner ) .execute = original_execute
347370
348- session ._session_factory = wrapped_session # type: ignore[assignment]
371+ session ._session_factory = cast ( Any , wrapped_session )
349372 try :
350373 retrieved_full = await session .get_items ()
351374 retrieved_limited = await session .get_items (limit = 2 )
352375 finally :
353- session ._session_factory = real_factory # type: ignore[assignment]
376+ session ._session_factory = real_factory
354377
355378 assert len (recorded ) >= 2
356379 orderings_full = [str (clause ) for clause in recorded [0 ]._order_by_clause ]
@@ -365,5 +388,5 @@ async def recording_execute(statement, *args, **kwargs):
365388 "agent_messages.id DESC" ,
366389 ]
367390
368- assert [ item [ "id" ] for item in retrieved_full ] == ["rs_first" , "msg_second" ]
369- assert [ item [ "id" ] for item in retrieved_limited ] == ["rs_first" , "msg_second" ]
391+ assert _item_ids ( retrieved_full ) == ["rs_first" , "msg_second" ]
392+ assert _item_ids ( retrieved_limited ) == ["rs_first" , "msg_second" ]
0 commit comments