Skip to content

Commit a6ce7ce

Browse files
committed
remove more soft deletion traces
1 parent 50e98bc commit a6ce7ce

File tree

2 files changed

+19
-23
lines changed

2 files changed

+19
-23
lines changed

src/agents/extensions/memory/advanced_sqlite_session.py

Lines changed: 16 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,15 @@
1616

1717

1818
class AdvancedSQLiteSession(SQLiteSession):
19-
"""Enhanced SQLite session with turn tracking, soft deletion, and usage analytics.
19+
"""Enhanced SQLite session with conversation branching and usage analytics.
2020
2121
Features:
22-
- Turn-based conversation management with soft delete/reactivate
22+
- Conversation branching from any user message
23+
- Independent branch management with turn tracking
2324
- Detailed usage tracking per turn with token breakdowns
2425
- Message structure metadata and tool usage statistics
2526
"""
2627

27-
ACTIVE = 1 # Message is active and visible in conversation
28-
INACTIVE = 0 # Message is soft-deleted (hidden but preserved)
29-
3028
def __init__(
3129
self,
3230
*,
@@ -96,7 +94,6 @@ def _init_structure_tables(self):
9694
CREATE INDEX IF NOT EXISTS idx_structure_turn
9795
ON message_structure(session_id, branch_id, user_turn_number)
9896
""")
99-
# Compound index for optimal performance on get_items queries
10097
conn.execute("""
10198
CREATE INDEX IF NOT EXISTS idx_structure_branch_seq
10299
ON message_structure(session_id, branch_id, sequence_number)
@@ -375,15 +372,13 @@ def _is_user_message(self, item: TResponseInputItem) -> bool:
375372
async def get_items(
376373
self,
377374
limit: int | None = None,
378-
include_inactive: bool = False,
379375
branch_id: str | None = None,
380376
) -> list[TResponseInputItem]:
381-
"""Get items from current or specified branch, optionally including soft-deleted ones."""
377+
"""Get items from current or specified branch."""
382378
if branch_id is None:
383379
branch_id = self._current_branch_id
384380

385-
if include_inactive:
386-
# Get all items (active and inactive) for this branch
381+
# Get all items for this branch
387382
def _get_all_items_sync():
388383
conn = self._get_connection()
389384
with self._lock if self._is_memory_db else threading.Lock():
@@ -427,12 +422,11 @@ def _get_all_items_sync():
427422

428423
return await asyncio.to_thread(_get_all_items_sync)
429424

430-
# Filter to only active items in this branch
431-
def _get_active_items_sync():
425+
def _get_items_sync():
432426
conn = self._get_connection()
433427
with self._lock if self._is_memory_db else threading.Lock():
434428
with closing(conn.cursor()) as cursor:
435-
# Get active message IDs in correct order for this branch
429+
# Get message IDs in correct order for this branch
436430
if limit is None:
437431
cursor.execute(
438432
"""
@@ -470,7 +464,7 @@ def _get_active_items_sync():
470464
continue
471465
return items
472466

473-
return await asyncio.to_thread(_get_active_items_sync)
467+
return await asyncio.to_thread(_get_items_sync)
474468

475469
async def _copy_messages_to_new_branch(self, new_branch_id: str, from_turn_number: int) -> None:
476470
"""Copy messages before the branch point to the new branch."""
@@ -562,7 +556,7 @@ async def create_branch_from_turn(
562556
The branch_id of the newly created branch
563557
564558
Raises:
565-
ValueError: If turn doesn't exist, isn't active, or doesn't contain a user message
559+
ValueError: If turn doesn't exist or doesn't contain a user message
566560
"""
567561
import time
568562

@@ -576,14 +570,15 @@ def _validate_turn():
576570
FROM message_structure ms
577571
JOIN agent_messages am ON ms.message_id = am.id
578572
WHERE ms.session_id = ? AND ms.branch_id = ?
579-
AND ms.branch_turn_number = ? AND ms.message_type = 'user' """,
573+
AND ms.branch_turn_number = ? AND ms.message_type = 'user'
574+
""",
580575
(self.session_id, self._current_branch_id, turn_number),
581576
)
582577

583578
result = cursor.fetchone()
584579
if not result:
585580
raise ValueError(
586-
f"Turn {turn_number} does not contain an active user message "
581+
f"Turn {turn_number} does not contain a user message "
587582
f"in branch '{self._current_branch_id}'"
588583
)
589584

@@ -862,9 +857,9 @@ async def list_branches(self) -> list[dict[str, Any]]:
862857
Returns:
863858
List of dicts with branch info: {
864859
'branch_id': str, # Branch identifier
865-
'message_count': int, # Number of active messages in branch
860+
'message_count': int, # Number of messages in branch
866861
'user_turns': int, # Number of user turns in branch
867-
'is_current': bool, # Whether this is the current active branch
862+
'is_current': bool, # Whether this is the current branch
868863
'created_at': str # When the branch was first created
869864
}
870865
"""
@@ -880,7 +875,8 @@ def _list_branches_sync():
880875
COUNT(CASE WHEN ms.message_type = 'user' THEN 1 END) as user_turns,
881876
MIN(ms.created_at) as created_at
882877
FROM message_structure ms
883-
WHERE ms.session_id = ? GROUP BY ms.branch_id
878+
WHERE ms.session_id = ?
879+
GROUP BY ms.branch_id
884880
ORDER BY created_at
885881
""",
886882
(self.session_id,),

tests/extensions/memory/test_advanced_sqlite_session.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -410,7 +410,7 @@ async def test_branch_error_handling():
410410
session = AdvancedSQLiteSession(session_id=session_id, create_tables=True)
411411

412412
# Test creating branch from non-existent turn
413-
with pytest.raises(ValueError, match="Turn 5 does not contain an active user message"):
413+
with pytest.raises(ValueError, match="Turn 5 does not contain a user message"):
414414
await session.create_branch_from_turn(5, "error_branch")
415415

416416
# Test switching to non-existent branch
@@ -499,8 +499,8 @@ async def test_get_items_with_parameters():
499499
main_items = await session.get_items(branch_id="main")
500500
assert len(main_items) == 4
501501

502-
# Test get_items with include_inactive (should be same as without it for now)
503-
all_items = await session.get_items(include_inactive=True)
502+
# Test get_items (no longer has include_inactive parameter)
503+
all_items = await session.get_items()
504504
assert len(all_items) == 4
505505

506506
# Create a branch from turn 2 and test branch-specific get_items

0 commit comments

Comments
 (0)