Skip to content

Commit 412ebb1

Browse files
committed
fix mypy issues
1 parent 421108d commit 412ebb1

File tree

1 file changed

+78
-55
lines changed

1 file changed

+78
-55
lines changed

tests/extensions/memory/test_sqlalchemy_session.py

Lines changed: 78 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import json
44
from contextlib import asynccontextmanager
55
from datetime import datetime, timedelta
6+
from typing import Any, Dict, Iterable, Sequence, cast
67

78
import pytest
89
from sqlalchemy import select, text, update
@@ -14,6 +15,12 @@
1415
from agents.extensions.memory.sqlalchemy_session import SQLAlchemySession
1516
from tests.fake_model import FakeModel
1617
from tests.test_responses import get_text_message
18+
from openai.types.responses.response_output_message_param import ResponseOutputMessageParam
19+
from openai.types.responses.response_output_text_param import ResponseOutputTextParam
20+
from openai.types.responses.response_reasoning_item_param import (
21+
ResponseReasoningItemParam,
22+
Summary,
23+
)
1724

1825
# Mark all tests in this file as asyncio
1926
pytestmark = pytest.mark.asyncio
@@ -22,6 +29,40 @@
2229
DB_URL = "sqlite+aiosqlite:///:memory:"
2330

2431

32+
def _make_message_item(item_id: str, text_value: str) -> TResponseInputItem:
33+
content: ResponseOutputTextParam = {
34+
"type": "output_text",
35+
"text": text_value,
36+
"annotations": [],
37+
}
38+
message: ResponseOutputMessageParam = {
39+
"id": item_id,
40+
"type": "message",
41+
"role": "assistant",
42+
"status": "completed",
43+
"content": [content],
44+
}
45+
return cast(TResponseInputItem, message)
46+
47+
48+
def _make_reasoning_item(item_id: str, summary_text: str) -> TResponseInputItem:
49+
summary: Summary = {"type": "summary_text", "text": summary_text}
50+
reasoning: ResponseReasoningItemParam = {
51+
"id": item_id,
52+
"type": "reasoning",
53+
"summary": [summary],
54+
}
55+
return cast(TResponseInputItem, reasoning)
56+
57+
58+
def _item_ids(items: Sequence[TResponseInputItem]) -> list[str]:
59+
result: list[str] = []
60+
for item in items:
61+
item_dict = cast(Dict[str, Any], item)
62+
result.append(cast(str, item_dict["id"]))
63+
return result
64+
65+
2566
@pytest.fixture
2667
def agent() -> Agent:
2768
"""Fixture for a basic agent with a fake model."""
@@ -164,21 +205,9 @@ async def test_get_items_same_timestamp_consistent_order():
164205
session_id = "same_timestamp_test"
165206
session = SQLAlchemySession.from_url(session_id, url=DB_URL, create_tables=True)
166207

167-
older_item: TResponseInputItem = {
168-
"id": "older_same_ts",
169-
"type": "message",
170-
"content": [{"type": "output_text", "text": "old"}],
171-
}
172-
reasoning_item: TResponseInputItem = {
173-
"id": "rs_same_ts",
174-
"type": "reasoning",
175-
"summary": "...",
176-
}
177-
message_item: TResponseInputItem = {
178-
"id": "msg_same_ts",
179-
"type": "message",
180-
"content": [{"type": "output_text", "text": "..."}],
181-
}
208+
older_item = _make_message_item("older_same_ts", "old")
209+
reasoning_item = _make_reasoning_item("rs_same_ts", "...")
210+
message_item = _make_message_item("msg_same_ts", "...")
182211
await session.add_items([older_item])
183212
await session.add_items([reasoning_item, message_item])
184213

@@ -214,13 +243,13 @@ async def test_get_items_same_timestamp_consistent_order():
214243
real_factory = session._session_factory
215244

216245
class FakeResult:
217-
def __init__(self, rows):
218-
self._rows = rows
246+
def __init__(self, rows: Iterable[Any]):
247+
self._rows = list(rows)
219248

220-
def all(self):
249+
def all(self) -> list[Any]:
221250
return list(self._rows)
222251

223-
def needs_shuffle(statement: Select) -> bool:
252+
def needs_shuffle(statement: Any) -> bool:
224253
if not isinstance(statement, Select):
225254
return False
226255
orderings = list(statement._order_by_clause)
@@ -231,22 +260,28 @@ def needs_shuffle(statement: Select) -> bool:
231260

232261
def references_id(clause) -> bool:
233262
try:
234-
return clause.compare(id_asc) or clause.compare(id_desc)
263+
return bool(clause.compare(id_asc) or clause.compare(id_desc))
235264
except AttributeError:
236265
return False
237266

238267
if any(references_id(clause) for clause in orderings):
239268
return False
240269
# Only shuffle queries that target the messages table.
241-
target_tables = {from_clause.name for from_clause in statement.get_final_froms()}
242-
return session._messages.name in target_tables
270+
target_tables: set[str] = set()
271+
for from_clause in statement.get_final_froms():
272+
name_attr = getattr(from_clause, "name", None)
273+
if isinstance(name_attr, str):
274+
target_tables.add(name_attr)
275+
table_name_obj = getattr(session._messages, "name", "")
276+
table_name = table_name_obj if isinstance(table_name_obj, str) else ""
277+
return bool(table_name in target_tables)
243278

244279
@asynccontextmanager
245280
async def shuffled_session():
246281
async with real_factory() as inner:
247282
original_execute = inner.execute
248283

249-
async def execute_with_shuffle(statement, *args, **kwargs):
284+
async def execute_with_shuffle(statement: Any, *args: Any, **kwargs: Any) -> Any:
250285
result = await original_execute(statement, *args, **kwargs)
251286
if needs_shuffle(statement):
252287
rows = result.all()
@@ -255,42 +290,30 @@ async def execute_with_shuffle(statement, *args, **kwargs):
255290
return FakeResult(shuffled)
256291
return result
257292

258-
inner.execute = execute_with_shuffle # type: ignore[assignment]
293+
cast(Any, inner).execute = execute_with_shuffle
259294
try:
260295
yield inner
261296
finally:
262-
inner.execute = original_execute # type: ignore[assignment]
297+
cast(Any, inner).execute = original_execute
263298

264-
session._session_factory = shuffled_session # type: ignore[assignment]
299+
session._session_factory = cast(Any, shuffled_session)
265300
try:
266301
retrieved = await session.get_items()
267-
assert [item["id"] for item in retrieved] == [
268-
"older_same_ts",
269-
"rs_same_ts",
270-
"msg_same_ts",
271-
]
302+
assert _item_ids(retrieved) == ["older_same_ts", "rs_same_ts", "msg_same_ts"]
272303

273304
latest_two = await session.get_items(limit=2)
274-
assert [item["id"] for item in latest_two] == ["rs_same_ts", "msg_same_ts"]
305+
assert _item_ids(latest_two) == ["rs_same_ts", "msg_same_ts"]
275306
finally:
276-
session._session_factory = real_factory # type: ignore[assignment]
307+
session._session_factory = real_factory
277308

278309

279310
async def test_pop_item_same_timestamp_returns_latest():
280311
"""Test that pop_item returns the newest item when timestamps tie."""
281312
session_id = "same_timestamp_pop_test"
282313
session = SQLAlchemySession.from_url(session_id, url=DB_URL, create_tables=True)
283314

284-
reasoning_item: TResponseInputItem = {
285-
"id": "rs_pop_same_ts",
286-
"type": "reasoning",
287-
"summary": "...",
288-
}
289-
message_item: TResponseInputItem = {
290-
"id": "msg_pop_same_ts",
291-
"type": "message",
292-
"content": [{"type": "output_text", "text": "..."}],
293-
}
315+
reasoning_item = _make_reasoning_item("rs_pop_same_ts", "...")
316+
message_item = _make_message_item("msg_pop_same_ts", "...")
294317
await session.add_items([reasoning_item, message_item])
295318

296319
async with session._session_factory() as sess:
@@ -309,10 +332,10 @@ async def test_pop_item_same_timestamp_returns_latest():
309332

310333
popped = await session.pop_item()
311334
assert popped is not None
312-
assert popped["id"] == "msg_pop_same_ts"
335+
assert cast(Dict[str, Any], popped)["id"] == "msg_pop_same_ts"
313336

314337
remaining = await session.get_items()
315-
assert [item["id"] for item in remaining] == ["rs_pop_same_ts"]
338+
assert _item_ids(remaining) == ["rs_pop_same_ts"]
316339

317340

318341
async def test_get_items_orders_by_id_for_ties():
@@ -322,35 +345,35 @@ async def test_get_items_orders_by_id_for_ties():
322345

323346
await session.add_items(
324347
[
325-
{"id": "rs_first", "type": "reasoning"},
326-
{"id": "msg_second", "type": "message"},
348+
_make_reasoning_item("rs_first", "..."),
349+
_make_message_item("msg_second", "..."),
327350
]
328351
)
329352

330353
real_factory = session._session_factory
331-
recorded: list = []
354+
recorded: list[Any] = []
332355

333356
@asynccontextmanager
334357
async def wrapped_session():
335358
async with real_factory() as inner:
336359
original_execute = inner.execute
337360

338-
async def recording_execute(statement, *args, **kwargs):
361+
async def recording_execute(statement: Any, *args: Any, **kwargs: Any) -> Any:
339362
recorded.append(statement)
340363
return await original_execute(statement, *args, **kwargs)
341364

342-
inner.execute = recording_execute # type: ignore[assignment]
365+
cast(Any, inner).execute = recording_execute
343366
try:
344367
yield inner
345368
finally:
346-
inner.execute = original_execute # type: ignore[assignment]
369+
cast(Any, inner).execute = original_execute
347370

348-
session._session_factory = wrapped_session # type: ignore[assignment]
371+
session._session_factory = cast(Any, wrapped_session)
349372
try:
350373
retrieved_full = await session.get_items()
351374
retrieved_limited = await session.get_items(limit=2)
352375
finally:
353-
session._session_factory = real_factory # type: ignore[assignment]
376+
session._session_factory = real_factory
354377

355378
assert len(recorded) >= 2
356379
orderings_full = [str(clause) for clause in recorded[0]._order_by_clause]
@@ -365,5 +388,5 @@ async def recording_execute(statement, *args, **kwargs):
365388
"agent_messages.id DESC",
366389
]
367390

368-
assert [item["id"] for item in retrieved_full] == ["rs_first", "msg_second"]
369-
assert [item["id"] for item in retrieved_limited] == ["rs_first", "msg_second"]
391+
assert _item_ids(retrieved_full) == ["rs_first", "msg_second"]
392+
assert _item_ids(retrieved_limited) == ["rs_first", "msg_second"]

0 commit comments

Comments
 (0)