Skip to content

Commit 001e339

Browse files
authored
Merge pull request #1668 from rumpl/session-store-cleanup
Session store cleanup
2 parents d52eaf6 + dfddca5 commit 001e339

File tree

5 files changed

+50
-234
lines changed

5 files changed

+50
-234
lines changed

pkg/session/branch.go

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@ import (
99
"github.com/docker/cagent/pkg/tools"
1010
)
1111

12-
func buildBranchedSession(parent *Session, branchAtPosition int) (*Session, error) {
12+
// BranchSession creates a new session branched from the parent at the given position.
13+
// Messages up to (but not including) branchAtPosition are deep-cloned into the new session.
14+
func BranchSession(parent *Session, branchAtPosition int) (*Session, error) {
1315
if parent == nil {
1416
return nil, fmt.Errorf("parent session is nil")
1517
}
@@ -242,15 +244,3 @@ func recalculateSessionTotals(sess *Session) {
242244
sess.OutputTokens = outputTokens
243245
sess.Cost = cost
244246
}
245-
246-
func collectSessionIDs(sess *Session, ids map[string]struct{}) {
247-
if sess == nil || ids == nil {
248-
return
249-
}
250-
ids[sess.ID] = struct{}{}
251-
for _, item := range sess.Messages {
252-
if item.SubSession != nil {
253-
collectSessionIDs(item.SubSession, ids)
254-
}
255-
}
256-
}

pkg/session/branch_test.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -90,30 +90,30 @@ func TestCloneSessionItem(t *testing.T) {
9090
})
9191
}
9292

93-
func TestBuildBranchedSession(t *testing.T) {
93+
func TestBranchSession(t *testing.T) {
9494
t.Run("nil parent returns error", func(t *testing.T) {
95-
_, err := buildBranchedSession(nil, 0)
95+
_, err := BranchSession(nil, 0)
9696
require.Error(t, err)
9797
assert.Contains(t, err.Error(), "parent session is nil")
9898
})
9999

100100
t.Run("negative position returns error", func(t *testing.T) {
101101
parent := &Session{Messages: []Item{NewMessageItem(UserMessage("test"))}}
102-
_, err := buildBranchedSession(parent, -1)
102+
_, err := BranchSession(parent, -1)
103103
require.Error(t, err)
104104
assert.Contains(t, err.Error(), "out of range")
105105
})
106106

107107
t.Run("position beyond messages returns error", func(t *testing.T) {
108108
parent := &Session{Messages: []Item{NewMessageItem(UserMessage("test"))}}
109-
_, err := buildBranchedSession(parent, 2)
109+
_, err := BranchSession(parent, 2)
110110
require.Error(t, err)
111111
assert.Contains(t, err.Error(), "out of range")
112112
})
113113

114114
t.Run("position equal to messages length returns error", func(t *testing.T) {
115115
parent := &Session{Messages: []Item{NewMessageItem(UserMessage("test"))}}
116-
_, err := buildBranchedSession(parent, 1)
116+
_, err := BranchSession(parent, 1)
117117
require.Error(t, err)
118118
assert.Contains(t, err.Error(), "out of range")
119119
})
@@ -129,7 +129,7 @@ func TestBuildBranchedSession(t *testing.T) {
129129
},
130130
}
131131

132-
branched, err := buildBranchedSession(parent, 2)
132+
branched, err := BranchSession(parent, 2)
133133
require.NoError(t, err)
134134
assert.NotNil(t, branched)
135135

pkg/session/store.go

Lines changed: 18 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,18 @@ func ResolveSessionID(ctx context.Context, store Store, ref string) (string, err
4949
if !isRelative {
5050
return ref, nil
5151
}
52-
return store.GetSessionByOffset(ctx, offset)
52+
53+
summaries, err := store.GetSessionSummaries(ctx)
54+
if err != nil {
55+
return "", fmt.Errorf("getting session summaries: %w", err)
56+
}
57+
58+
index := offset - 1
59+
if index >= len(summaries) {
60+
return "", fmt.Errorf("session offset %d out of range (have %d sessions)", offset, len(summaries))
61+
}
62+
63+
return summaries[index].ID, nil
5364
}
5465

5566
// Summary contains lightweight session metadata for listing purposes.
@@ -72,12 +83,6 @@ type Store interface {
7283
DeleteSession(ctx context.Context, id string) error
7384
UpdateSession(ctx context.Context, session *Session) error // Updates metadata only (not messages/items)
7485
SetSessionStarred(ctx context.Context, id string, starred bool) error
75-
BranchSession(ctx context.Context, parentSessionID string, branchAtPosition int) (*Session, error)
76-
77-
// GetSessionByOffset returns the session ID at the given offset from the most recent.
78-
// Offset 1 returns the most recent session, 2 returns the second most recent, etc.
79-
// Only root sessions are considered (sub-sessions are excluded).
80-
GetSessionByOffset(ctx context.Context, offset int) (string, error)
8186

8287
// === Granular item operations ===
8388

@@ -147,6 +152,9 @@ func (s *InMemorySessionStore) GetSessions(_ context.Context) ([]*Session, error
147152
func (s *InMemorySessionStore) GetSessionSummaries(_ context.Context) ([]Summary, error) {
148153
summaries := make([]Summary, 0, s.sessions.Length())
149154
s.sessions.Range(func(_ string, value *Session) bool {
155+
if value.ParentID != "" {
156+
return true
157+
}
150158
summaries = append(summaries, Summary{
151159
ID: value.ID,
152160
Title: value.Title,
@@ -156,6 +164,9 @@ func (s *InMemorySessionStore) GetSessionSummaries(_ context.Context) ([]Summary
156164
})
157165
return true
158166
})
167+
sort.Slice(summaries, func(i, j int) bool {
168+
return summaries[i].CreatedAt.After(summaries[j].CreatedAt)
169+
})
159170
return summaries, nil
160171
}
161172

@@ -206,25 +217,6 @@ func (s *InMemorySessionStore) SetSessionStarred(_ context.Context, id string, s
206217
return nil
207218
}
208219

209-
// BranchSession creates a new session branched from the parent at the given position.
210-
func (s *InMemorySessionStore) BranchSession(_ context.Context, parentSessionID string, branchAtPosition int) (*Session, error) {
211-
if parentSessionID == "" {
212-
return nil, ErrEmptyID
213-
}
214-
parent, exists := s.sessions.Load(parentSessionID)
215-
if !exists {
216-
return nil, ErrNotFound
217-
}
218-
219-
branched, err := buildBranchedSession(parent, branchAtPosition)
220-
if err != nil {
221-
return nil, err
222-
}
223-
224-
s.sessions.Store(branched.ID, branched)
225-
return branched, nil
226-
}
227-
228220
// AddMessage adds a message to a session at the next position.
229221
// Returns the ID of the created message (for in-memory, this is a simple counter).
230222
func (s *InMemorySessionStore) AddMessage(_ context.Context, sessionID string, msg *Message) (int64, error) {
@@ -358,34 +350,6 @@ func (s *InMemorySessionStore) UpdateSessionTitle(_ context.Context, sessionID,
358350
return nil
359351
}
360352

361-
// GetSessionByOffset returns the session ID at the given offset from the most recent.
362-
func (s *InMemorySessionStore) GetSessionByOffset(_ context.Context, offset int) (string, error) {
363-
if offset < 1 {
364-
return "", fmt.Errorf("offset must be >= 1, got %d", offset)
365-
}
366-
367-
// Collect and sort sessions by creation time (newest first)
368-
var sessions []*Session
369-
s.sessions.Range(func(_ string, value *Session) bool {
370-
// Only include root sessions (not sub-sessions)
371-
if value.ParentID == "" {
372-
sessions = append(sessions, value)
373-
}
374-
return true
375-
})
376-
377-
sort.Slice(sessions, func(i, j int) bool {
378-
return sessions[i].CreatedAt.After(sessions[j].CreatedAt)
379-
})
380-
381-
index := offset - 1 // offset 1 means index 0 (most recent session)
382-
if index >= len(sessions) {
383-
return "", fmt.Errorf("session offset %d out of range (have %d sessions)", offset, len(sessions))
384-
}
385-
386-
return sessions[index].ID, nil
387-
}
388-
389353
// NewSQLiteSessionStore creates a new SQLite session store
390354
func NewSQLiteSessionStore(path string) (Store, error) {
391355
store, err := openAndMigrateSQLiteStore(path)
@@ -1091,37 +1055,6 @@ func (s *SQLiteSessionStore) SetSessionStarred(ctx context.Context, id string, s
10911055
return nil
10921056
}
10931057

1094-
// BranchSession creates a new session branched from the parent at the given position.
1095-
func (s *SQLiteSessionStore) BranchSession(ctx context.Context, parentSessionID string, branchAtPosition int) (*Session, error) {
1096-
if parentSessionID == "" {
1097-
return nil, ErrEmptyID
1098-
}
1099-
1100-
parent, err := s.GetSession(ctx, parentSessionID)
1101-
if err != nil {
1102-
return nil, err
1103-
}
1104-
1105-
branched, err := buildBranchedSession(parent, branchAtPosition)
1106-
if err != nil {
1107-
return nil, err
1108-
}
1109-
1110-
if err := s.AddSession(ctx, branched); err != nil {
1111-
return nil, err
1112-
}
1113-
1114-
ids := make(map[string]struct{})
1115-
collectSessionIDs(branched, ids)
1116-
for id := range ids {
1117-
if err := s.syncMessagesColumn(ctx, id); err != nil {
1118-
slog.Warn("[STORE] Failed to sync messages column after branch", "session_id", id, "error", err)
1119-
}
1120-
}
1121-
1122-
return branched, nil
1123-
}
1124-
11251058
// Close closes the database connection
11261059
func (s *SQLiteSessionStore) Close() error {
11271060
return s.db.Close()
@@ -1400,28 +1333,3 @@ func (s *SQLiteSessionStore) UpdateSessionTitle(ctx context.Context, sessionID,
14001333
title, sessionID)
14011334
return err
14021335
}
1403-
1404-
// GetSessionByOffset returns the session ID at the given offset from the most recent.
1405-
func (s *SQLiteSessionStore) GetSessionByOffset(ctx context.Context, offset int) (string, error) {
1406-
if offset < 1 {
1407-
return "", fmt.Errorf("offset must be >= 1, got %d", offset)
1408-
}
1409-
1410-
// Query sessions ordered by creation time (newest first), limited to offset
1411-
// Only include root sessions (not sub-sessions)
1412-
var sessionID string
1413-
err := s.db.QueryRowContext(ctx,
1414-
`SELECT id FROM sessions
1415-
WHERE parent_id IS NULL OR parent_id = ''
1416-
ORDER BY created_at DESC
1417-
LIMIT 1 OFFSET ?`,
1418-
offset-1).Scan(&sessionID)
1419-
if err != nil {
1420-
if errors.Is(err, sql.ErrNoRows) {
1421-
return "", fmt.Errorf("session offset %d out of range", offset)
1422-
}
1423-
return "", err
1424-
}
1425-
1426-
return sessionID, nil
1427-
}

pkg/session/store_test.go

Lines changed: 12 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -244,8 +244,13 @@ func TestBranchSessionCopiesPrefix(t *testing.T) {
244244

245245
require.NoError(t, store.AddSession(t.Context(), parent))
246246

247-
branched, err := store.BranchSession(t.Context(), parent.ID, 2)
247+
parentLoaded, err := store.GetSession(t.Context(), parent.ID)
248248
require.NoError(t, err)
249+
250+
branched, err := BranchSession(parentLoaded, 2)
251+
require.NoError(t, err)
252+
253+
require.NoError(t, store.AddSession(t.Context(), branched))
249254
require.NotNil(t, branched.BranchParentPosition)
250255
assert.Equal(t, parent.ID, branched.BranchParentSessionID)
251256
assert.Equal(t, 2, *branched.BranchParentPosition)
@@ -289,9 +294,14 @@ func TestBranchSessionClonesSubSession(t *testing.T) {
289294

290295
require.NoError(t, store.AddSession(t.Context(), parent))
291296

292-
branched, err := store.BranchSession(t.Context(), parent.ID, 2)
297+
parentLoaded, err := store.GetSession(t.Context(), parent.ID)
293298
require.NoError(t, err)
294299

300+
branched, err := BranchSession(parentLoaded, 2)
301+
require.NoError(t, err)
302+
303+
require.NoError(t, store.AddSession(t.Context(), branched))
304+
295305
loaded, err := store.GetSession(t.Context(), branched.ID)
296306
require.NoError(t, err)
297307
require.Len(t, loaded.Messages, 2)
@@ -1274,105 +1284,3 @@ func TestResolveSessionID_InMemory(t *testing.T) {
12741284
assert.Equal(t, "some-uuid", id)
12751285
})
12761286
}
1277-
1278-
func TestGetSessionByOffset_SQLite(t *testing.T) {
1279-
tempDB := filepath.Join(t.TempDir(), "test_offset.db")
1280-
1281-
store, err := NewSQLiteSessionStore(tempDB)
1282-
require.NoError(t, err)
1283-
defer store.(*SQLiteSessionStore).Close()
1284-
1285-
// Create sessions with known timestamps
1286-
baseTime := time.Now()
1287-
sessions := []struct {
1288-
id string
1289-
createdAt time.Time
1290-
}{
1291-
{"oldest", baseTime.Add(-3 * time.Hour)},
1292-
{"middle", baseTime.Add(-2 * time.Hour)},
1293-
{"newest", baseTime.Add(-1 * time.Hour)},
1294-
}
1295-
1296-
for _, s := range sessions {
1297-
err := store.AddSession(t.Context(), &Session{
1298-
ID: s.id,
1299-
CreatedAt: s.createdAt,
1300-
})
1301-
require.NoError(t, err)
1302-
}
1303-
1304-
t.Run("offset 1 returns newest", func(t *testing.T) {
1305-
id, err := store.GetSessionByOffset(t.Context(), 1)
1306-
require.NoError(t, err)
1307-
assert.Equal(t, "newest", id)
1308-
})
1309-
1310-
t.Run("offset 2 returns middle", func(t *testing.T) {
1311-
id, err := store.GetSessionByOffset(t.Context(), 2)
1312-
require.NoError(t, err)
1313-
assert.Equal(t, "middle", id)
1314-
})
1315-
1316-
t.Run("offset 3 returns oldest", func(t *testing.T) {
1317-
id, err := store.GetSessionByOffset(t.Context(), 3)
1318-
require.NoError(t, err)
1319-
assert.Equal(t, "oldest", id)
1320-
})
1321-
1322-
t.Run("offset 0 returns error", func(t *testing.T) {
1323-
_, err := store.GetSessionByOffset(t.Context(), 0)
1324-
require.Error(t, err)
1325-
})
1326-
1327-
t.Run("out of range offset returns error", func(t *testing.T) {
1328-
_, err := store.GetSessionByOffset(t.Context(), 4)
1329-
require.Error(t, err)
1330-
assert.Contains(t, err.Error(), "out of range")
1331-
})
1332-
}
1333-
1334-
func TestGetSessionByOffset_InMemory(t *testing.T) {
1335-
store := NewInMemorySessionStore()
1336-
1337-
// Create sessions with known timestamps
1338-
baseTime := time.Now()
1339-
sessions := []struct {
1340-
id string
1341-
createdAt time.Time
1342-
}{
1343-
{"oldest", baseTime.Add(-3 * time.Hour)},
1344-
{"middle", baseTime.Add(-2 * time.Hour)},
1345-
{"newest", baseTime.Add(-1 * time.Hour)},
1346-
}
1347-
1348-
for _, s := range sessions {
1349-
err := store.AddSession(t.Context(), &Session{
1350-
ID: s.id,
1351-
CreatedAt: s.createdAt,
1352-
})
1353-
require.NoError(t, err)
1354-
}
1355-
1356-
t.Run("offset 1 returns newest", func(t *testing.T) {
1357-
id, err := store.GetSessionByOffset(t.Context(), 1)
1358-
require.NoError(t, err)
1359-
assert.Equal(t, "newest", id)
1360-
})
1361-
1362-
t.Run("offset 2 returns middle", func(t *testing.T) {
1363-
id, err := store.GetSessionByOffset(t.Context(), 2)
1364-
require.NoError(t, err)
1365-
assert.Equal(t, "middle", id)
1366-
})
1367-
1368-
t.Run("offset 0 returns error", func(t *testing.T) {
1369-
_, err := store.GetSessionByOffset(t.Context(), 0)
1370-
require.Error(t, err)
1371-
})
1372-
1373-
t.Run("out of range offset returns error", func(t *testing.T) {
1374-
_, err := store.GetSessionByOffset(t.Context(), 4)
1375-
require.Error(t, err)
1376-
assert.Contains(t, err.Error(), "out of range")
1377-
})
1378-
}

0 commit comments

Comments
 (0)