Skip to content

Commit dfddca5

Browse files
committed
Remove BranchSession from Store interface
Export buildBranchedSession as session.BranchSession so callers compose it with GetSession and AddSession directly, rather than having the store own branching logic. Also removes the now-unused collectSessionIDs helper. Assisted-By: cagent Signed-off-by: Djordje Lukic <djordje.lukic@docker.com>
1 parent f78dccb commit dfddca5

File tree

5 files changed

+32
-73
lines changed

5 files changed

+32
-73
lines changed

pkg/session/branch.go

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@ import (
99
"github.com/docker/cagent/pkg/tools"
1010
)
1111

12-
func buildBranchedSession(parent *Session, branchAtPosition int) (*Session, error) {
12+
// BranchSession creates a new session branched from the parent at the given position.
13+
// Messages up to (but not including) branchAtPosition are deep-cloned into the new session.
14+
func BranchSession(parent *Session, branchAtPosition int) (*Session, error) {
1315
if parent == nil {
1416
return nil, fmt.Errorf("parent session is nil")
1517
}
@@ -242,15 +244,3 @@ func recalculateSessionTotals(sess *Session) {
242244
sess.OutputTokens = outputTokens
243245
sess.Cost = cost
244246
}
245-
246-
func collectSessionIDs(sess *Session, ids map[string]struct{}) {
247-
if sess == nil || ids == nil {
248-
return
249-
}
250-
ids[sess.ID] = struct{}{}
251-
for _, item := range sess.Messages {
252-
if item.SubSession != nil {
253-
collectSessionIDs(item.SubSession, ids)
254-
}
255-
}
256-
}

pkg/session/branch_test.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -90,30 +90,30 @@ func TestCloneSessionItem(t *testing.T) {
9090
})
9191
}
9292

93-
func TestBuildBranchedSession(t *testing.T) {
93+
func TestBranchSession(t *testing.T) {
9494
t.Run("nil parent returns error", func(t *testing.T) {
95-
_, err := buildBranchedSession(nil, 0)
95+
_, err := BranchSession(nil, 0)
9696
require.Error(t, err)
9797
assert.Contains(t, err.Error(), "parent session is nil")
9898
})
9999

100100
t.Run("negative position returns error", func(t *testing.T) {
101101
parent := &Session{Messages: []Item{NewMessageItem(UserMessage("test"))}}
102-
_, err := buildBranchedSession(parent, -1)
102+
_, err := BranchSession(parent, -1)
103103
require.Error(t, err)
104104
assert.Contains(t, err.Error(), "out of range")
105105
})
106106

107107
t.Run("position beyond messages returns error", func(t *testing.T) {
108108
parent := &Session{Messages: []Item{NewMessageItem(UserMessage("test"))}}
109-
_, err := buildBranchedSession(parent, 2)
109+
_, err := BranchSession(parent, 2)
110110
require.Error(t, err)
111111
assert.Contains(t, err.Error(), "out of range")
112112
})
113113

114114
t.Run("position equal to messages length returns error", func(t *testing.T) {
115115
parent := &Session{Messages: []Item{NewMessageItem(UserMessage("test"))}}
116-
_, err := buildBranchedSession(parent, 1)
116+
_, err := BranchSession(parent, 1)
117117
require.Error(t, err)
118118
assert.Contains(t, err.Error(), "out of range")
119119
})
@@ -129,7 +129,7 @@ func TestBuildBranchedSession(t *testing.T) {
129129
},
130130
}
131131

132-
branched, err := buildBranchedSession(parent, 2)
132+
branched, err := BranchSession(parent, 2)
133133
require.NoError(t, err)
134134
assert.NotNil(t, branched)
135135

pkg/session/store.go

Lines changed: 0 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,6 @@ type Store interface {
8383
DeleteSession(ctx context.Context, id string) error
8484
UpdateSession(ctx context.Context, session *Session) error // Updates metadata only (not messages/items)
8585
SetSessionStarred(ctx context.Context, id string, starred bool) error
86-
BranchSession(ctx context.Context, parentSessionID string, branchAtPosition int) (*Session, error)
8786

8887
// === Granular item operations ===
8988

@@ -218,25 +217,6 @@ func (s *InMemorySessionStore) SetSessionStarred(_ context.Context, id string, s
218217
return nil
219218
}
220219

221-
// BranchSession creates a new session branched from the parent at the given position.
222-
func (s *InMemorySessionStore) BranchSession(_ context.Context, parentSessionID string, branchAtPosition int) (*Session, error) {
223-
if parentSessionID == "" {
224-
return nil, ErrEmptyID
225-
}
226-
parent, exists := s.sessions.Load(parentSessionID)
227-
if !exists {
228-
return nil, ErrNotFound
229-
}
230-
231-
branched, err := buildBranchedSession(parent, branchAtPosition)
232-
if err != nil {
233-
return nil, err
234-
}
235-
236-
s.sessions.Store(branched.ID, branched)
237-
return branched, nil
238-
}
239-
240220
// AddMessage adds a message to a session at the next position.
241221
// Returns the ID of the created message (for in-memory, this is a simple counter).
242222
func (s *InMemorySessionStore) AddMessage(_ context.Context, sessionID string, msg *Message) (int64, error) {
@@ -1075,37 +1055,6 @@ func (s *SQLiteSessionStore) SetSessionStarred(ctx context.Context, id string, s
10751055
return nil
10761056
}
10771057

1078-
// BranchSession creates a new session branched from the parent at the given position.
1079-
func (s *SQLiteSessionStore) BranchSession(ctx context.Context, parentSessionID string, branchAtPosition int) (*Session, error) {
1080-
if parentSessionID == "" {
1081-
return nil, ErrEmptyID
1082-
}
1083-
1084-
parent, err := s.GetSession(ctx, parentSessionID)
1085-
if err != nil {
1086-
return nil, err
1087-
}
1088-
1089-
branched, err := buildBranchedSession(parent, branchAtPosition)
1090-
if err != nil {
1091-
return nil, err
1092-
}
1093-
1094-
if err := s.AddSession(ctx, branched); err != nil {
1095-
return nil, err
1096-
}
1097-
1098-
ids := make(map[string]struct{})
1099-
collectSessionIDs(branched, ids)
1100-
for id := range ids {
1101-
if err := s.syncMessagesColumn(ctx, id); err != nil {
1102-
slog.Warn("[STORE] Failed to sync messages column after branch", "session_id", id, "error", err)
1103-
}
1104-
}
1105-
1106-
return branched, nil
1107-
}
1108-
11091058
// Close closes the database connection
11101059
func (s *SQLiteSessionStore) Close() error {
11111060
return s.db.Close()

pkg/session/store_test.go

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -244,8 +244,13 @@ func TestBranchSessionCopiesPrefix(t *testing.T) {
244244

245245
require.NoError(t, store.AddSession(t.Context(), parent))
246246

247-
branched, err := store.BranchSession(t.Context(), parent.ID, 2)
247+
parentLoaded, err := store.GetSession(t.Context(), parent.ID)
248248
require.NoError(t, err)
249+
250+
branched, err := BranchSession(parentLoaded, 2)
251+
require.NoError(t, err)
252+
253+
require.NoError(t, store.AddSession(t.Context(), branched))
249254
require.NotNil(t, branched.BranchParentPosition)
250255
assert.Equal(t, parent.ID, branched.BranchParentSessionID)
251256
assert.Equal(t, 2, *branched.BranchParentPosition)
@@ -289,9 +294,14 @@ func TestBranchSessionClonesSubSession(t *testing.T) {
289294

290295
require.NoError(t, store.AddSession(t.Context(), parent))
291296

292-
branched, err := store.BranchSession(t.Context(), parent.ID, 2)
297+
parentLoaded, err := store.GetSession(t.Context(), parent.ID)
293298
require.NoError(t, err)
294299

300+
branched, err := BranchSession(parentLoaded, 2)
301+
require.NoError(t, err)
302+
303+
require.NoError(t, store.AddSession(t.Context(), branched))
304+
295305
loaded, err := store.GetSession(t.Context(), branched.ID)
296306
require.NoError(t, err)
297307
require.Len(t, loaded.Messages, 2)

pkg/tui/handlers.go

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import (
1515
"github.com/docker/cagent/pkg/browser"
1616
"github.com/docker/cagent/pkg/evaluation"
1717
"github.com/docker/cagent/pkg/modelsdev"
18+
"github.com/docker/cagent/pkg/session"
1819
"github.com/docker/cagent/pkg/tools"
1920
mcptools "github.com/docker/cagent/pkg/tools/mcp"
2021
"github.com/docker/cagent/pkg/tui/components/notification"
@@ -106,11 +107,20 @@ func (a *appModel) handleBranchFromEdit(msg messages.BranchFromEditMsg) (tea.Mod
106107
return a, notification.ErrorCmd("No parent session for branch")
107108
}
108109

109-
newSess, err := store.BranchSession(context.Background(), msg.ParentSessionID, msg.BranchAtPosition)
110+
parent, err := store.GetSession(context.Background(), msg.ParentSessionID)
111+
if err != nil {
112+
return a, notification.ErrorCmd(fmt.Sprintf("Failed to load parent session: %v", err))
113+
}
114+
115+
newSess, err := session.BranchSession(parent, msg.BranchAtPosition)
110116
if err != nil {
111117
return a, notification.ErrorCmd(fmt.Sprintf("Failed to branch session: %v", err))
112118
}
113119

120+
if err := store.AddSession(context.Background(), newSess); err != nil {
121+
return a, notification.ErrorCmd(fmt.Sprintf("Failed to save branched session: %v", err))
122+
}
123+
114124
if current := a.application.Session(); current != nil {
115125
newSess.HideToolResults = current.HideToolResults
116126
newSess.ToolsApproved = current.ToolsApproved

0 commit comments

Comments
 (0)