Skip to content

Commit 350b01d

Browse files
committed
Fix SQLiteSession threading.Lock() bug and file descriptor leak
This PR addresses two critical bugs in SQLiteSession: ## Bug 1: threading.Lock() creating new instances **Problem:** In SQLiteSession (4 places) and AdvancedSQLiteSession (8 places), the code used: ```python with self._lock if self._is_memory_db else threading.Lock(): ``` For file-based databases, this creates a NEW Lock() instance on every operation, providing NO thread safety whatsoever. Only in-memory databases used self._lock. **Impact:** - File-based SQLiteSession had zero thread protection - Race conditions possible but masked by WAL mode's own concurrency handling ## Bug 2: File descriptor leak **Problem:** Thread-local connections in ThreadPoolExecutor are never cleaned up: - asyncio.to_thread() uses ThreadPoolExecutor internally - Each worker thread creates a connection on first use - ThreadPoolExecutor reuses threads indefinitely - Connections persist until program exit, accumulating file descriptors **Evidence:** Testing on main branch (60s, 40 concurrent workers): - My system (FD limit 1,048,575): +789 FDs leaked, 0 errors (limit not reached) - @ihower's system (likely limit 1,024): 646,802 errors in 20 seconds Error: `sqlite3.OperationalError: unable to open database file` ## Solution: Unified shared connection approach Instead of managing thread-local connections that can't be reliably cleaned up in ThreadPoolExecutor, use a single shared connection for all database types. **Changes:** 1. Removed thread-local connection logic (eliminates FD leak root cause) 2. All database types now use shared connection + self._lock 3. SQLite's WAL mode provides sufficient concurrency even with single connection 4. Fixed all 12 instances of threading.Lock() bug (4 in SQLiteSession, 8 in AdvancedSQLiteSession) 5. Kept _is_memory_db attribute for backward compatibility with AdvancedSQLiteSession 6. Added close() and __del__() methods for proper cleanup **Results (60s stress test, 30 writers + 10 readers):** ``` Main branch: - FD growth: +789 (leak) - Throughput: 701 ops/s - Errors: 0 on high-limit systems, 646k+ on normal systems After fix: - FD growth: +44 (stable) - Throughput: 726 ops/s (+3.6% improvement) - Errors: 0 on all systems - All 29 SQLite tests pass ``` ## Why shared connection performs better SQLite's WAL (Write-Ahead Logging) mode already provides: - Multiple concurrent readers - One writer coexisting with readers - Readers don't block writer - Writer doesn't block readers (except during checkpoint) The overhead of managing multiple connections outweighs any concurrency benefit. ## Backward compatibility The _is_memory_db attribute is preserved for AdvancedSQLiteSession compatibility, even though the implementation no longer differentiates connection strategies. ## Testing Comprehensive stress test available at: https://gist.github.com/gn00295120/0b6a65fe6c0ac6b7a1ce23654eed3ffe Run with: `python sqlite_stress_test_final.py`
1 parent 4bc33e3 commit 350b01d

File tree

2 files changed

+30
-47
lines changed

2 files changed

+30
-47
lines changed

src/agents/extensions/memory/advanced_sqlite_session.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import asyncio
44
import json
55
import logging
6-
import threading
76
from contextlib import closing
87
from pathlib import Path
98
from typing import Any, Union, cast
@@ -146,7 +145,7 @@ def _get_all_items_sync():
146145
"""Synchronous helper to get all items for a branch."""
147146
conn = self._get_connection()
148147
# TODO: Refactor SQLiteSession to use asyncio.Lock instead of threading.Lock and update this code # noqa: E501
149-
with self._lock if self._is_memory_db else threading.Lock():
148+
with self._lock:
150149
with closing(conn.cursor()) as cursor:
151150
if limit is None:
152151
cursor.execute(
@@ -191,7 +190,7 @@ def _get_items_sync():
191190
"""Synchronous helper to get items for a specific branch."""
192191
conn = self._get_connection()
193192
# TODO: Refactor SQLiteSession to use asyncio.Lock instead of threading.Lock and update this code # noqa: E501
194-
with self._lock if self._is_memory_db else threading.Lock():
193+
with self._lock:
195194
with closing(conn.cursor()) as cursor:
196195
# Get message IDs in correct order for this branch
197196
if limit is None:
@@ -333,7 +332,7 @@ def _add_structure_sync():
333332
"""Synchronous helper to add structure metadata to database."""
334333
conn = self._get_connection()
335334
# TODO: Refactor SQLiteSession to use asyncio.Lock instead of threading.Lock and update this code # noqa: E501
336-
with self._lock if self._is_memory_db else threading.Lock():
335+
with self._lock:
337336
# Get the IDs of messages we just inserted, in order
338337
with closing(conn.cursor()) as cursor:
339338
cursor.execute(
@@ -439,7 +438,7 @@ def _cleanup_sync():
439438
"""Synchronous helper to cleanup orphaned messages."""
440439
conn = self._get_connection()
441440
# TODO: Refactor SQLiteSession to use asyncio.Lock instead of threading.Lock and update this code # noqa: E501
442-
with self._lock if self._is_memory_db else threading.Lock():
441+
with self._lock:
443442
with closing(conn.cursor()) as cursor:
444443
# Find messages without structure metadata
445444
cursor.execute(
@@ -694,7 +693,7 @@ def _delete_sync():
694693
"""Synchronous helper to delete branch and associated data."""
695694
conn = self._get_connection()
696695
# TODO: Refactor SQLiteSession to use asyncio.Lock instead of threading.Lock and update this code # noqa: E501
697-
with self._lock if self._is_memory_db else threading.Lock():
696+
with self._lock:
698697
with closing(conn.cursor()) as cursor:
699698
# First verify the branch exists
700699
cursor.execute(
@@ -801,7 +800,7 @@ def _copy_sync():
801800
"""Synchronous helper to copy messages to new branch."""
802801
conn = self._get_connection()
803802
# TODO: Refactor SQLiteSession to use asyncio.Lock instead of threading.Lock and update this code # noqa: E501
804-
with self._lock if self._is_memory_db else threading.Lock():
803+
with self._lock:
805804
with closing(conn.cursor()) as cursor:
806805
# Get all messages before the branch point
807806
cursor.execute(
@@ -1072,7 +1071,7 @@ def _get_usage_sync():
10721071
"""Synchronous helper to get session usage data."""
10731072
conn = self._get_connection()
10741073
# TODO: Refactor SQLiteSession to use asyncio.Lock instead of threading.Lock and update this code # noqa: E501
1075-
with self._lock if self._is_memory_db else threading.Lock():
1074+
with self._lock:
10761075
if branch_id:
10771076
# Branch-specific usage
10781077
query = """
@@ -1236,7 +1235,7 @@ def _update_sync():
12361235
"""Synchronous helper to update turn usage data."""
12371236
conn = self._get_connection()
12381237
# TODO: Refactor SQLiteSession to use asyncio.Lock instead of threading.Lock and update this code # noqa: E501
1239-
with self._lock if self._is_memory_db else threading.Lock():
1238+
with self._lock:
12401239
# Serialize token details as JSON
12411240
input_details_json = None
12421241
output_details_json = None

src/agents/memory/sqlite_session.py

Lines changed: 22 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -38,40 +38,21 @@ def __init__(
3838
self.db_path = db_path
3939
self.sessions_table = sessions_table
4040
self.messages_table = messages_table
41-
self._local = threading.local()
4241
self._lock = threading.Lock()
4342

44-
# For in-memory databases, we need a shared connection to avoid thread isolation
45-
# For file databases, we use thread-local connections for better concurrency
43+
# Keep _is_memory_db for backward compatibility with AdvancedSQLiteSession
4644
self._is_memory_db = str(db_path) == ":memory:"
47-
if self._is_memory_db:
48-
self._shared_connection = sqlite3.connect(":memory:", check_same_thread=False)
49-
self._shared_connection.execute("PRAGMA journal_mode=WAL")
50-
self._init_db_for_connection(self._shared_connection)
51-
else:
52-
# For file databases, initialize the schema once since it persists
53-
init_conn = sqlite3.connect(str(self.db_path), check_same_thread=False)
54-
init_conn.execute("PRAGMA journal_mode=WAL")
55-
self._init_db_for_connection(init_conn)
56-
init_conn.close()
45+
46+
# Use a shared connection for all database types
47+
# This avoids file descriptor leaks from thread-local connections
48+
# WAL mode enables concurrent readers/writers even with a shared connection
49+
self._shared_connection = sqlite3.connect(str(db_path), check_same_thread=False)
50+
self._shared_connection.execute("PRAGMA journal_mode=WAL")
51+
self._init_db_for_connection(self._shared_connection)
5752

5853
def _get_connection(self) -> sqlite3.Connection:
5954
"""Get a database connection."""
60-
if self._is_memory_db:
61-
# Use shared connection for in-memory database to avoid thread isolation
62-
return self._shared_connection
63-
else:
64-
# Use thread-local connections for file databases
65-
if not hasattr(self._local, "connection"):
66-
self._local.connection = sqlite3.connect(
67-
str(self.db_path),
68-
check_same_thread=False,
69-
)
70-
self._local.connection.execute("PRAGMA journal_mode=WAL")
71-
assert isinstance(self._local.connection, sqlite3.Connection), (
72-
f"Expected sqlite3.Connection, got {type(self._local.connection)}"
73-
)
74-
return self._local.connection
55+
return self._shared_connection
7556

7657
def _init_db_for_connection(self, conn: sqlite3.Connection) -> None:
7758
"""Initialize the database schema for a specific connection."""
@@ -120,7 +101,7 @@ async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:
120101

121102
def _get_items_sync():
122103
conn = self._get_connection()
123-
with self._lock if self._is_memory_db else threading.Lock():
104+
with self._lock:
124105
if limit is None:
125106
# Fetch all items in chronological order
126107
cursor = conn.execute(
@@ -174,7 +155,7 @@ async def add_items(self, items: list[TResponseInputItem]) -> None:
174155
def _add_items_sync():
175156
conn = self._get_connection()
176157

177-
with self._lock if self._is_memory_db else threading.Lock():
158+
with self._lock:
178159
# Ensure session exists
179160
conn.execute(
180161
f"""
@@ -215,7 +196,7 @@ async def pop_item(self) -> TResponseInputItem | None:
215196

216197
def _pop_item_sync():
217198
conn = self._get_connection()
218-
with self._lock if self._is_memory_db else threading.Lock():
199+
with self._lock:
219200
# Use DELETE with RETURNING to atomically delete and return the most recent item
220201
cursor = conn.execute(
221202
f"""
@@ -252,7 +233,7 @@ async def clear_session(self) -> None:
252233

253234
def _clear_session_sync():
254235
conn = self._get_connection()
255-
with self._lock if self._is_memory_db else threading.Lock():
236+
with self._lock:
256237
conn.execute(
257238
f"DELETE FROM {self.messages_table} WHERE session_id = ?",
258239
(self.session_id,),
@@ -267,9 +248,12 @@ def _clear_session_sync():
267248

268249
def close(self) -> None:
269250
"""Close the database connection."""
270-
if self._is_memory_db:
271-
if hasattr(self, "_shared_connection"):
272-
self._shared_connection.close()
273-
else:
274-
if hasattr(self._local, "connection"):
275-
self._local.connection.close()
251+
if hasattr(self, "_shared_connection"):
252+
self._shared_connection.close()
253+
254+
def __del__(self) -> None:
255+
"""Ensure connection is closed when the session is garbage collected."""
256+
try:
257+
self.close()
258+
except Exception:
259+
pass # Ignore errors during finalization

0 commit comments

Comments
 (0)