Skip to content

Commit 5e6096a

Browse files
committed
fix(#1738): prevent sub-agent streaming messages from being persisted to parent session
When a task is transferred to a sub-agent, AgentChoiceEvent and AgentChoiceReasoningEvent from the sub-agent were being persisted to the parent session by PersistentRuntime.handleEvent. This corrupted the parent session's conversation history by interleaving sub-agent messages between the transfer_task tool call and its tool result, breaking the assistant→tool pairing required by LLM providers. Track sub-session depth via AgentSwitchingEvent and skip persistence of streaming content and MessageAddedEvent while inside a sub-session. Sub-session messages are correctly persisted via SubSessionCompletedEvent. Fixes #1738
1 parent 8df5bc5 commit 5e6096a

File tree

2 files changed

+136
-0
lines changed

2 files changed

+136
-0
lines changed

pkg/runtime/persistent_runtime.go

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ type streamingState struct {
2323
reasoningContent strings.Builder
2424
agentName string
2525
messageID int64 // ID of the current streaming message (0 if none)
26+
subSessionDepth int // >0 when inside a sub-session (task transfer); skip parent persistence
2627
}
2728

2829
// New creates a new runtime for an agent and its team.
@@ -72,14 +73,27 @@ func (r *PersistentRuntime) handleEvent(ctx context.Context, sess *session.Sessi
7273
}
7374

7475
switch e := event.(type) {
76+
case *AgentSwitchingEvent:
77+
if e.Switching {
78+
streaming.subSessionDepth++
79+
} else {
80+
streaming.subSessionDepth--
81+
}
82+
7583
case *AgentChoiceEvent:
84+
if streaming.subSessionDepth > 0 {
85+
return
86+
}
7687
// Accumulate streaming content
7788
streaming.content.WriteString(e.Content)
7889
streaming.agentName = e.AgentName
7990

8091
r.persistStreamingContent(ctx, sess.ID, streaming)
8192

8293
case *AgentChoiceReasoningEvent:
94+
if streaming.subSessionDepth > 0 {
95+
return
96+
}
8397
// Accumulate streaming reasoning content
8498
streaming.reasoningContent.WriteString(e.Content)
8599
streaming.agentName = e.AgentName
@@ -98,6 +112,9 @@ func (r *PersistentRuntime) handleEvent(ctx context.Context, sess *session.Sessi
98112
}
99113

100114
case *MessageAddedEvent:
115+
if streaming.subSessionDepth > 0 {
116+
return
117+
}
101118
// Finalize the streaming message with complete metadata
102119
if streaming.messageID != 0 {
103120
// Update the existing streaming message with final content
Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
package runtime
2+
3+
import (
4+
"context"
5+
"sync"
6+
"testing"
7+
8+
"github.com/stretchr/testify/assert"
9+
"github.com/stretchr/testify/require"
10+
11+
"github.com/docker/cagent/pkg/agent"
12+
"github.com/docker/cagent/pkg/chat"
13+
"github.com/docker/cagent/pkg/model/provider/base"
14+
"github.com/docker/cagent/pkg/session"
15+
"github.com/docker/cagent/pkg/team"
16+
"github.com/docker/cagent/pkg/tools"
17+
"github.com/docker/cagent/pkg/tools/builtin"
18+
)
19+
20+
// multiStreamProvider returns different streams on consecutive calls.
21+
type multiStreamProvider struct {
22+
id string
23+
mu sync.Mutex
24+
streams []chat.MessageStream
25+
idx int
26+
}
27+
28+
func (m *multiStreamProvider) ID() string { return m.id }
29+
30+
func (m *multiStreamProvider) CreateChatCompletionStream(_ context.Context, _ []chat.Message, _ []tools.Tool) (chat.MessageStream, error) {
31+
m.mu.Lock()
32+
defer m.mu.Unlock()
33+
if m.idx >= len(m.streams) {
34+
return m.streams[len(m.streams)-1], nil
35+
}
36+
s := m.streams[m.idx]
37+
m.idx++
38+
return s, nil
39+
}
40+
41+
func (m *multiStreamProvider) BaseConfig() base.Config { return base.Config{} }
42+
43+
func (m *multiStreamProvider) MaxTokens() int { return 0 }
44+
45+
func TestPersistentRuntime_SubAgentMessagesNotPersistedToParent(t *testing.T) {
46+
// Stream 1 (root): produces a transfer_task tool call to "worker"
47+
rootStream := newStreamBuilder().
48+
AddToolCallName("call_transfer", "transfer_task").
49+
AddToolCallArguments("call_transfer", `{"agent":"worker","task":"do work","expected_output":"result"}`).
50+
AddStopWithUsage(10, 5).
51+
Build()
52+
53+
// Stream 2 (worker sub-agent): produces streaming content simulating work
54+
workerStream := newStreamBuilder().
55+
AddContent("I am doing ").
56+
AddContent("the work now.").
57+
AddStopWithUsage(5, 10).
58+
Build()
59+
60+
// Stream 3 (root after transfer completes): final response
61+
rootFinalStream := newStreamBuilder().
62+
AddContent("Task completed.").
63+
AddStopWithUsage(3, 2).
64+
Build()
65+
66+
prov := &multiStreamProvider{
67+
id: "test/mock-model",
68+
streams: []chat.MessageStream{rootStream, workerStream, rootFinalStream},
69+
}
70+
71+
worker := agent.New("worker", "Worker agent", agent.WithModel(prov))
72+
root := agent.New("root", "Root coordinator",
73+
agent.WithModel(prov),
74+
agent.WithToolSets(builtin.NewTransferTaskTool()),
75+
)
76+
agent.WithSubAgents(worker)(root)
77+
78+
tm := team.New(team.WithAgents(root, worker))
79+
80+
store := session.NewInMemorySessionStore()
81+
82+
rt, err := New(tm,
83+
WithSessionCompaction(false),
84+
WithModelStore(mockModelStore{}),
85+
WithSessionStore(store),
86+
)
87+
require.NoError(t, err)
88+
89+
sess := session.New(
90+
session.WithUserMessage("Please delegate work to the worker"),
91+
session.WithToolsApproved(true),
92+
)
93+
sess.Title = "Test Transfer Persistence"
94+
95+
err = store.AddSession(t.Context(), sess)
96+
require.NoError(t, err)
97+
98+
evCh := rt.RunStream(t.Context(), sess)
99+
100+
var events []Event
101+
for ev := range evCh {
102+
events = append(events, ev)
103+
}
104+
105+
// Retrieve the parent session from the store
106+
parentSess, err := store.GetSession(t.Context(), sess.ID)
107+
require.NoError(t, err)
108+
109+
// Check that no sub-agent messages are in the parent session
110+
for _, item := range parentSess.Messages {
111+
if !item.IsMessage() {
112+
continue
113+
}
114+
msg := item.Message
115+
assert.NotEqual(t, "worker", msg.AgentName,
116+
"Sub-agent 'worker' messages should not be persisted in the parent session. "+
117+
"Found message with role=%s content=%q", msg.Message.Role, msg.Message.Content)
118+
}
119+
}

0 commit comments

Comments
 (0)