Skip to content

Commit 28bec8d

Browse files
committed
fix: prevent sub-agent streaming messages from being persisted to parent session
When a task is transferred to a sub-agent, AgentChoiceEvent and AgentChoiceReasoningEvent from the sub-agent were being persisted to the parent session by PersistentRuntime.handleEvent. This corrupted the parent session's conversation history by interleaving sub-agent messages between the transfer_task tool call and its tool result, breaking the assistant/tool pairing required by LLM providers. Track sub-session depth via AgentSwitchingEvent and skip persistence of streaming content and MessageAddedEvent while inside a sub-session. Sub-session messages are correctly persisted via SubSessionCompletedEvent. Fixes #1738
1 parent 8df5bc5 commit 28bec8d

File tree

2 files changed

+162
-0
lines changed

2 files changed

+162
-0
lines changed

pkg/runtime/persistent_runtime.go

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ type streamingState struct {
2323
reasoningContent strings.Builder
2424
agentName string
2525
messageID int64 // ID of the current streaming message (0 if none)
26+
subSessionDepth int // >0 when inside a sub-session (task transfer); skip parent persistence
2627
}
2728

2829
// New creates a new runtime for an agent and its team.
@@ -72,14 +73,30 @@ func (r *PersistentRuntime) handleEvent(ctx context.Context, sess *session.Sessi
7273
}
7374

7475
switch e := event.(type) {
76+
case *AgentSwitchingEvent:
77+
if e.Switching {
78+
streaming.subSessionDepth++
79+
} else if streaming.subSessionDepth > 0 {
80+
streaming.subSessionDepth--
81+
} else {
82+
slog.Warn("Received AgentSwitching(false) without matching AgentSwitching(true)",
83+
"session_id", sess.ID, "from_agent", e.FromAgent, "to_agent", e.ToAgent)
84+
}
85+
7586
case *AgentChoiceEvent:
87+
if streaming.subSessionDepth > 0 {
88+
return
89+
}
7690
// Accumulate streaming content
7791
streaming.content.WriteString(e.Content)
7892
streaming.agentName = e.AgentName
7993

8094
r.persistStreamingContent(ctx, sess.ID, streaming)
8195

8296
case *AgentChoiceReasoningEvent:
97+
if streaming.subSessionDepth > 0 {
98+
return
99+
}
83100
// Accumulate streaming reasoning content
84101
streaming.reasoningContent.WriteString(e.Content)
85102
streaming.agentName = e.AgentName
@@ -98,6 +115,9 @@ func (r *PersistentRuntime) handleEvent(ctx context.Context, sess *session.Sessi
98115
}
99116

100117
case *MessageAddedEvent:
118+
if streaming.subSessionDepth > 0 {
119+
return
120+
}
101121
// Finalize the streaming message with complete metadata
102122
if streaming.messageID != 0 {
103123
// Update the existing streaming message with final content
Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
package runtime
2+
3+
import (
4+
"context"
5+
"sync"
6+
"testing"
7+
8+
"github.com/stretchr/testify/assert"
9+
"github.com/stretchr/testify/require"
10+
11+
"github.com/docker/cagent/pkg/agent"
12+
"github.com/docker/cagent/pkg/chat"
13+
"github.com/docker/cagent/pkg/model/provider/base"
14+
"github.com/docker/cagent/pkg/session"
15+
"github.com/docker/cagent/pkg/team"
16+
"github.com/docker/cagent/pkg/tools"
17+
"github.com/docker/cagent/pkg/tools/builtin"
18+
)
19+
20+
// multiStreamProvider returns different streams on consecutive calls.
21+
type multiStreamProvider struct {
22+
id string
23+
mu sync.Mutex
24+
streams []chat.MessageStream
25+
idx int
26+
}
27+
28+
func (m *multiStreamProvider) ID() string { return m.id }
29+
30+
func (m *multiStreamProvider) CreateChatCompletionStream(_ context.Context, _ []chat.Message, _ []tools.Tool) (chat.MessageStream, error) {
31+
m.mu.Lock()
32+
defer m.mu.Unlock()
33+
if m.idx >= len(m.streams) {
34+
return m.streams[len(m.streams)-1], nil
35+
}
36+
s := m.streams[m.idx]
37+
m.idx++
38+
return s, nil
39+
}
40+
41+
func (m *multiStreamProvider) BaseConfig() base.Config { return base.Config{} }
42+
43+
func (m *multiStreamProvider) MaxTokens() int { return 0 }
44+
45+
func TestPersistentRuntime_SubAgentMessagesNotPersistedToParent(t *testing.T) {
46+
// Stream 1 (root): produces a transfer_task tool call to "worker"
47+
rootStream := newStreamBuilder().
48+
AddToolCallName("call_transfer", "transfer_task").
49+
AddToolCallArguments("call_transfer", `{"agent":"worker","task":"do work","expected_output":"result"}`).
50+
AddStopWithUsage(10, 5).
51+
Build()
52+
53+
// Stream 2 (worker sub-agent): produces streaming content simulating work
54+
workerStream := newStreamBuilder().
55+
AddContent("I am doing ").
56+
AddContent("the work now.").
57+
AddStopWithUsage(5, 10).
58+
Build()
59+
60+
prov := &multiStreamProvider{
61+
id: "test/mock-model",
62+
streams: []chat.MessageStream{rootStream, workerStream},
63+
}
64+
65+
worker := agent.New("worker", "Worker agent", agent.WithModel(prov))
66+
root := agent.New("root", "Root coordinator",
67+
agent.WithModel(prov),
68+
agent.WithToolSets(builtin.NewTransferTaskTool()),
69+
)
70+
agent.WithSubAgents(worker)(root)
71+
72+
tm := team.New(team.WithAgents(root, worker))
73+
74+
store := session.NewInMemorySessionStore()
75+
76+
rt, err := New(tm,
77+
WithSessionCompaction(false),
78+
WithModelStore(mockModelStore{}),
79+
WithSessionStore(store),
80+
)
81+
require.NoError(t, err)
82+
83+
sess := session.New(
84+
session.WithUserMessage("Please delegate work to the worker"),
85+
session.WithToolsApproved(true),
86+
)
87+
sess.Title = "Test Transfer Persistence"
88+
89+
err = store.AddSession(t.Context(), sess)
90+
require.NoError(t, err)
91+
92+
evCh := rt.RunStream(t.Context(), sess)
93+
for range evCh {
94+
}
95+
96+
parentSess, err := store.GetSession(t.Context(), sess.ID)
97+
require.NoError(t, err)
98+
99+
// Verify no sub-agent messages leaked into the parent session
100+
for _, item := range parentSess.Messages {
101+
if !item.IsMessage() {
102+
continue
103+
}
104+
assert.NotEqual(t, "worker", item.Message.AgentName,
105+
"Sub-agent 'worker' messages should not be in the parent session. "+
106+
"Found message with role=%s content=%q",
107+
item.Message.Message.Role, item.Message.Message.Content)
108+
}
109+
110+
// Verify the sub-session was persisted and contains the worker's messages
111+
var subSess *session.Session
112+
for _, item := range parentSess.Messages {
113+
if item.IsSubSession() {
114+
subSess = item.SubSession
115+
break
116+
}
117+
}
118+
require.NotNil(t, subSess,
119+
"Sub-session should be persisted in the parent session")
120+
121+
var workerMsgCount int
122+
for _, item := range subSess.Messages {
123+
if item.IsMessage() && item.Message.AgentName == "worker" {
124+
workerMsgCount++
125+
}
126+
}
127+
assert.Greater(t, workerMsgCount, 0,
128+
"Worker messages should be in the sub-session")
129+
130+
// Verify the root agent's assistant message (with transfer_task tool call)
131+
// and the tool result are both persisted in the parent
132+
var roles []chat.MessageRole
133+
for _, item := range parentSess.Messages {
134+
if item.IsMessage() {
135+
roles = append(roles, item.Message.Message.Role)
136+
}
137+
}
138+
assert.Contains(t, roles, chat.MessageRoleAssistant,
139+
"Parent session should contain root's assistant message with the transfer_task tool call")
140+
assert.Contains(t, roles, chat.MessageRoleTool,
141+
"Parent session should contain the tool result for transfer_task")
142+
}

0 commit comments

Comments
 (0)