Skip to content

Commit e361ab9

Browse files
committed
fix: #1900 fix a bug where SQLAlchemySession could return items in an invalid order
1 parent d9f1d5f commit e361ab9

File tree

2 files changed

+228
-3
lines changed

2 files changed

+228
-3
lines changed

src/agents/extensions/memory/sqlalchemy_session.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -195,15 +195,21 @@ async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:
195195
stmt = (
196196
select(self._messages.c.message_data)
197197
.where(self._messages.c.session_id == self.session_id)
198-
.order_by(self._messages.c.created_at.asc())
198+
.order_by(
199+
self._messages.c.created_at.asc(),
200+
self._messages.c.id.asc(),
201+
)
199202
)
200203
else:
201204
stmt = (
202205
select(self._messages.c.message_data)
203206
.where(self._messages.c.session_id == self.session_id)
204207
# Use DESC + LIMIT to get the latest N
205208
# then reverse later for chronological order.
206-
.order_by(self._messages.c.created_at.desc())
209+
.order_by(
210+
self._messages.c.created_at.desc(),
211+
self._messages.c.id.desc(),
212+
)
207213
.limit(limit)
208214
)
209215

@@ -278,7 +284,10 @@ async def pop_item(self) -> TResponseInputItem | None:
278284
subq = (
279285
select(self._messages.c.id)
280286
.where(self._messages.c.session_id == self.session_id)
281-
.order_by(self._messages.c.created_at.desc())
287+
.order_by(
288+
self._messages.c.created_at.desc(),
289+
self._messages.c.id.desc(),
290+
)
282291
.limit(1)
283292
)
284293
res = await sess.execute(subq)

tests/extensions/memory/test_sqlalchemy_session.py

Lines changed: 216 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
11
from __future__ import annotations
22

3+
from contextlib import asynccontextmanager
4+
from datetime import datetime, timedelta
5+
import json
6+
37
import pytest
8+
from sqlalchemy import select, text, update
9+
from sqlalchemy.sql import Select
410

511
pytest.importorskip("sqlalchemy") # Skip tests if SQLAlchemy is not installed
612

@@ -151,3 +157,213 @@ async def test_add_empty_items_list():
151157

152158
items_after_add = await session.get_items()
153159
assert len(items_after_add) == 0
160+
161+
162+
async def test_get_items_same_timestamp_consistent_order():
163+
"""Test that items with identical timestamps keep insertion order."""
164+
session_id = "same_timestamp_test"
165+
session = SQLAlchemySession.from_url(session_id, url=DB_URL, create_tables=True)
166+
167+
older_item: TResponseInputItem = {
168+
"id": "older_same_ts",
169+
"type": "message",
170+
"content": [{"type": "output_text", "text": "old"}],
171+
}
172+
reasoning_item: TResponseInputItem = {
173+
"id": "rs_same_ts",
174+
"type": "reasoning",
175+
"summary": "...",
176+
}
177+
message_item: TResponseInputItem = {
178+
"id": "msg_same_ts",
179+
"type": "message",
180+
"content": [{"type": "output_text", "text": "..."}],
181+
}
182+
await session.add_items([older_item])
183+
await session.add_items([reasoning_item, message_item])
184+
185+
async with session._session_factory() as sess:
186+
rows = await sess.execute(
187+
select(session._messages.c.id, session._messages.c.message_data).where(
188+
session._messages.c.session_id == session.session_id
189+
)
190+
)
191+
id_map = {
192+
json.loads(message_json)["id"]: row_id
193+
for row_id, message_json in rows.fetchall()
194+
}
195+
shared = datetime(2025, 10, 15, 17, 26, 39, 132483)
196+
older = shared - timedelta(milliseconds=1)
197+
await sess.execute(
198+
update(session._messages)
199+
.where(session._messages.c.id.in_(
200+
[
201+
id_map["rs_same_ts"],
202+
id_map["msg_same_ts"],
203+
]
204+
))
205+
.values(created_at=shared)
206+
)
207+
await sess.execute(
208+
update(session._messages)
209+
.where(session._messages.c.id == id_map["older_same_ts"])
210+
.values(created_at=older)
211+
)
212+
await sess.commit()
213+
214+
real_factory = session._session_factory
215+
216+
class FakeResult:
217+
def __init__(self, rows):
218+
self._rows = rows
219+
220+
def all(self):
221+
return list(self._rows)
222+
223+
def needs_shuffle(statement: Select) -> bool:
224+
if not isinstance(statement, Select):
225+
return False
226+
orderings = list(statement._order_by_clause)
227+
if not orderings:
228+
return False
229+
id_asc = session._messages.c.id.asc()
230+
id_desc = session._messages.c.id.desc()
231+
232+
def references_id(clause) -> bool:
233+
try:
234+
return clause.compare(id_asc) or clause.compare(id_desc)
235+
except AttributeError:
236+
return False
237+
238+
if any(references_id(clause) for clause in orderings):
239+
return False
240+
# Only shuffle queries that target the messages table.
241+
target_tables = {from_clause.name for from_clause in statement.get_final_froms()}
242+
return session._messages.name in target_tables
243+
244+
@asynccontextmanager
245+
async def shuffled_session():
246+
async with real_factory() as inner:
247+
original_execute = inner.execute
248+
249+
async def execute_with_shuffle(statement, *args, **kwargs):
250+
result = await original_execute(statement, *args, **kwargs)
251+
if needs_shuffle(statement):
252+
rows = result.all()
253+
shuffled = list(rows)
254+
shuffled.reverse()
255+
return FakeResult(shuffled)
256+
return result
257+
258+
inner.execute = execute_with_shuffle # type: ignore[assignment]
259+
try:
260+
yield inner
261+
finally:
262+
inner.execute = original_execute # type: ignore[assignment]
263+
264+
session._session_factory = shuffled_session # type: ignore[assignment]
265+
try:
266+
retrieved = await session.get_items()
267+
assert [item["id"] for item in retrieved] == [
268+
"older_same_ts",
269+
"rs_same_ts",
270+
"msg_same_ts",
271+
]
272+
273+
latest_two = await session.get_items(limit=2)
274+
assert [item["id"] for item in latest_two] == ["rs_same_ts", "msg_same_ts"]
275+
finally:
276+
session._session_factory = real_factory # type: ignore[assignment]
277+
278+
279+
async def test_pop_item_same_timestamp_returns_latest():
280+
"""Test that pop_item returns the newest item when timestamps tie."""
281+
session_id = "same_timestamp_pop_test"
282+
session = SQLAlchemySession.from_url(session_id, url=DB_URL, create_tables=True)
283+
284+
reasoning_item: TResponseInputItem = {
285+
"id": "rs_pop_same_ts",
286+
"type": "reasoning",
287+
"summary": "...",
288+
}
289+
message_item: TResponseInputItem = {
290+
"id": "msg_pop_same_ts",
291+
"type": "message",
292+
"content": [{"type": "output_text", "text": "..."}],
293+
}
294+
await session.add_items([reasoning_item, message_item])
295+
296+
async with session._session_factory() as sess:
297+
await sess.execute(
298+
text(
299+
"UPDATE agent_messages "
300+
"SET created_at = :created_at "
301+
"WHERE session_id = :session_id"
302+
),
303+
{
304+
"created_at": "2025-10-15 17:26:39.132483",
305+
"session_id": session.session_id,
306+
},
307+
)
308+
await sess.commit()
309+
310+
popped = await session.pop_item()
311+
assert popped is not None
312+
assert popped["id"] == "msg_pop_same_ts"
313+
314+
remaining = await session.get_items()
315+
assert [item["id"] for item in remaining] == ["rs_pop_same_ts"]
316+
317+
318+
async def test_get_items_orders_by_id_for_ties():
319+
"""Test that get_items adds id ordering to break timestamp ties."""
320+
session_id = "order_by_id_test"
321+
session = SQLAlchemySession.from_url(session_id, url=DB_URL, create_tables=True)
322+
323+
await session.add_items(
324+
[
325+
{"id": "rs_first", "type": "reasoning"},
326+
{"id": "msg_second", "type": "message"},
327+
]
328+
)
329+
330+
real_factory = session._session_factory
331+
recorded: list = []
332+
333+
@asynccontextmanager
334+
async def wrapped_session():
335+
async with real_factory() as inner:
336+
original_execute = inner.execute
337+
338+
async def recording_execute(statement, *args, **kwargs):
339+
recorded.append(statement)
340+
return await original_execute(statement, *args, **kwargs)
341+
342+
inner.execute = recording_execute # type: ignore[assignment]
343+
try:
344+
yield inner
345+
finally:
346+
inner.execute = original_execute # type: ignore[assignment]
347+
348+
session._session_factory = wrapped_session # type: ignore[assignment]
349+
try:
350+
retrieved_full = await session.get_items()
351+
retrieved_limited = await session.get_items(limit=2)
352+
finally:
353+
session._session_factory = real_factory # type: ignore[assignment]
354+
355+
assert len(recorded) >= 2
356+
orderings_full = [str(clause) for clause in recorded[0]._order_by_clause]
357+
assert orderings_full == [
358+
"agent_messages.created_at ASC",
359+
"agent_messages.id ASC",
360+
]
361+
362+
orderings_limited = [str(clause) for clause in recorded[1]._order_by_clause]
363+
assert orderings_limited == [
364+
"agent_messages.created_at DESC",
365+
"agent_messages.id DESC",
366+
]
367+
368+
assert [item["id"] for item in retrieved_full] == ["rs_first", "msg_second"]
369+
assert [item["id"] for item in retrieved_limited] == ["rs_first", "msg_second"]

0 commit comments

Comments
 (0)