Skip to content

Commit f669093

Browse files
committed
temporary fixes to pass typecheck and old_versions CI tests
1 parent e534fa0 commit f669093

File tree

3 files changed

+29
-39
lines changed

3 files changed

+29
-39
lines changed

src/agents/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
RunErrorDetails,
2424
UserError,
2525
)
26-
from .extensions.memory import SQLAlchemySession
2726
from .guardrail import (
2827
GuardrailFunctionOutput,
2928
InputGuardrail,
@@ -223,7 +222,6 @@ def enable_verbose_stdout_logging():
223222
"AgentHooks",
224223
"Session",
225224
"SQLiteSession",
226-
"SQLAlchemySession",
227225
"RunContextWrapper",
228226
"TContext",
229227
"RunErrorDetails",

src/agents/extensions/memory/sqlalchemy_session.py

Lines changed: 25 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
import json
2626
from typing import Any
2727

28-
from sqlalchemy import ( # type: ignore
28+
from sqlalchemy import (
2929
TIMESTAMP,
3030
Column,
3131
ForeignKey,
@@ -40,10 +40,10 @@
4040
text as sql_text,
4141
update,
4242
)
43-
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine # type: ignore
44-
from sqlalchemy.orm import sessionmaker # type: ignore
43+
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine
44+
from sqlalchemy.ext.asyncio import async_sessionmaker
4545

46-
from ...items import TResponseInputItem # type: ignore
46+
from ...items import TResponseInputItem
4747
from ...memory.session import SessionABC
4848

4949

@@ -124,7 +124,7 @@ def __init__(
124124
)
125125

126126
# Index for efficient retrieval of messages per session ordered by time
127-
from sqlalchemy import Index # type: ignore
127+
from sqlalchemy import Index
128128

129129
Index(
130130
f"idx_{messages_table}_session_time",
@@ -133,8 +133,8 @@ def __init__(
133133
)
134134

135135
# Async session factory
136-
self._session_factory: sessionmaker[AsyncSession] = sessionmaker(
137-
self._engine, expire_on_commit=False, class_=AsyncSession
136+
self._session_factory = async_sessionmaker(
137+
self._engine, expire_on_commit=False
138138
)
139139

140140
self._create_tables = create_tables
@@ -252,40 +252,28 @@ async def pop_item(self) -> TResponseInputItem | None:
252252
await self._ensure_tables()
253253
async with self._session_factory() as sess:
254254
async with sess.begin():
255-
# First try dialects that support DELETE … RETURNING for atomicity
256-
try:
257-
stmt = (
258-
delete(self._messages)
259-
.where(self._messages.c.session_id == self.session_id)
260-
.order_by(self._messages.c.created_at.desc())
261-
.limit(1)
262-
.returning(self._messages.c.message_data)
263-
)
264-
result = await sess.execute(stmt)
265-
row = result.scalar_one_or_none()
266-
except Exception: # pragma: no cover – fallback path
267-
# Fallback for dialects that don't support ORDER BY in DELETE
268-
subq = (
269-
select(self._messages.c.id)
270-
.where(self._messages.c.session_id == self.session_id)
271-
.order_by(self._messages.c.created_at.desc())
272-
.limit(1)
273-
)
274-
res = await sess.execute(subq)
275-
row_id = res.scalar_one_or_none()
276-
if row_id is None:
277-
return None
278-
# Fetch data before deleting
279-
res_data = await sess.execute(
280-
select(self._messages.c.message_data).where(self._messages.c.id == row_id)
281-
)
282-
row = res_data.scalar_one_or_none()
283-
await sess.execute(delete(self._messages).where(self._messages.c.id == row_id))
255+
# Fallback for all dialects - get ID first, then delete
256+
subq = (
257+
select(self._messages.c.id)
258+
.where(self._messages.c.session_id == self.session_id)
259+
.order_by(self._messages.c.created_at.desc())
260+
.limit(1)
261+
)
262+
res = await sess.execute(subq)
263+
row_id = res.scalar_one_or_none()
264+
if row_id is None:
265+
return None
266+
# Fetch data before deleting
267+
res_data = await sess.execute(
268+
select(self._messages.c.message_data).where(self._messages.c.id == row_id)
269+
)
270+
row = res_data.scalar_one_or_none()
271+
await sess.execute(delete(self._messages).where(self._messages.c.id == row_id))
284272

285273
if row is None:
286274
return None
287275
try:
288-
return json.loads(row)
276+
return json.loads(row) # type: ignore[no-any-return]
289277
except json.JSONDecodeError:
290278
return None
291279

tests/extensions/memory/test_sqlalchemy_session.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
import pytest
44

5+
pytest.importorskip("sqlalchemy") # Skip tests if SQLAlchemy is not installed
6+
57
from agents import Agent, Runner, TResponseInputItem
68
from agents.extensions.memory.sqlalchemy_session import SQLAlchemySession
79
from tests.fake_model import FakeModel
@@ -58,6 +60,7 @@ async def test_runner_integration(agent: Agent):
5860
session = SQLAlchemySession.from_url(session_id, url=DB_URL)
5961

6062
# First turn
63+
assert isinstance(agent.model, FakeModel)
6164
agent.model.set_next_output([get_text_message("San Francisco")])
6265
result1 = await Runner.run(
6366
agent,
@@ -86,6 +89,7 @@ async def test_session_isolation(agent: Agent):
8689
session2 = SQLAlchemySession.from_url(session_id_2, url=DB_URL)
8790

8891
# Interact with session 1
92+
assert isinstance(agent.model, FakeModel)
8993
agent.model.set_next_output([get_text_message("I like cats.")])
9094
await Runner.run(agent, "I like cats.", session=session1)
9195

0 commit comments

Comments
 (0)