|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
| 3 | +from contextlib import asynccontextmanager |
| 4 | +from datetime import datetime, timedelta |
| 5 | +import json |
| 6 | + |
3 | 7 | import pytest |
| 8 | +from sqlalchemy import select, text, update |
| 9 | +from sqlalchemy.sql import Select |
4 | 10 |
|
5 | 11 | pytest.importorskip("sqlalchemy") # Skip tests if SQLAlchemy is not installed |
6 | 12 |
|
@@ -151,3 +157,213 @@ async def test_add_empty_items_list(): |
151 | 157 |
|
152 | 158 | items_after_add = await session.get_items() |
153 | 159 | assert len(items_after_add) == 0 |
| 160 | + |
| 161 | + |
| 162 | +async def test_get_items_same_timestamp_consistent_order(): |
| 163 | + """Test that items with identical timestamps keep insertion order.""" |
| 164 | + session_id = "same_timestamp_test" |
| 165 | + session = SQLAlchemySession.from_url(session_id, url=DB_URL, create_tables=True) |
| 166 | + |
| 167 | + older_item: TResponseInputItem = { |
| 168 | + "id": "older_same_ts", |
| 169 | + "type": "message", |
| 170 | + "content": [{"type": "output_text", "text": "old"}], |
| 171 | + } |
| 172 | + reasoning_item: TResponseInputItem = { |
| 173 | + "id": "rs_same_ts", |
| 174 | + "type": "reasoning", |
| 175 | + "summary": "...", |
| 176 | + } |
| 177 | + message_item: TResponseInputItem = { |
| 178 | + "id": "msg_same_ts", |
| 179 | + "type": "message", |
| 180 | + "content": [{"type": "output_text", "text": "..."}], |
| 181 | + } |
| 182 | + await session.add_items([older_item]) |
| 183 | + await session.add_items([reasoning_item, message_item]) |
| 184 | + |
| 185 | + async with session._session_factory() as sess: |
| 186 | + rows = await sess.execute( |
| 187 | + select(session._messages.c.id, session._messages.c.message_data).where( |
| 188 | + session._messages.c.session_id == session.session_id |
| 189 | + ) |
| 190 | + ) |
| 191 | + id_map = { |
| 192 | + json.loads(message_json)["id"]: row_id |
| 193 | + for row_id, message_json in rows.fetchall() |
| 194 | + } |
| 195 | + shared = datetime(2025, 10, 15, 17, 26, 39, 132483) |
| 196 | + older = shared - timedelta(milliseconds=1) |
| 197 | + await sess.execute( |
| 198 | + update(session._messages) |
| 199 | + .where(session._messages.c.id.in_( |
| 200 | + [ |
| 201 | + id_map["rs_same_ts"], |
| 202 | + id_map["msg_same_ts"], |
| 203 | + ] |
| 204 | + )) |
| 205 | + .values(created_at=shared) |
| 206 | + ) |
| 207 | + await sess.execute( |
| 208 | + update(session._messages) |
| 209 | + .where(session._messages.c.id == id_map["older_same_ts"]) |
| 210 | + .values(created_at=older) |
| 211 | + ) |
| 212 | + await sess.commit() |
| 213 | + |
| 214 | + real_factory = session._session_factory |
| 215 | + |
| 216 | + class FakeResult: |
| 217 | + def __init__(self, rows): |
| 218 | + self._rows = rows |
| 219 | + |
| 220 | + def all(self): |
| 221 | + return list(self._rows) |
| 222 | + |
| 223 | + def needs_shuffle(statement: Select) -> bool: |
| 224 | + if not isinstance(statement, Select): |
| 225 | + return False |
| 226 | + orderings = list(statement._order_by_clause) |
| 227 | + if not orderings: |
| 228 | + return False |
| 229 | + id_asc = session._messages.c.id.asc() |
| 230 | + id_desc = session._messages.c.id.desc() |
| 231 | + |
| 232 | + def references_id(clause) -> bool: |
| 233 | + try: |
| 234 | + return clause.compare(id_asc) or clause.compare(id_desc) |
| 235 | + except AttributeError: |
| 236 | + return False |
| 237 | + |
| 238 | + if any(references_id(clause) for clause in orderings): |
| 239 | + return False |
| 240 | + # Only shuffle queries that target the messages table. |
| 241 | + target_tables = {from_clause.name for from_clause in statement.get_final_froms()} |
| 242 | + return session._messages.name in target_tables |
| 243 | + |
| 244 | + @asynccontextmanager |
| 245 | + async def shuffled_session(): |
| 246 | + async with real_factory() as inner: |
| 247 | + original_execute = inner.execute |
| 248 | + |
| 249 | + async def execute_with_shuffle(statement, *args, **kwargs): |
| 250 | + result = await original_execute(statement, *args, **kwargs) |
| 251 | + if needs_shuffle(statement): |
| 252 | + rows = result.all() |
| 253 | + shuffled = list(rows) |
| 254 | + shuffled.reverse() |
| 255 | + return FakeResult(shuffled) |
| 256 | + return result |
| 257 | + |
| 258 | + inner.execute = execute_with_shuffle # type: ignore[assignment] |
| 259 | + try: |
| 260 | + yield inner |
| 261 | + finally: |
| 262 | + inner.execute = original_execute # type: ignore[assignment] |
| 263 | + |
| 264 | + session._session_factory = shuffled_session # type: ignore[assignment] |
| 265 | + try: |
| 266 | + retrieved = await session.get_items() |
| 267 | + assert [item["id"] for item in retrieved] == [ |
| 268 | + "older_same_ts", |
| 269 | + "rs_same_ts", |
| 270 | + "msg_same_ts", |
| 271 | + ] |
| 272 | + |
| 273 | + latest_two = await session.get_items(limit=2) |
| 274 | + assert [item["id"] for item in latest_two] == ["rs_same_ts", "msg_same_ts"] |
| 275 | + finally: |
| 276 | + session._session_factory = real_factory # type: ignore[assignment] |
| 277 | + |
| 278 | + |
| 279 | +async def test_pop_item_same_timestamp_returns_latest(): |
| 280 | + """Test that pop_item returns the newest item when timestamps tie.""" |
| 281 | + session_id = "same_timestamp_pop_test" |
| 282 | + session = SQLAlchemySession.from_url(session_id, url=DB_URL, create_tables=True) |
| 283 | + |
| 284 | + reasoning_item: TResponseInputItem = { |
| 285 | + "id": "rs_pop_same_ts", |
| 286 | + "type": "reasoning", |
| 287 | + "summary": "...", |
| 288 | + } |
| 289 | + message_item: TResponseInputItem = { |
| 290 | + "id": "msg_pop_same_ts", |
| 291 | + "type": "message", |
| 292 | + "content": [{"type": "output_text", "text": "..."}], |
| 293 | + } |
| 294 | + await session.add_items([reasoning_item, message_item]) |
| 295 | + |
| 296 | + async with session._session_factory() as sess: |
| 297 | + await sess.execute( |
| 298 | + text( |
| 299 | + "UPDATE agent_messages " |
| 300 | + "SET created_at = :created_at " |
| 301 | + "WHERE session_id = :session_id" |
| 302 | + ), |
| 303 | + { |
| 304 | + "created_at": "2025-10-15 17:26:39.132483", |
| 305 | + "session_id": session.session_id, |
| 306 | + }, |
| 307 | + ) |
| 308 | + await sess.commit() |
| 309 | + |
| 310 | + popped = await session.pop_item() |
| 311 | + assert popped is not None |
| 312 | + assert popped["id"] == "msg_pop_same_ts" |
| 313 | + |
| 314 | + remaining = await session.get_items() |
| 315 | + assert [item["id"] for item in remaining] == ["rs_pop_same_ts"] |
| 316 | + |
| 317 | + |
| 318 | +async def test_get_items_orders_by_id_for_ties(): |
| 319 | + """Test that get_items adds id ordering to break timestamp ties.""" |
| 320 | + session_id = "order_by_id_test" |
| 321 | + session = SQLAlchemySession.from_url(session_id, url=DB_URL, create_tables=True) |
| 322 | + |
| 323 | + await session.add_items( |
| 324 | + [ |
| 325 | + {"id": "rs_first", "type": "reasoning"}, |
| 326 | + {"id": "msg_second", "type": "message"}, |
| 327 | + ] |
| 328 | + ) |
| 329 | + |
| 330 | + real_factory = session._session_factory |
| 331 | + recorded: list = [] |
| 332 | + |
| 333 | + @asynccontextmanager |
| 334 | + async def wrapped_session(): |
| 335 | + async with real_factory() as inner: |
| 336 | + original_execute = inner.execute |
| 337 | + |
| 338 | + async def recording_execute(statement, *args, **kwargs): |
| 339 | + recorded.append(statement) |
| 340 | + return await original_execute(statement, *args, **kwargs) |
| 341 | + |
| 342 | + inner.execute = recording_execute # type: ignore[assignment] |
| 343 | + try: |
| 344 | + yield inner |
| 345 | + finally: |
| 346 | + inner.execute = original_execute # type: ignore[assignment] |
| 347 | + |
| 348 | + session._session_factory = wrapped_session # type: ignore[assignment] |
| 349 | + try: |
| 350 | + retrieved_full = await session.get_items() |
| 351 | + retrieved_limited = await session.get_items(limit=2) |
| 352 | + finally: |
| 353 | + session._session_factory = real_factory # type: ignore[assignment] |
| 354 | + |
| 355 | + assert len(recorded) >= 2 |
| 356 | + orderings_full = [str(clause) for clause in recorded[0]._order_by_clause] |
| 357 | + assert orderings_full == [ |
| 358 | + "agent_messages.created_at ASC", |
| 359 | + "agent_messages.id ASC", |
| 360 | + ] |
| 361 | + |
| 362 | + orderings_limited = [str(clause) for clause in recorded[1]._order_by_clause] |
| 363 | + assert orderings_limited == [ |
| 364 | + "agent_messages.created_at DESC", |
| 365 | + "agent_messages.id DESC", |
| 366 | + ] |
| 367 | + |
| 368 | + assert [item["id"] for item in retrieved_full] == ["rs_first", "msg_second"] |
| 369 | + assert [item["id"] for item in retrieved_limited] == ["rs_first", "msg_second"] |
0 commit comments