Skip to content

Commit 843f2be

Browse files
committed
Fix SQLiteSession threading.Lock() bug and file descriptor leak
This PR addresses two critical bugs in SQLiteSession: ## Bug 1: threading.Lock() creating new instances **Problem:** In SQLiteSession (4 places) and AdvancedSQLiteSession (8 places), the code used: ```python with self._lock if self._is_memory_db else threading.Lock(): ``` For file-based databases, this creates a NEW Lock() instance on every operation, providing NO thread safety whatsoever. Only in-memory databases used self._lock. **Impact:** - File-based SQLiteSession had zero thread protection - Race conditions possible but masked by WAL mode's own concurrency handling ## Bug 2: File descriptor leak **Problem:** Thread-local connections in ThreadPoolExecutor are never cleaned up: - asyncio.to_thread() uses ThreadPoolExecutor internally - Each worker thread creates a connection on first use - ThreadPoolExecutor reuses threads indefinitely - Connections persist until program exit, accumulating file descriptors **Evidence:** Testing on main branch (60s, 40 concurrent workers): - My system (FD limit 1,048,575): +789 FDs leaked, 0 errors (limit not reached) - @ihower's system (likely limit 1,024): 646,802 errors in 20 seconds Error: `sqlite3.OperationalError: unable to open database file` ## Solution: Unified shared connection approach Instead of managing thread-local connections that can't be reliably cleaned up in ThreadPoolExecutor, use a single shared connection for all database types. **Changes:** 1. Removed thread-local connection logic (eliminates FD leak root cause) 2. All database types now use shared connection + self._lock 3. SQLite's WAL mode provides sufficient concurrency even with single connection 4. Fixed all 12 instances of threading.Lock() bug (4 in SQLiteSession, 8 in AdvancedSQLiteSession) 5. Kept _is_memory_db attribute for backward compatibility with AdvancedSQLiteSession 6. Added close() and __del__() methods for proper cleanup **Results (60s stress test, 30 writers + 10 readers):** ``` Main branch: - FD growth: +789 (leak) - Throughput: 701 ops/s - Errors: 0 on high-limit systems, 646k+ on normal systems After fix: - FD growth: +44 (stable) - Throughput: 726 ops/s (+3.6% improvement) - Errors: 0 on all systems - All 29 SQLite tests pass ``` ## Why shared connection performs better SQLite's WAL (Write-Ahead Logging) mode already provides: - Multiple concurrent readers - One writer coexisting with readers - Readers don't block writer - Writer doesn't block readers (except during checkpoint) The overhead of managing multiple connections outweighs any concurrency benefit. ## Backward compatibility The _is_memory_db attribute is preserved for AdvancedSQLiteSession compatibility, even though the implementation no longer differentiates connection strategies. ## Testing Comprehensive stress test available at: https://gist.github.com/gn00295120/0b6a65fe6c0ac6b7a1ce23654eed3ffe Run with: `python sqlite_stress_test_final.py`
1 parent 4bc33e3 commit 843f2be

File tree

2 files changed

+97
-110
lines changed

2 files changed

+97
-110
lines changed

src/agents/extensions/memory/advanced_sqlite_session.py

Lines changed: 75 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import asyncio
44
import json
55
import logging
6-
import threading
76
from contextlib import closing
87
from pathlib import Path
98
from typing import Any, Union, cast
@@ -146,7 +145,7 @@ def _get_all_items_sync():
146145
"""Synchronous helper to get all items for a branch."""
147146
conn = self._get_connection()
148147
# TODO: Refactor SQLiteSession to use asyncio.Lock instead of threading.Lock and update this code # noqa: E501
149-
with self._lock if self._is_memory_db else threading.Lock():
148+
with self._lock:
150149
with closing(conn.cursor()) as cursor:
151150
if limit is None:
152151
cursor.execute(
@@ -191,7 +190,7 @@ def _get_items_sync():
191190
"""Synchronous helper to get items for a specific branch."""
192191
conn = self._get_connection()
193192
# TODO: Refactor SQLiteSession to use asyncio.Lock instead of threading.Lock and update this code # noqa: E501
194-
with self._lock if self._is_memory_db else threading.Lock():
193+
with self._lock:
195194
with closing(conn.cursor()) as cursor:
196195
# Get message IDs in correct order for this branch
197196
if limit is None:
@@ -261,18 +260,19 @@ def _get_next_turn_number(self, branch_id: str) -> int:
261260
The next available turn number for the specified branch.
262261
"""
263262
conn = self._get_connection()
264-
with closing(conn.cursor()) as cursor:
265-
cursor.execute(
266-
"""
267-
SELECT COALESCE(MAX(user_turn_number), 0)
268-
FROM message_structure
269-
WHERE session_id = ? AND branch_id = ?
270-
""",
271-
(self.session_id, branch_id),
272-
)
273-
result = cursor.fetchone()
274-
max_turn = result[0] if result else 0
275-
return max_turn + 1
263+
with self._lock:
264+
with closing(conn.cursor()) as cursor:
265+
cursor.execute(
266+
"""
267+
SELECT COALESCE(MAX(user_turn_number), 0)
268+
FROM message_structure
269+
WHERE session_id = ? AND branch_id = ?
270+
""",
271+
(self.session_id, branch_id),
272+
)
273+
result = cursor.fetchone()
274+
max_turn = result[0] if result else 0
275+
return max_turn + 1
276276

277277
def _get_next_branch_turn_number(self, branch_id: str) -> int:
278278
"""Get the next branch turn number for a specific branch.
@@ -284,18 +284,19 @@ def _get_next_branch_turn_number(self, branch_id: str) -> int:
284284
The next available branch turn number for the specified branch.
285285
"""
286286
conn = self._get_connection()
287-
with closing(conn.cursor()) as cursor:
288-
cursor.execute(
289-
"""
290-
SELECT COALESCE(MAX(branch_turn_number), 0)
291-
FROM message_structure
292-
WHERE session_id = ? AND branch_id = ?
293-
""",
294-
(self.session_id, branch_id),
295-
)
296-
result = cursor.fetchone()
297-
max_turn = result[0] if result else 0
298-
return max_turn + 1
287+
with self._lock:
288+
with closing(conn.cursor()) as cursor:
289+
cursor.execute(
290+
"""
291+
SELECT COALESCE(MAX(branch_turn_number), 0)
292+
FROM message_structure
293+
WHERE session_id = ? AND branch_id = ?
294+
""",
295+
(self.session_id, branch_id),
296+
)
297+
result = cursor.fetchone()
298+
max_turn = result[0] if result else 0
299+
return max_turn + 1
299300

300301
def _get_current_turn_number(self) -> int:
301302
"""Get the current turn number for the current branch.
@@ -304,17 +305,18 @@ def _get_current_turn_number(self) -> int:
304305
The current turn number for the active branch.
305306
"""
306307
conn = self._get_connection()
307-
with closing(conn.cursor()) as cursor:
308-
cursor.execute(
309-
"""
310-
SELECT COALESCE(MAX(user_turn_number), 0)
311-
FROM message_structure
312-
WHERE session_id = ? AND branch_id = ?
313-
""",
314-
(self.session_id, self._current_branch_id),
315-
)
316-
result = cursor.fetchone()
317-
return result[0] if result else 0
308+
with self._lock:
309+
with closing(conn.cursor()) as cursor:
310+
cursor.execute(
311+
"""
312+
SELECT COALESCE(MAX(user_turn_number), 0)
313+
FROM message_structure
314+
WHERE session_id = ? AND branch_id = ?
315+
""",
316+
(self.session_id, self._current_branch_id),
317+
)
318+
result = cursor.fetchone()
319+
return result[0] if result else 0
318320

319321
async def _add_structure_metadata(self, items: list[TResponseInputItem]) -> None:
320322
"""Extract structure metadata with branch-aware turn tracking.
@@ -333,7 +335,7 @@ def _add_structure_sync():
333335
"""Synchronous helper to add structure metadata to database."""
334336
conn = self._get_connection()
335337
# TODO: Refactor SQLiteSession to use asyncio.Lock instead of threading.Lock and update this code # noqa: E501
336-
with self._lock if self._is_memory_db else threading.Lock():
338+
with self._lock:
337339
# Get the IDs of messages we just inserted, in order
338340
with closing(conn.cursor()) as cursor:
339341
cursor.execute(
@@ -439,7 +441,7 @@ def _cleanup_sync():
439441
"""Synchronous helper to cleanup orphaned messages."""
440442
conn = self._get_connection()
441443
# TODO: Refactor SQLiteSession to use asyncio.Lock instead of threading.Lock and update this code # noqa: E501
442-
with self._lock if self._is_memory_db else threading.Lock():
444+
with self._lock:
443445
with closing(conn.cursor()) as cursor:
444446
# Find messages without structure metadata
445447
cursor.execute(
@@ -694,7 +696,7 @@ def _delete_sync():
694696
"""Synchronous helper to delete branch and associated data."""
695697
conn = self._get_connection()
696698
# TODO: Refactor SQLiteSession to use asyncio.Lock instead of threading.Lock and update this code # noqa: E501
697-
with self._lock if self._is_memory_db else threading.Lock():
699+
with self._lock:
698700
with closing(conn.cursor()) as cursor:
699701
# First verify the branch exists
700702
cursor.execute(
@@ -756,36 +758,37 @@ async def list_branches(self) -> list[dict[str, Any]]:
756758
def _list_branches_sync():
757759
"""Synchronous helper to list all branches."""
758760
conn = self._get_connection()
759-
with closing(conn.cursor()) as cursor:
760-
cursor.execute(
761-
"""
762-
SELECT
763-
ms.branch_id,
764-
COUNT(*) as message_count,
765-
COUNT(CASE WHEN ms.message_type = 'user' THEN 1 END) as user_turns,
766-
MIN(ms.created_at) as created_at
767-
FROM message_structure ms
768-
WHERE ms.session_id = ?
769-
GROUP BY ms.branch_id
770-
ORDER BY created_at
771-
""",
772-
(self.session_id,),
773-
)
774-
775-
branches = []
776-
for row in cursor.fetchall():
777-
branch_id, msg_count, user_turns, created_at = row
778-
branches.append(
779-
{
780-
"branch_id": branch_id,
781-
"message_count": msg_count,
782-
"user_turns": user_turns,
783-
"is_current": branch_id == self._current_branch_id,
784-
"created_at": created_at,
785-
}
761+
with self._lock:
762+
with closing(conn.cursor()) as cursor:
763+
cursor.execute(
764+
"""
765+
SELECT
766+
ms.branch_id,
767+
COUNT(*) as message_count,
768+
COUNT(CASE WHEN ms.message_type = 'user' THEN 1 END) as user_turns,
769+
MIN(ms.created_at) as created_at
770+
FROM message_structure ms
771+
WHERE ms.session_id = ?
772+
GROUP BY ms.branch_id
773+
ORDER BY created_at
774+
""",
775+
(self.session_id,),
786776
)
787777

788-
return branches
778+
branches = []
779+
for row in cursor.fetchall():
780+
branch_id, msg_count, user_turns, created_at = row
781+
branches.append(
782+
{
783+
"branch_id": branch_id,
784+
"message_count": msg_count,
785+
"user_turns": user_turns,
786+
"is_current": branch_id == self._current_branch_id,
787+
"created_at": created_at,
788+
}
789+
)
790+
791+
return branches
789792

790793
return await asyncio.to_thread(_list_branches_sync)
791794

@@ -801,7 +804,7 @@ def _copy_sync():
801804
"""Synchronous helper to copy messages to new branch."""
802805
conn = self._get_connection()
803806
# TODO: Refactor SQLiteSession to use asyncio.Lock instead of threading.Lock and update this code # noqa: E501
804-
with self._lock if self._is_memory_db else threading.Lock():
807+
with self._lock:
805808
with closing(conn.cursor()) as cursor:
806809
# Get all messages before the branch point
807810
cursor.execute(
@@ -1072,7 +1075,7 @@ def _get_usage_sync():
10721075
"""Synchronous helper to get session usage data."""
10731076
conn = self._get_connection()
10741077
# TODO: Refactor SQLiteSession to use asyncio.Lock instead of threading.Lock and update this code # noqa: E501
1075-
with self._lock if self._is_memory_db else threading.Lock():
1078+
with self._lock:
10761079
if branch_id:
10771080
# Branch-specific usage
10781081
query = """
@@ -1236,7 +1239,7 @@ def _update_sync():
12361239
"""Synchronous helper to update turn usage data."""
12371240
conn = self._get_connection()
12381241
# TODO: Refactor SQLiteSession to use asyncio.Lock instead of threading.Lock and update this code # noqa: E501
1239-
with self._lock if self._is_memory_db else threading.Lock():
1242+
with self._lock:
12401243
# Serialize token details as JSON
12411244
input_details_json = None
12421245
output_details_json = None

src/agents/memory/sqlite_session.py

Lines changed: 22 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -38,40 +38,21 @@ def __init__(
3838
self.db_path = db_path
3939
self.sessions_table = sessions_table
4040
self.messages_table = messages_table
41-
self._local = threading.local()
4241
self._lock = threading.Lock()
4342

44-
# For in-memory databases, we need a shared connection to avoid thread isolation
45-
# For file databases, we use thread-local connections for better concurrency
43+
# Keep _is_memory_db for backward compatibility with AdvancedSQLiteSession
4644
self._is_memory_db = str(db_path) == ":memory:"
47-
if self._is_memory_db:
48-
self._shared_connection = sqlite3.connect(":memory:", check_same_thread=False)
49-
self._shared_connection.execute("PRAGMA journal_mode=WAL")
50-
self._init_db_for_connection(self._shared_connection)
51-
else:
52-
# For file databases, initialize the schema once since it persists
53-
init_conn = sqlite3.connect(str(self.db_path), check_same_thread=False)
54-
init_conn.execute("PRAGMA journal_mode=WAL")
55-
self._init_db_for_connection(init_conn)
56-
init_conn.close()
45+
46+
# Use a shared connection for all database types
47+
# This avoids file descriptor leaks from thread-local connections
48+
# WAL mode enables concurrent readers/writers even with a shared connection
49+
self._shared_connection = sqlite3.connect(str(db_path), check_same_thread=False)
50+
self._shared_connection.execute("PRAGMA journal_mode=WAL")
51+
self._init_db_for_connection(self._shared_connection)
5752

5853
def _get_connection(self) -> sqlite3.Connection:
5954
"""Get a database connection."""
60-
if self._is_memory_db:
61-
# Use shared connection for in-memory database to avoid thread isolation
62-
return self._shared_connection
63-
else:
64-
# Use thread-local connections for file databases
65-
if not hasattr(self._local, "connection"):
66-
self._local.connection = sqlite3.connect(
67-
str(self.db_path),
68-
check_same_thread=False,
69-
)
70-
self._local.connection.execute("PRAGMA journal_mode=WAL")
71-
assert isinstance(self._local.connection, sqlite3.Connection), (
72-
f"Expected sqlite3.Connection, got {type(self._local.connection)}"
73-
)
74-
return self._local.connection
55+
return self._shared_connection
7556

7657
def _init_db_for_connection(self, conn: sqlite3.Connection) -> None:
7758
"""Initialize the database schema for a specific connection."""
@@ -120,7 +101,7 @@ async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:
120101

121102
def _get_items_sync():
122103
conn = self._get_connection()
123-
with self._lock if self._is_memory_db else threading.Lock():
104+
with self._lock:
124105
if limit is None:
125106
# Fetch all items in chronological order
126107
cursor = conn.execute(
@@ -174,7 +155,7 @@ async def add_items(self, items: list[TResponseInputItem]) -> None:
174155
def _add_items_sync():
175156
conn = self._get_connection()
176157

177-
with self._lock if self._is_memory_db else threading.Lock():
158+
with self._lock:
178159
# Ensure session exists
179160
conn.execute(
180161
f"""
@@ -215,7 +196,7 @@ async def pop_item(self) -> TResponseInputItem | None:
215196

216197
def _pop_item_sync():
217198
conn = self._get_connection()
218-
with self._lock if self._is_memory_db else threading.Lock():
199+
with self._lock:
219200
# Use DELETE with RETURNING to atomically delete and return the most recent item
220201
cursor = conn.execute(
221202
f"""
@@ -252,7 +233,7 @@ async def clear_session(self) -> None:
252233

253234
def _clear_session_sync():
254235
conn = self._get_connection()
255-
with self._lock if self._is_memory_db else threading.Lock():
236+
with self._lock:
256237
conn.execute(
257238
f"DELETE FROM {self.messages_table} WHERE session_id = ?",
258239
(self.session_id,),
@@ -267,9 +248,12 @@ def _clear_session_sync():
267248

268249
def close(self) -> None:
269250
"""Close the database connection."""
270-
if self._is_memory_db:
271-
if hasattr(self, "_shared_connection"):
272-
self._shared_connection.close()
273-
else:
274-
if hasattr(self._local, "connection"):
275-
self._local.connection.close()
251+
if hasattr(self, "_shared_connection"):
252+
self._shared_connection.close()
253+
254+
def __del__(self) -> None:
255+
"""Ensure connection is closed when the session is garbage collected."""
256+
try:
257+
self.close()
258+
except Exception:
259+
pass # Ignore errors during finalization

0 commit comments

Comments
 (0)