Skip to content

Commit a4d77fe

Browse files
committed
Fix SQLiteSession threading.Lock() bug and file descriptor leak
This PR addresses two critical bugs in SQLiteSession: ## Bug 1: threading.Lock() creating new instances **Problem:** In SQLiteSession (4 places) and AdvancedSQLiteSession (8 places), the code used: ```python with self._lock if self._is_memory_db else threading.Lock(): ``` For file-based databases, this creates a NEW Lock() instance on every operation, providing NO thread safety whatsoever. Only in-memory databases used self._lock. **Impact:** - File-based SQLiteSession had zero thread protection - Race conditions possible but masked by WAL mode's own concurrency handling ## Bug 2: File descriptor leak **Problem:** Thread-local connections in ThreadPoolExecutor are never cleaned up: - asyncio.to_thread() uses ThreadPoolExecutor internally - Each worker thread creates a connection on first use - ThreadPoolExecutor reuses threads indefinitely - Connections persist until program exit, accumulating file descriptors **Evidence:** Testing on main branch (60s, 40 concurrent workers): - My system (FD limit 1,048,575): +789 FDs leaked, 0 errors (limit not reached) - @ihower's system (likely limit 1,024): 646,802 errors in 20 seconds Error: `sqlite3.OperationalError: unable to open database file` ## Solution: Unified shared connection approach Instead of managing thread-local connections that can't be reliably cleaned up in ThreadPoolExecutor, use a single shared connection for all database types. **Changes:** 1. Removed thread-local connection logic (eliminates FD leak root cause) 2. All database types now use shared connection + self._lock 3. SQLite's WAL mode provides sufficient concurrency even with single connection 4. Fixed all 12 instances of threading.Lock() bug (4 in SQLiteSession, 8 in AdvancedSQLiteSession) 5. Kept _is_memory_db attribute for backward compatibility with AdvancedSQLiteSession 6. Added close() and __del__() methods for proper cleanup **Results (60s stress test, 30 writers + 10 readers):** ``` Main branch: - FD growth: +789 (leak) - Throughput: 701 ops/s - Errors: 0 on high-limit systems, 646k+ on normal systems After fix: - FD growth: +44 (stable) - Throughput: 726 ops/s (+3.6% improvement) - Errors: 0 on all systems - All 29 SQLite tests pass ``` ## Why shared connection performs better SQLite's WAL (Write-Ahead Logging) mode already provides: - Multiple concurrent readers - One writer coexisting with readers - Readers don't block writer - Writer doesn't block readers (except during checkpoint) The overhead of managing multiple connections outweighs any concurrency benefit. ## Backward compatibility The _is_memory_db attribute is preserved for AdvancedSQLiteSession compatibility, even though the implementation no longer differentiates connection strategies. ## Testing Comprehensive stress test available at: https://gist.github.com/gn00295120/0b6a65fe6c0ac6b7a1ce23654eed3ffe Run with: `python sqlite_stress_test_final.py`
1 parent 4bc33e3 commit a4d77fe

File tree

2 files changed

+127
-118
lines changed

2 files changed

+127
-118
lines changed

src/agents/extensions/memory/advanced_sqlite_session.py

Lines changed: 82 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import asyncio
44
import json
55
import logging
6-
import threading
76
from contextlib import closing
87
from pathlib import Path
98
from typing import Any, Union, cast
@@ -146,7 +145,7 @@ def _get_all_items_sync():
146145
"""Synchronous helper to get all items for a branch."""
147146
conn = self._get_connection()
148147
# TODO: Refactor SQLiteSession to use asyncio.Lock instead of threading.Lock and update this code # noqa: E501
149-
with self._lock if self._is_memory_db else threading.Lock():
148+
with self._lock:
150149
with closing(conn.cursor()) as cursor:
151150
if limit is None:
152151
cursor.execute(
@@ -191,7 +190,7 @@ def _get_items_sync():
191190
"""Synchronous helper to get items for a specific branch."""
192191
conn = self._get_connection()
193192
# TODO: Refactor SQLiteSession to use asyncio.Lock instead of threading.Lock and update this code # noqa: E501
194-
with self._lock if self._is_memory_db else threading.Lock():
193+
with self._lock:
195194
with closing(conn.cursor()) as cursor:
196195
# Get message IDs in correct order for this branch
197196
if limit is None:
@@ -261,18 +260,19 @@ def _get_next_turn_number(self, branch_id: str) -> int:
261260
The next available turn number for the specified branch.
262261
"""
263262
conn = self._get_connection()
264-
with closing(conn.cursor()) as cursor:
265-
cursor.execute(
266-
"""
267-
SELECT COALESCE(MAX(user_turn_number), 0)
268-
FROM message_structure
269-
WHERE session_id = ? AND branch_id = ?
270-
""",
271-
(self.session_id, branch_id),
272-
)
273-
result = cursor.fetchone()
274-
max_turn = result[0] if result else 0
275-
return max_turn + 1
263+
with self._lock:
264+
with closing(conn.cursor()) as cursor:
265+
cursor.execute(
266+
"""
267+
SELECT COALESCE(MAX(user_turn_number), 0)
268+
FROM message_structure
269+
WHERE session_id = ? AND branch_id = ?
270+
""",
271+
(self.session_id, branch_id),
272+
)
273+
result = cursor.fetchone()
274+
max_turn = result[0] if result else 0
275+
return max_turn + 1
276276

277277
def _get_next_branch_turn_number(self, branch_id: str) -> int:
278278
"""Get the next branch turn number for a specific branch.
@@ -284,18 +284,19 @@ def _get_next_branch_turn_number(self, branch_id: str) -> int:
284284
The next available branch turn number for the specified branch.
285285
"""
286286
conn = self._get_connection()
287-
with closing(conn.cursor()) as cursor:
288-
cursor.execute(
289-
"""
290-
SELECT COALESCE(MAX(branch_turn_number), 0)
291-
FROM message_structure
292-
WHERE session_id = ? AND branch_id = ?
293-
""",
294-
(self.session_id, branch_id),
295-
)
296-
result = cursor.fetchone()
297-
max_turn = result[0] if result else 0
298-
return max_turn + 1
287+
with self._lock:
288+
with closing(conn.cursor()) as cursor:
289+
cursor.execute(
290+
"""
291+
SELECT COALESCE(MAX(branch_turn_number), 0)
292+
FROM message_structure
293+
WHERE session_id = ? AND branch_id = ?
294+
""",
295+
(self.session_id, branch_id),
296+
)
297+
result = cursor.fetchone()
298+
max_turn = result[0] if result else 0
299+
return max_turn + 1
299300

300301
def _get_current_turn_number(self) -> int:
301302
"""Get the current turn number for the current branch.
@@ -304,17 +305,18 @@ def _get_current_turn_number(self) -> int:
304305
The current turn number for the active branch.
305306
"""
306307
conn = self._get_connection()
307-
with closing(conn.cursor()) as cursor:
308-
cursor.execute(
309-
"""
310-
SELECT COALESCE(MAX(user_turn_number), 0)
311-
FROM message_structure
312-
WHERE session_id = ? AND branch_id = ?
313-
""",
314-
(self.session_id, self._current_branch_id),
315-
)
316-
result = cursor.fetchone()
317-
return result[0] if result else 0
308+
with self._lock:
309+
with closing(conn.cursor()) as cursor:
310+
cursor.execute(
311+
"""
312+
SELECT COALESCE(MAX(user_turn_number), 0)
313+
FROM message_structure
314+
WHERE session_id = ? AND branch_id = ?
315+
""",
316+
(self.session_id, self._current_branch_id),
317+
)
318+
result = cursor.fetchone()
319+
return result[0] if result else 0
318320

319321
async def _add_structure_metadata(self, items: list[TResponseInputItem]) -> None:
320322
"""Extract structure metadata with branch-aware turn tracking.
@@ -333,7 +335,7 @@ def _add_structure_sync():
333335
"""Synchronous helper to add structure metadata to database."""
334336
conn = self._get_connection()
335337
# TODO: Refactor SQLiteSession to use asyncio.Lock instead of threading.Lock and update this code # noqa: E501
336-
with self._lock if self._is_memory_db else threading.Lock():
338+
with self._lock:
337339
# Get the IDs of messages we just inserted, in order
338340
with closing(conn.cursor()) as cursor:
339341
cursor.execute(
@@ -439,7 +441,7 @@ def _cleanup_sync():
439441
"""Synchronous helper to cleanup orphaned messages."""
440442
conn = self._get_connection()
441443
# TODO: Refactor SQLiteSession to use asyncio.Lock instead of threading.Lock and update this code # noqa: E501
442-
with self._lock if self._is_memory_db else threading.Lock():
444+
with self._lock:
443445
with closing(conn.cursor()) as cursor:
444446
# Find messages without structure metadata
445447
cursor.execute(
@@ -586,7 +588,7 @@ def _validate_turn():
586588
except Exception:
587589
return "Unable to parse content"
588590

589-
turn_content = await asyncio.to_thread(_validate_turn)
591+
turn_content = await self._to_thread_with_lock(_validate_turn)
590592

591593
# Generate branch name if not provided
592594
if branch_name is None:
@@ -655,7 +657,7 @@ def _validate_branch():
655657
if count == 0:
656658
raise ValueError(f"Branch '{branch_id}' does not exist")
657659

658-
await asyncio.to_thread(_validate_branch)
660+
await self._to_thread_with_lock(_validate_branch)
659661

660662
old_branch = self._current_branch_id
661663
self._current_branch_id = branch_id
@@ -694,7 +696,7 @@ def _delete_sync():
694696
"""Synchronous helper to delete branch and associated data."""
695697
conn = self._get_connection()
696698
# TODO: Refactor SQLiteSession to use asyncio.Lock instead of threading.Lock and update this code # noqa: E501
697-
with self._lock if self._is_memory_db else threading.Lock():
699+
with self._lock:
698700
with closing(conn.cursor()) as cursor:
699701
# First verify the branch exists
700702
cursor.execute(
@@ -756,36 +758,37 @@ async def list_branches(self) -> list[dict[str, Any]]:
756758
def _list_branches_sync():
757759
"""Synchronous helper to list all branches."""
758760
conn = self._get_connection()
759-
with closing(conn.cursor()) as cursor:
760-
cursor.execute(
761-
"""
762-
SELECT
763-
ms.branch_id,
764-
COUNT(*) as message_count,
765-
COUNT(CASE WHEN ms.message_type = 'user' THEN 1 END) as user_turns,
766-
MIN(ms.created_at) as created_at
767-
FROM message_structure ms
768-
WHERE ms.session_id = ?
769-
GROUP BY ms.branch_id
770-
ORDER BY created_at
771-
""",
772-
(self.session_id,),
773-
)
774-
775-
branches = []
776-
for row in cursor.fetchall():
777-
branch_id, msg_count, user_turns, created_at = row
778-
branches.append(
779-
{
780-
"branch_id": branch_id,
781-
"message_count": msg_count,
782-
"user_turns": user_turns,
783-
"is_current": branch_id == self._current_branch_id,
784-
"created_at": created_at,
785-
}
761+
with self._lock:
762+
with closing(conn.cursor()) as cursor:
763+
cursor.execute(
764+
"""
765+
SELECT
766+
ms.branch_id,
767+
COUNT(*) as message_count,
768+
COUNT(CASE WHEN ms.message_type = 'user' THEN 1 END) as user_turns,
769+
MIN(ms.created_at) as created_at
770+
FROM message_structure ms
771+
WHERE ms.session_id = ?
772+
GROUP BY ms.branch_id
773+
ORDER BY created_at
774+
""",
775+
(self.session_id,),
786776
)
787777

788-
return branches
778+
branches = []
779+
for row in cursor.fetchall():
780+
branch_id, msg_count, user_turns, created_at = row
781+
branches.append(
782+
{
783+
"branch_id": branch_id,
784+
"message_count": msg_count,
785+
"user_turns": user_turns,
786+
"is_current": branch_id == self._current_branch_id,
787+
"created_at": created_at,
788+
}
789+
)
790+
791+
return branches
789792

790793
return await asyncio.to_thread(_list_branches_sync)
791794

@@ -801,7 +804,7 @@ def _copy_sync():
801804
"""Synchronous helper to copy messages to new branch."""
802805
conn = self._get_connection()
803806
# TODO: Refactor SQLiteSession to use asyncio.Lock instead of threading.Lock and update this code # noqa: E501
804-
with self._lock if self._is_memory_db else threading.Lock():
807+
with self._lock:
805808
with closing(conn.cursor()) as cursor:
806809
# Get all messages before the branch point
807810
cursor.execute(
@@ -928,7 +931,7 @@ def _get_turns_sync():
928931

929932
return turns
930933

931-
return await asyncio.to_thread(_get_turns_sync)
934+
return await self._to_thread_with_lock(_get_turns_sync)
932935

933936
async def find_turns_by_content(
934937
self, search_term: str, branch_id: str | None = None
@@ -984,7 +987,7 @@ def _search_sync():
984987

985988
return matches
986989

987-
return await asyncio.to_thread(_search_sync)
990+
return await self._to_thread_with_lock(_search_sync)
988991

989992
async def get_conversation_by_turns(
990993
self, branch_id: str | None = None
@@ -1022,7 +1025,7 @@ def _get_conversation_sync():
10221025
turns[turn_num].append({"type": msg_type, "tool_name": tool_name})
10231026
return turns
10241027

1025-
return await asyncio.to_thread(_get_conversation_sync)
1028+
return await self._to_thread_with_lock(_get_conversation_sync)
10261029

10271030
async def get_tool_usage(self, branch_id: str | None = None) -> list[tuple[str, int, int]]:
10281031
"""Get all tool usage by turn for specified branch.
@@ -1056,7 +1059,7 @@ def _get_tool_usage_sync():
10561059
)
10571060
return cursor.fetchall()
10581061

1059-
return await asyncio.to_thread(_get_tool_usage_sync)
1062+
return await self._to_thread_with_lock(_get_tool_usage_sync)
10601063

10611064
async def get_session_usage(self, branch_id: str | None = None) -> dict[str, int] | None:
10621065
"""Get cumulative usage for session or specific branch.
@@ -1072,7 +1075,7 @@ def _get_usage_sync():
10721075
"""Synchronous helper to get session usage data."""
10731076
conn = self._get_connection()
10741077
# TODO: Refactor SQLiteSession to use asyncio.Lock instead of threading.Lock and update this code # noqa: E501
1075-
with self._lock if self._is_memory_db else threading.Lock():
1078+
with self._lock:
10761079
if branch_id:
10771080
# Branch-specific usage
10781081
query = """
@@ -1220,7 +1223,7 @@ def _get_turn_usage_sync():
12201223
)
12211224
return results
12221225

1223-
result = await asyncio.to_thread(_get_turn_usage_sync)
1226+
result = await self._to_thread_with_lock(_get_turn_usage_sync)
12241227

12251228
return cast(Union[list[dict[str, Any]], dict[str, Any]], result)
12261229

@@ -1236,7 +1239,7 @@ def _update_sync():
12361239
"""Synchronous helper to update turn usage data."""
12371240
conn = self._get_connection()
12381241
# TODO: Refactor SQLiteSession to use asyncio.Lock instead of threading.Lock and update this code # noqa: E501
1239-
with self._lock if self._is_memory_db else threading.Lock():
1242+
with self._lock:
12401243
# Serialize token details as JSON
12411244
input_details_json = None
12421245
output_details_json = None

0 commit comments

Comments
 (0)