From e361ab90435d3384bde1cea11cd92b34a0ca053b Mon Sep 17 00:00:00 2001 From: Kazuhiro Sera Date: Fri, 17 Oct 2025 15:05:41 +0900 Subject: [PATCH 1/4] fix: #1900 fix a bug where SQLAlchemySession could return items in an invalid order --- .../extensions/memory/sqlalchemy_session.py | 15 +- .../memory/test_sqlalchemy_session.py | 216 ++++++++++++++++++ 2 files changed, 228 insertions(+), 3 deletions(-) diff --git a/src/agents/extensions/memory/sqlalchemy_session.py b/src/agents/extensions/memory/sqlalchemy_session.py index e1d7f248d..e1fc885bb 100644 --- a/src/agents/extensions/memory/sqlalchemy_session.py +++ b/src/agents/extensions/memory/sqlalchemy_session.py @@ -195,7 +195,10 @@ async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: stmt = ( select(self._messages.c.message_data) .where(self._messages.c.session_id == self.session_id) - .order_by(self._messages.c.created_at.asc()) + .order_by( + self._messages.c.created_at.asc(), + self._messages.c.id.asc(), + ) ) else: stmt = ( @@ -203,7 +206,10 @@ async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: .where(self._messages.c.session_id == self.session_id) # Use DESC + LIMIT to get the latest N # then reverse later for chronological order. - .order_by(self._messages.c.created_at.desc()) + .order_by( + self._messages.c.created_at.desc(), + self._messages.c.id.desc(), + ) .limit(limit) ) @@ -278,7 +284,10 @@ async def pop_item(self) -> TResponseInputItem | None: subq = ( select(self._messages.c.id) .where(self._messages.c.session_id == self.session_id) - .order_by(self._messages.c.created_at.desc()) + .order_by( + self._messages.c.created_at.desc(), + self._messages.c.id.desc(), + ) .limit(1) ) res = await sess.execute(subq) diff --git a/tests/extensions/memory/test_sqlalchemy_session.py b/tests/extensions/memory/test_sqlalchemy_session.py index e1ce3e10b..4265144f4 100644 --- a/tests/extensions/memory/test_sqlalchemy_session.py +++ b/tests/extensions/memory/test_sqlalchemy_session.py @@ -1,6 +1,12 @@ from __future__ import annotations +from contextlib import asynccontextmanager +from datetime import datetime, timedelta +import json + import pytest +from sqlalchemy import select, text, update +from sqlalchemy.sql import Select pytest.importorskip("sqlalchemy") # Skip tests if SQLAlchemy is not installed @@ -151,3 +157,213 @@ async def test_add_empty_items_list(): items_after_add = await session.get_items() assert len(items_after_add) == 0 + + +async def test_get_items_same_timestamp_consistent_order(): + """Test that items with identical timestamps keep insertion order.""" + session_id = "same_timestamp_test" + session = SQLAlchemySession.from_url(session_id, url=DB_URL, create_tables=True) + + older_item: TResponseInputItem = { + "id": "older_same_ts", + "type": "message", + "content": [{"type": "output_text", "text": "old"}], + } + reasoning_item: TResponseInputItem = { + "id": "rs_same_ts", + "type": "reasoning", + "summary": "...", + } + message_item: TResponseInputItem = { + "id": "msg_same_ts", + "type": "message", + "content": [{"type": "output_text", "text": "..."}], + } + await session.add_items([older_item]) + await session.add_items([reasoning_item, message_item]) + + async with session._session_factory() as sess: + rows = await sess.execute( + select(session._messages.c.id, session._messages.c.message_data).where( + session._messages.c.session_id == session.session_id + ) + ) + id_map = { + json.loads(message_json)["id"]: row_id + for row_id, message_json in rows.fetchall() + } + shared = datetime(2025, 10, 15, 17, 26, 39, 132483) + older = shared - timedelta(milliseconds=1) + await sess.execute( + update(session._messages) + .where(session._messages.c.id.in_( + [ + id_map["rs_same_ts"], + id_map["msg_same_ts"], + ] + )) + .values(created_at=shared) + ) + await sess.execute( + update(session._messages) + .where(session._messages.c.id == id_map["older_same_ts"]) + .values(created_at=older) + ) + await sess.commit() + + real_factory = session._session_factory + + class FakeResult: + def __init__(self, rows): + self._rows = rows + + def all(self): + return list(self._rows) + + def needs_shuffle(statement: Select) -> bool: + if not isinstance(statement, Select): + return False + orderings = list(statement._order_by_clause) + if not orderings: + return False + id_asc = session._messages.c.id.asc() + id_desc = session._messages.c.id.desc() + + def references_id(clause) -> bool: + try: + return clause.compare(id_asc) or clause.compare(id_desc) + except AttributeError: + return False + + if any(references_id(clause) for clause in orderings): + return False + # Only shuffle queries that target the messages table. + target_tables = {from_clause.name for from_clause in statement.get_final_froms()} + return session._messages.name in target_tables + + @asynccontextmanager + async def shuffled_session(): + async with real_factory() as inner: + original_execute = inner.execute + + async def execute_with_shuffle(statement, *args, **kwargs): + result = await original_execute(statement, *args, **kwargs) + if needs_shuffle(statement): + rows = result.all() + shuffled = list(rows) + shuffled.reverse() + return FakeResult(shuffled) + return result + + inner.execute = execute_with_shuffle # type: ignore[assignment] + try: + yield inner + finally: + inner.execute = original_execute # type: ignore[assignment] + + session._session_factory = shuffled_session # type: ignore[assignment] + try: + retrieved = await session.get_items() + assert [item["id"] for item in retrieved] == [ + "older_same_ts", + "rs_same_ts", + "msg_same_ts", + ] + + latest_two = await session.get_items(limit=2) + assert [item["id"] for item in latest_two] == ["rs_same_ts", "msg_same_ts"] + finally: + session._session_factory = real_factory # type: ignore[assignment] + + +async def test_pop_item_same_timestamp_returns_latest(): + """Test that pop_item returns the newest item when timestamps tie.""" + session_id = "same_timestamp_pop_test" + session = SQLAlchemySession.from_url(session_id, url=DB_URL, create_tables=True) + + reasoning_item: TResponseInputItem = { + "id": "rs_pop_same_ts", + "type": "reasoning", + "summary": "...", + } + message_item: TResponseInputItem = { + "id": "msg_pop_same_ts", + "type": "message", + "content": [{"type": "output_text", "text": "..."}], + } + await session.add_items([reasoning_item, message_item]) + + async with session._session_factory() as sess: + await sess.execute( + text( + "UPDATE agent_messages " + "SET created_at = :created_at " + "WHERE session_id = :session_id" + ), + { + "created_at": "2025-10-15 17:26:39.132483", + "session_id": session.session_id, + }, + ) + await sess.commit() + + popped = await session.pop_item() + assert popped is not None + assert popped["id"] == "msg_pop_same_ts" + + remaining = await session.get_items() + assert [item["id"] for item in remaining] == ["rs_pop_same_ts"] + + +async def test_get_items_orders_by_id_for_ties(): + """Test that get_items adds id ordering to break timestamp ties.""" + session_id = "order_by_id_test" + session = SQLAlchemySession.from_url(session_id, url=DB_URL, create_tables=True) + + await session.add_items( + [ + {"id": "rs_first", "type": "reasoning"}, + {"id": "msg_second", "type": "message"}, + ] + ) + + real_factory = session._session_factory + recorded: list = [] + + @asynccontextmanager + async def wrapped_session(): + async with real_factory() as inner: + original_execute = inner.execute + + async def recording_execute(statement, *args, **kwargs): + recorded.append(statement) + return await original_execute(statement, *args, **kwargs) + + inner.execute = recording_execute # type: ignore[assignment] + try: + yield inner + finally: + inner.execute = original_execute # type: ignore[assignment] + + session._session_factory = wrapped_session # type: ignore[assignment] + try: + retrieved_full = await session.get_items() + retrieved_limited = await session.get_items(limit=2) + finally: + session._session_factory = real_factory # type: ignore[assignment] + + assert len(recorded) >= 2 + orderings_full = [str(clause) for clause in recorded[0]._order_by_clause] + assert orderings_full == [ + "agent_messages.created_at ASC", + "agent_messages.id ASC", + ] + + orderings_limited = [str(clause) for clause in recorded[1]._order_by_clause] + assert orderings_limited == [ + "agent_messages.created_at DESC", + "agent_messages.id DESC", + ] + + assert [item["id"] for item in retrieved_full] == ["rs_first", "msg_second"] + assert [item["id"] for item in retrieved_limited] == ["rs_first", "msg_second"] From 421108d364462ccff37dbddd3ae06696167933d5 Mon Sep 17 00:00:00 2001 From: Kazuhiro Sera Date: Fri, 17 Oct 2025 15:17:16 +0900 Subject: [PATCH 2/4] fix make lint --- tests/extensions/memory/test_sqlalchemy_session.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/extensions/memory/test_sqlalchemy_session.py b/tests/extensions/memory/test_sqlalchemy_session.py index 4265144f4..0188e4000 100644 --- a/tests/extensions/memory/test_sqlalchemy_session.py +++ b/tests/extensions/memory/test_sqlalchemy_session.py @@ -1,8 +1,8 @@ from __future__ import annotations +import json from contextlib import asynccontextmanager from datetime import datetime, timedelta -import json import pytest from sqlalchemy import select, text, update From 412ebb123d68eaf3b76b2b206ff631e188b7231a Mon Sep 17 00:00:00 2001 From: Kazuhiro Sera Date: Fri, 17 Oct 2025 16:06:43 +0900 Subject: [PATCH 3/4] fix mypy issues --- .../memory/test_sqlalchemy_session.py | 133 ++++++++++-------- 1 file changed, 78 insertions(+), 55 deletions(-) diff --git a/tests/extensions/memory/test_sqlalchemy_session.py b/tests/extensions/memory/test_sqlalchemy_session.py index 0188e4000..68435bba9 100644 --- a/tests/extensions/memory/test_sqlalchemy_session.py +++ b/tests/extensions/memory/test_sqlalchemy_session.py @@ -3,6 +3,7 @@ import json from contextlib import asynccontextmanager from datetime import datetime, timedelta +from typing import Any, Dict, Iterable, Sequence, cast import pytest from sqlalchemy import select, text, update @@ -14,6 +15,12 @@ from agents.extensions.memory.sqlalchemy_session import SQLAlchemySession from tests.fake_model import FakeModel from tests.test_responses import get_text_message +from openai.types.responses.response_output_message_param import ResponseOutputMessageParam +from openai.types.responses.response_output_text_param import ResponseOutputTextParam +from openai.types.responses.response_reasoning_item_param import ( + ResponseReasoningItemParam, + Summary, +) # Mark all tests in this file as asyncio pytestmark = pytest.mark.asyncio @@ -22,6 +29,40 @@ DB_URL = "sqlite+aiosqlite:///:memory:" +def _make_message_item(item_id: str, text_value: str) -> TResponseInputItem: + content: ResponseOutputTextParam = { + "type": "output_text", + "text": text_value, + "annotations": [], + } + message: ResponseOutputMessageParam = { + "id": item_id, + "type": "message", + "role": "assistant", + "status": "completed", + "content": [content], + } + return cast(TResponseInputItem, message) + + +def _make_reasoning_item(item_id: str, summary_text: str) -> TResponseInputItem: + summary: Summary = {"type": "summary_text", "text": summary_text} + reasoning: ResponseReasoningItemParam = { + "id": item_id, + "type": "reasoning", + "summary": [summary], + } + return cast(TResponseInputItem, reasoning) + + +def _item_ids(items: Sequence[TResponseInputItem]) -> list[str]: + result: list[str] = [] + for item in items: + item_dict = cast(Dict[str, Any], item) + result.append(cast(str, item_dict["id"])) + return result + + @pytest.fixture def agent() -> Agent: """Fixture for a basic agent with a fake model.""" @@ -164,21 +205,9 @@ async def test_get_items_same_timestamp_consistent_order(): session_id = "same_timestamp_test" session = SQLAlchemySession.from_url(session_id, url=DB_URL, create_tables=True) - older_item: TResponseInputItem = { - "id": "older_same_ts", - "type": "message", - "content": [{"type": "output_text", "text": "old"}], - } - reasoning_item: TResponseInputItem = { - "id": "rs_same_ts", - "type": "reasoning", - "summary": "...", - } - message_item: TResponseInputItem = { - "id": "msg_same_ts", - "type": "message", - "content": [{"type": "output_text", "text": "..."}], - } + older_item = _make_message_item("older_same_ts", "old") + reasoning_item = _make_reasoning_item("rs_same_ts", "...") + message_item = _make_message_item("msg_same_ts", "...") await session.add_items([older_item]) await session.add_items([reasoning_item, message_item]) @@ -214,13 +243,13 @@ async def test_get_items_same_timestamp_consistent_order(): real_factory = session._session_factory class FakeResult: - def __init__(self, rows): - self._rows = rows + def __init__(self, rows: Iterable[Any]): + self._rows = list(rows) - def all(self): + def all(self) -> list[Any]: return list(self._rows) - def needs_shuffle(statement: Select) -> bool: + def needs_shuffle(statement: Any) -> bool: if not isinstance(statement, Select): return False orderings = list(statement._order_by_clause) @@ -231,22 +260,28 @@ def needs_shuffle(statement: Select) -> bool: def references_id(clause) -> bool: try: - return clause.compare(id_asc) or clause.compare(id_desc) + return bool(clause.compare(id_asc) or clause.compare(id_desc)) except AttributeError: return False if any(references_id(clause) for clause in orderings): return False # Only shuffle queries that target the messages table. - target_tables = {from_clause.name for from_clause in statement.get_final_froms()} - return session._messages.name in target_tables + target_tables: set[str] = set() + for from_clause in statement.get_final_froms(): + name_attr = getattr(from_clause, "name", None) + if isinstance(name_attr, str): + target_tables.add(name_attr) + table_name_obj = getattr(session._messages, "name", "") + table_name = table_name_obj if isinstance(table_name_obj, str) else "" + return bool(table_name in target_tables) @asynccontextmanager async def shuffled_session(): async with real_factory() as inner: original_execute = inner.execute - async def execute_with_shuffle(statement, *args, **kwargs): + async def execute_with_shuffle(statement: Any, *args: Any, **kwargs: Any) -> Any: result = await original_execute(statement, *args, **kwargs) if needs_shuffle(statement): rows = result.all() @@ -255,25 +290,21 @@ async def execute_with_shuffle(statement, *args, **kwargs): return FakeResult(shuffled) return result - inner.execute = execute_with_shuffle # type: ignore[assignment] + cast(Any, inner).execute = execute_with_shuffle try: yield inner finally: - inner.execute = original_execute # type: ignore[assignment] + cast(Any, inner).execute = original_execute - session._session_factory = shuffled_session # type: ignore[assignment] + session._session_factory = cast(Any, shuffled_session) try: retrieved = await session.get_items() - assert [item["id"] for item in retrieved] == [ - "older_same_ts", - "rs_same_ts", - "msg_same_ts", - ] + assert _item_ids(retrieved) == ["older_same_ts", "rs_same_ts", "msg_same_ts"] latest_two = await session.get_items(limit=2) - assert [item["id"] for item in latest_two] == ["rs_same_ts", "msg_same_ts"] + assert _item_ids(latest_two) == ["rs_same_ts", "msg_same_ts"] finally: - session._session_factory = real_factory # type: ignore[assignment] + session._session_factory = real_factory async def test_pop_item_same_timestamp_returns_latest(): @@ -281,16 +312,8 @@ async def test_pop_item_same_timestamp_returns_latest(): session_id = "same_timestamp_pop_test" session = SQLAlchemySession.from_url(session_id, url=DB_URL, create_tables=True) - reasoning_item: TResponseInputItem = { - "id": "rs_pop_same_ts", - "type": "reasoning", - "summary": "...", - } - message_item: TResponseInputItem = { - "id": "msg_pop_same_ts", - "type": "message", - "content": [{"type": "output_text", "text": "..."}], - } + reasoning_item = _make_reasoning_item("rs_pop_same_ts", "...") + message_item = _make_message_item("msg_pop_same_ts", "...") await session.add_items([reasoning_item, message_item]) async with session._session_factory() as sess: @@ -309,10 +332,10 @@ async def test_pop_item_same_timestamp_returns_latest(): popped = await session.pop_item() assert popped is not None - assert popped["id"] == "msg_pop_same_ts" + assert cast(Dict[str, Any], popped)["id"] == "msg_pop_same_ts" remaining = await session.get_items() - assert [item["id"] for item in remaining] == ["rs_pop_same_ts"] + assert _item_ids(remaining) == ["rs_pop_same_ts"] async def test_get_items_orders_by_id_for_ties(): @@ -322,35 +345,35 @@ async def test_get_items_orders_by_id_for_ties(): await session.add_items( [ - {"id": "rs_first", "type": "reasoning"}, - {"id": "msg_second", "type": "message"}, + _make_reasoning_item("rs_first", "..."), + _make_message_item("msg_second", "..."), ] ) real_factory = session._session_factory - recorded: list = [] + recorded: list[Any] = [] @asynccontextmanager async def wrapped_session(): async with real_factory() as inner: original_execute = inner.execute - async def recording_execute(statement, *args, **kwargs): + async def recording_execute(statement: Any, *args: Any, **kwargs: Any) -> Any: recorded.append(statement) return await original_execute(statement, *args, **kwargs) - inner.execute = recording_execute # type: ignore[assignment] + cast(Any, inner).execute = recording_execute try: yield inner finally: - inner.execute = original_execute # type: ignore[assignment] + cast(Any, inner).execute = original_execute - session._session_factory = wrapped_session # type: ignore[assignment] + session._session_factory = cast(Any, wrapped_session) try: retrieved_full = await session.get_items() retrieved_limited = await session.get_items(limit=2) finally: - session._session_factory = real_factory # type: ignore[assignment] + session._session_factory = real_factory assert len(recorded) >= 2 orderings_full = [str(clause) for clause in recorded[0]._order_by_clause] @@ -365,5 +388,5 @@ async def recording_execute(statement, *args, **kwargs): "agent_messages.id DESC", ] - assert [item["id"] for item in retrieved_full] == ["rs_first", "msg_second"] - assert [item["id"] for item in retrieved_limited] == ["rs_first", "msg_second"] + assert _item_ids(retrieved_full) == ["rs_first", "msg_second"] + assert _item_ids(retrieved_limited) == ["rs_first", "msg_second"] From a9d4ff87f246bf4010b93b4bd3d5edaf0fa03880 Mon Sep 17 00:00:00 2001 From: Kazuhiro Sera Date: Fri, 17 Oct 2025 16:22:20 +0900 Subject: [PATCH 4/4] fix lint issues --- .../memory/test_sqlalchemy_session.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/tests/extensions/memory/test_sqlalchemy_session.py b/tests/extensions/memory/test_sqlalchemy_session.py index 68435bba9..496d0b027 100644 --- a/tests/extensions/memory/test_sqlalchemy_session.py +++ b/tests/extensions/memory/test_sqlalchemy_session.py @@ -1,11 +1,18 @@ from __future__ import annotations import json +from collections.abc import Iterable, Sequence from contextlib import asynccontextmanager from datetime import datetime, timedelta -from typing import Any, Dict, Iterable, Sequence, cast +from typing import Any, cast import pytest +from openai.types.responses.response_output_message_param import ResponseOutputMessageParam +from openai.types.responses.response_output_text_param import ResponseOutputTextParam +from openai.types.responses.response_reasoning_item_param import ( + ResponseReasoningItemParam, + Summary, +) from sqlalchemy import select, text, update from sqlalchemy.sql import Select @@ -15,12 +22,6 @@ from agents.extensions.memory.sqlalchemy_session import SQLAlchemySession from tests.fake_model import FakeModel from tests.test_responses import get_text_message -from openai.types.responses.response_output_message_param import ResponseOutputMessageParam -from openai.types.responses.response_output_text_param import ResponseOutputTextParam -from openai.types.responses.response_reasoning_item_param import ( - ResponseReasoningItemParam, - Summary, -) # Mark all tests in this file as asyncio pytestmark = pytest.mark.asyncio @@ -58,7 +59,7 @@ def _make_reasoning_item(item_id: str, summary_text: str) -> TResponseInputItem: def _item_ids(items: Sequence[TResponseInputItem]) -> list[str]: result: list[str] = [] for item in items: - item_dict = cast(Dict[str, Any], item) + item_dict = cast(dict[str, Any], item) result.append(cast(str, item_dict["id"])) return result @@ -332,7 +333,7 @@ async def test_pop_item_same_timestamp_returns_latest(): popped = await session.pop_item() assert popped is not None - assert cast(Dict[str, Any], popped)["id"] == "msg_pop_same_ts" + assert cast(dict[str, Any], popped)["id"] == "msg_pop_same_ts" remaining = await session.get_items() assert _item_ids(remaining) == ["rs_pop_same_ts"]