Skip to content

Commit 58a1e11

Browse files
committed
feat(checkpointer): add session management with expiration and stats
- Add SessionInfo and SessionStats structs for session metadata - Add ListSessions() to query all sessions ordered by last_active - Add GetSessionInfo() to get single session by threadID - Add GetStats() for aggregate session statistics - Add TouchSession() to track session activity and checkpoint counts - Add CleanupExpiredSessions() to remove old sessions (default 24h) - Add background cleanup worker running every hour - Add SetMaxAge() and SetMaxCheckpoints() configuration methods - Fix SQLite datetime parsing for aggregate functions
1 parent 8cd0570 commit 58a1e11

File tree

2 files changed

+585
-99
lines changed

2 files changed

+585
-99
lines changed

pkg/workflow/checkpointer.go

Lines changed: 246 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2,22 +2,47 @@ package workflow
22

33
import (
44
"context"
5+
"database/sql"
56
"fmt"
67
"os"
78
"path/filepath"
9+
"sync"
10+
"time"
811

912
lgg "github.com/smallnest/langgraphgo/graph"
1013
"github.com/smallnest/langgraphgo/store/sqlite"
1114
)
1215

13-
// CheckpointerManager manages the SQLite checkpointer for session persistence.
16+
const (
17+
DefaultSessionMaxAge = 24 * time.Hour
18+
DefaultCleanupInterval = 1 * time.Hour
19+
DefaultMaxCheckpoints = 100
20+
)
21+
22+
type SessionInfo struct {
23+
ThreadID string
24+
CreatedAt time.Time
25+
LastActive time.Time
26+
CheckpointCount int
27+
}
28+
29+
type SessionStats struct {
30+
TotalSessions int
31+
TotalCheckpoints int
32+
OldestSession *time.Time
33+
NewestSession *time.Time
34+
}
35+
1436
type CheckpointerManager struct {
15-
store *sqlite.SqliteCheckpointStore
16-
dataDir string
37+
store *sqlite.SqliteCheckpointStore
38+
db *sql.DB
39+
dataDir string
40+
maxAge time.Duration
41+
maxCheckpoints int
42+
mu sync.RWMutex
43+
stopCleanup chan struct{}
1744
}
1845

19-
// NewCheckpointerManager creates a new checkpointer manager.
20-
// If dataDir is empty, it defaults to ~/.k8s-wizard/checkpoints
2146
func NewCheckpointerManager(dataDir string) (*CheckpointerManager, error) {
2247
if dataDir == "" {
2348
homeDir, err := os.UserHomeDir()
@@ -27,7 +52,6 @@ func NewCheckpointerManager(dataDir string) (*CheckpointerManager, error) {
2752
dataDir = filepath.Join(homeDir, ".k8s-wizard", "checkpoints")
2853
}
2954

30-
// Ensure the directory exists
3155
if err := os.MkdirAll(dataDir, 0755); err != nil {
3256
return nil, fmt.Errorf("failed to create checkpoints directory: %w", err)
3357
}
@@ -40,33 +64,233 @@ func NewCheckpointerManager(dataDir string) (*CheckpointerManager, error) {
4064
return nil, fmt.Errorf("failed to create SQLite checkpoint store: %w", err)
4165
}
4266

43-
return &CheckpointerManager{
44-
store: store,
45-
dataDir: dataDir,
46-
}, nil
67+
db, err := sql.Open("sqlite3", dbPath+"?_loc=auto")
68+
if err != nil {
69+
store.Close()
70+
return nil, fmt.Errorf("failed to open database: %w", err)
71+
}
72+
73+
m := &CheckpointerManager{
74+
store: store,
75+
db: db,
76+
dataDir: dataDir,
77+
maxAge: DefaultSessionMaxAge,
78+
maxCheckpoints: DefaultMaxCheckpoints,
79+
stopCleanup: make(chan struct{}),
80+
}
81+
82+
if err := m.ensureSessionTable(); err != nil {
83+
m.Close()
84+
return nil, fmt.Errorf("failed to ensure session table: %w", err)
85+
}
86+
87+
go m.startCleanupWorker()
88+
89+
return m, nil
90+
}
91+
92+
func (m *CheckpointerManager) ensureSessionTable() error {
93+
_, err := m.db.Exec(`
94+
CREATE TABLE IF NOT EXISTS session_metadata (
95+
thread_id TEXT PRIMARY KEY,
96+
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
97+
last_active DATETIME DEFAULT CURRENT_TIMESTAMP,
98+
checkpoint_count INTEGER DEFAULT 0
99+
);
100+
CREATE INDEX IF NOT EXISTS idx_session_last_active ON session_metadata(last_active);
101+
CREATE INDEX IF NOT EXISTS idx_session_created_at ON session_metadata(created_at);
102+
`)
103+
return err
47104
}
48105

49-
// GetStore returns the underlying checkpoint store.
50106
func (m *CheckpointerManager) GetStore() lgg.CheckpointStore {
51107
return m.store
52108
}
53109

54-
// Close closes the checkpointer store.
55110
func (m *CheckpointerManager) Close() error {
111+
close(m.stopCleanup)
112+
if m.db != nil {
113+
m.db.Close()
114+
}
56115
return m.store.Close()
57116
}
58117

59-
// ClearSession clears all checkpoints for a given thread/session.
60118
func (m *CheckpointerManager) ClearSession(ctx context.Context, threadID string) error {
61-
return m.store.Clear(ctx, threadID)
119+
if err := m.store.Clear(ctx, threadID); err != nil {
120+
return err
121+
}
122+
_, err := m.db.ExecContext(ctx, "DELETE FROM session_metadata WHERE thread_id = ?", threadID)
123+
return err
124+
}
125+
126+
func (m *CheckpointerManager) ListSessions(ctx context.Context) ([]SessionInfo, error) {
127+
rows, err := m.db.QueryContext(ctx, `
128+
SELECT thread_id, created_at, last_active, checkpoint_count
129+
FROM session_metadata
130+
ORDER BY last_active DESC
131+
`)
132+
if err != nil {
133+
return nil, fmt.Errorf("failed to list sessions: %w", err)
134+
}
135+
defer rows.Close()
136+
137+
var sessions []SessionInfo
138+
for rows.Next() {
139+
var si SessionInfo
140+
var createdAt, lastActive sql.NullTime
141+
if err := rows.Scan(&si.ThreadID, &createdAt, &lastActive, &si.CheckpointCount); err != nil {
142+
return nil, fmt.Errorf("failed to scan session: %w", err)
143+
}
144+
if createdAt.Valid {
145+
si.CreatedAt = createdAt.Time
146+
}
147+
if lastActive.Valid {
148+
si.LastActive = lastActive.Time
149+
}
150+
sessions = append(sessions, si)
151+
}
152+
return sessions, rows.Err()
153+
}
154+
155+
func (m *CheckpointerManager) GetSessionInfo(ctx context.Context, threadID string) (*SessionInfo, error) {
156+
var si SessionInfo
157+
var createdAt, lastActive sql.NullTime
158+
err := m.db.QueryRowContext(ctx, `
159+
SELECT thread_id, created_at, last_active, checkpoint_count
160+
FROM session_metadata
161+
WHERE thread_id = ?
162+
`, threadID).Scan(&si.ThreadID, &createdAt, &lastActive, &si.CheckpointCount)
163+
if err == sql.ErrNoRows {
164+
return nil, nil
165+
}
166+
if err != nil {
167+
return nil, fmt.Errorf("failed to get session info: %w", err)
168+
}
169+
if createdAt.Valid {
170+
si.CreatedAt = createdAt.Time
171+
}
172+
if lastActive.Valid {
173+
si.LastActive = lastActive.Time
174+
}
175+
return &si, nil
62176
}
63177

64-
// ListSessions lists all unique thread IDs in the store.
65-
// Note: This is a best-effort implementation as the store doesn't have a direct API for this.
66-
func (m *CheckpointerManager) ListSessions(ctx context.Context) ([]string, error) {
67-
// This implementation depends on the store's internal structure.
68-
// For now, we return an empty list as there's no direct API to list all thread IDs.
69-
// In a production implementation, you might want to track thread IDs separately
70-
// or use a database query to get unique thread IDs.
71-
return []string{}, nil
178+
func (m *CheckpointerManager) GetStats(ctx context.Context) (*SessionStats, error) {
179+
var stats SessionStats
180+
var oldestStr, newestStr sql.NullString
181+
err := m.db.QueryRowContext(ctx, `
182+
SELECT
183+
COUNT(*) as total_sessions,
184+
COALESCE(SUM(checkpoint_count), 0) as total_checkpoints,
185+
MIN(created_at) as oldest_session,
186+
MAX(created_at) as newest_session
187+
FROM session_metadata
188+
`).Scan(&stats.TotalSessions, &stats.TotalCheckpoints, &oldestStr, &newestStr)
189+
if err != nil {
190+
return nil, fmt.Errorf("failed to get stats: %w", err)
191+
}
192+
if oldestStr.Valid && oldestStr.String != "" {
193+
t, parseErr := time.Parse("2006-01-02 15:04:05-07:00", oldestStr.String)
194+
if parseErr != nil {
195+
t, parseErr = time.Parse("2006-01-02 15:04:05", oldestStr.String)
196+
}
197+
if parseErr == nil {
198+
stats.OldestSession = &t
199+
}
200+
}
201+
if newestStr.Valid && newestStr.String != "" {
202+
t, parseErr := time.Parse("2006-01-02 15:04:05-07:00", newestStr.String)
203+
if parseErr != nil {
204+
t, parseErr = time.Parse("2006-01-02 15:04:05", newestStr.String)
205+
}
206+
if parseErr == nil {
207+
stats.NewestSession = &t
208+
}
209+
}
210+
return &stats, nil
211+
}
212+
213+
func (m *CheckpointerManager) TouchSession(ctx context.Context, threadID string) error {
214+
now := time.Now()
215+
result, err := m.db.ExecContext(ctx, `
216+
INSERT INTO session_metadata (thread_id, created_at, last_active, checkpoint_count)
217+
VALUES (?, ?, ?, 1)
218+
ON CONFLICT(thread_id) DO UPDATE SET
219+
last_active = ?,
220+
checkpoint_count = checkpoint_count + 1
221+
`, threadID, now, now, now)
222+
if err != nil {
223+
return fmt.Errorf("failed to touch session: %w", err)
224+
}
225+
226+
affected, _ := result.RowsAffected()
227+
if affected == 0 {
228+
return fmt.Errorf("failed to update session metadata")
229+
}
230+
return nil
231+
}
232+
233+
func (m *CheckpointerManager) SetMaxAge(maxAge time.Duration) {
234+
m.mu.Lock()
235+
defer m.mu.Unlock()
236+
m.maxAge = maxAge
237+
}
238+
239+
func (m *CheckpointerManager) SetMaxCheckpoints(max int) {
240+
m.mu.Lock()
241+
defer m.mu.Unlock()
242+
m.maxCheckpoints = max
243+
}
244+
245+
func (m *CheckpointerManager) CleanupExpiredSessions(ctx context.Context) (int, error) {
246+
m.mu.RLock()
247+
cutoff := time.Now().Add(-m.maxAge)
248+
m.mu.RUnlock()
249+
250+
result, err := m.db.ExecContext(ctx, `
251+
DELETE FROM session_metadata WHERE last_active < ?
252+
`, cutoff)
253+
if err != nil {
254+
return 0, fmt.Errorf("failed to cleanup expired sessions: %w", err)
255+
}
256+
257+
affected, _ := result.RowsAffected()
258+
259+
threads, _ := m.db.QueryContext(ctx, `
260+
SELECT DISTINCT thread_id FROM checkpoints WHERE thread_id IS NOT NULL
261+
`)
262+
if threads != nil {
263+
defer threads.Close()
264+
for threads.Next() {
265+
var tid string
266+
if err := threads.Scan(&tid); err != nil {
267+
continue
268+
}
269+
var count int
270+
m.db.QueryRowContext(ctx, `
271+
SELECT COUNT(*) FROM session_metadata WHERE thread_id = ?
272+
`, tid).Scan(&count)
273+
if count == 0 {
274+
m.store.Clear(ctx, tid)
275+
}
276+
}
277+
}
278+
279+
return int(affected), nil
280+
}
281+
282+
func (m *CheckpointerManager) startCleanupWorker() {
283+
ticker := time.NewTicker(DefaultCleanupInterval)
284+
defer ticker.Stop()
285+
286+
for {
287+
select {
288+
case <-ticker.C:
289+
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
290+
m.CleanupExpiredSessions(ctx)
291+
cancel()
292+
case <-m.stopCleanup:
293+
return
294+
}
295+
}
72296
}

0 commit comments

Comments
 (0)