Skip to content
Merged
821 changes: 413 additions & 408 deletions api/v1/api.gen.go

Large diffs are not rendered by default.

6 changes: 6 additions & 0 deletions api/v1/api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7871,6 +7871,9 @@ components:
type: string
working:
type: boolean
hasPendingPrompt:
type: boolean
description: "Whether the agent is waiting for user input"
model:
type: string
totalCost:
Expand All @@ -7891,6 +7894,9 @@ components:
$ref: "#/components/schemas/AgentSession"
working:
type: boolean
hasPendingPrompt:
type: boolean
description: "Whether the agent is waiting for user input"
model:
type: string
totalCost:
Expand Down
39 changes: 27 additions & 12 deletions internal/agent/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,11 @@ type APIConfig struct {

// SessionWithState is a session with its current state.
type SessionWithState struct {
Session Session `json:"session"`
Working bool `json:"working"`
Model string `json:"model,omitempty"`
TotalCost float64 `json:"total_cost"`
Session Session `json:"session"`
Working bool `json:"working"`
HasPendingPrompt bool `json:"has_pending_prompt"`
Model string `json:"model,omitempty"`
TotalCost float64 `json:"total_cost"`
}

// NewAPI creates a new API instance.
Expand Down Expand Up @@ -310,10 +311,11 @@ func (a *API) collectActiveSessions(userID string, activeIDs map[string]struct{}
}
activeIDs[id] = struct{}{}
sessions = append(sessions, SessionWithState{
Session: sess,
Working: mgr.IsWorking(),
Model: mgr.GetModel(),
TotalCost: mgr.GetTotalCost(),
Session: sess,
Working: mgr.IsWorking(),
HasPendingPrompt: mgr.HasPendingPrompt(),
Model: mgr.GetModel(),
TotalCost: mgr.GetTotalCost(),
})
return true
})
Expand Down Expand Up @@ -638,10 +640,11 @@ func (a *API) GetSessionDetail(ctx context.Context, sessionID, userID string) (*
Messages: mgr.GetMessages(),
Session: &sess,
SessionState: &SessionState{
SessionID: sessionID,
Working: mgr.IsWorking(),
Model: mgr.GetModel(),
TotalCost: mgr.GetTotalCost(),
SessionID: sessionID,
Working: mgr.IsWorking(),
HasPendingPrompt: mgr.HasPendingPrompt(),
Model: mgr.GetModel(),
TotalCost: mgr.GetTotalCost(),
},
Delegates: mgr.GetDelegates(),
}, nil
Expand Down Expand Up @@ -763,6 +766,10 @@ const idleSessionTimeout = 30 * time.Minute
// cleanupInterval is how often the cleanup goroutine runs.
const cleanupInterval = 5 * time.Minute

// stuckHeartbeatTimeout is the maximum time without a heartbeat before
// a working session is considered stuck and cancelled (3x loopHeartbeatInterval).
const stuckHeartbeatTimeout = 30 * time.Second

// StartCleanup begins periodic cleanup of idle sessions.
// It should be called once when the API is initialized and will
// stop when the context is cancelled.
Expand Down Expand Up @@ -801,6 +808,14 @@ func (a *API) cleanupIdleSessions() {
if sess.ParentSessionID != "" {
return true
}
// Detect stuck sessions: working but no heartbeat in 30s (3x the 10s interval).
if mgr.IsWorking() {
lastHB := mgr.LastHeartbeat()
if !lastHB.IsZero() && time.Since(lastHB) > stuckHeartbeatTimeout {
_ = mgr.Cancel(context.Background())
a.logger.Warn("Cancelled stuck session", "session_id", id)
}
}
if !mgr.IsWorking() && mgr.LastActivity().Before(cutoff) {
toDelete = append(toDelete, id)
}
Expand Down
71 changes: 71 additions & 0 deletions internal/agent/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1204,3 +1204,74 @@ func TestAPI_CleanupIdleSessions_DeletesIdleSession(t *testing.T) {
_, exists = api.sessions.Load("cleanup-cancel")
assert.False(t, exists, "idle session should be cleaned up")
}

func TestAPI_CleanupIdleSessions_CancelsStuckSession(t *testing.T) {
t.Parallel()

api := NewAPI(APIConfig{
ConfigStore: newMockConfigStore(true),
WorkingDir: t.TempDir(),
})

mgr := NewSessionManager(SessionManagerConfig{ID: "stuck-sess"})
mgr.mu.Lock()
mgr.working = true
mgr.lastHeartbeat = time.Now().Add(-1 * time.Minute) // stale heartbeat
mgr.lastActivity = time.Now() // recent activity
mgr.mu.Unlock()

api.sessions.Store("stuck-sess", mgr)

api.cleanupIdleSessions()

// Session should have been cancelled (working set to false)
assert.False(t, mgr.IsWorking(), "stuck session should be cancelled")
}

func TestAPI_CleanupIdleSessions_DoesNotCancelHealthyWorkingSession(t *testing.T) {
t.Parallel()

api := NewAPI(APIConfig{
ConfigStore: newMockConfigStore(true),
WorkingDir: t.TempDir(),
})

mgr := NewSessionManager(SessionManagerConfig{ID: "healthy-sess"})
mgr.mu.Lock()
mgr.working = true
mgr.lastHeartbeat = time.Now() // fresh heartbeat
mgr.lastActivity = time.Now()
mgr.mu.Unlock()

api.sessions.Store("healthy-sess", mgr)

api.cleanupIdleSessions()

// Session should still be working
_, exists := api.sessions.Load("healthy-sess")
assert.True(t, exists, "healthy working session should not be removed")
}

func TestAPI_CleanupIdleSessions_DoesNotCancelZeroHeartbeat(t *testing.T) {
t.Parallel()

api := NewAPI(APIConfig{
ConfigStore: newMockConfigStore(true),
WorkingDir: t.TempDir(),
})

// Working session with zero heartbeat (loop hasn't started heartbeating yet)
mgr := NewSessionManager(SessionManagerConfig{ID: "no-hb-sess"})
mgr.mu.Lock()
mgr.working = true
mgr.lastActivity = time.Now()
mgr.mu.Unlock()

api.sessions.Store("no-hb-sess", mgr)

api.cleanupIdleSessions()

// Should not be cancelled because lastHeartbeat is zero
_, exists := api.sessions.Load("no-hb-sess")
assert.True(t, exists, "session with zero heartbeat should not be cancelled")
}
Comment on lines +1336 to +1405
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Use require for assertions in the new cleanup tests.

These subtests rely on assertions that should fail fast; switch to require.* for consistency and to match project conventions.
As per coding guidelines, "Use stretchr/testify/require for assertions and shared fixtures from internal/test instead of duplicating mocks".

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@internal/agent/api_test.go` around lines 1208 - 1277, Update the three tests
TestAPI_CleanupIdleSessions_CancelsStuckSession,
TestAPI_CleanupIdleSessions_DoesNotCancelHealthyWorkingSession, and
TestAPI_CleanupIdleSessions_DoesNotCancelZeroHeartbeat to use require.* (from
stretchr/testify/require) instead of assert.* so failures fail fast; replace
assert.False/True and assert calls with the corresponding
require.False/True/NoError as appropriate, and stop duplicating mocks by using
the shared fixture from internal/test for the config store instead of
newMockConfigStore (i.e., replace newMockConfigStore(true) with the
project-provided mock from internal/test or its helper function), keeping
references to NewAPI, SessionManager, api.cleanupIdleSessions, and the session
IDs unchanged.

28 changes: 28 additions & 0 deletions internal/agent/loop.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ const (
// maxToolCallDepth limits nested tool call chains to prevent infinite recursion.
// This can happen if an LLM continuously makes tool calls without producing a final response.
maxToolCallDepth = 50

// loopHeartbeatInterval is the interval at which the loop emits heartbeats.
loopHeartbeatInterval = 10 * time.Second
)

// MessageRecordFunc is called to record new messages to persistent storage.
Expand Down Expand Up @@ -54,6 +57,8 @@ type LoopConfig struct {
SessionID string
// OnWorking is called when the working state changes.
OnWorking func(working bool)
// OnHeartbeat is called periodically to signal the loop is alive.
OnHeartbeat func()
// EmitUIAction is called when a tool wants to emit a UI action.
EmitUIAction UIActionFunc
// EmitUserPrompt is called when a tool wants to emit a user prompt.
Expand Down Expand Up @@ -91,6 +96,7 @@ type Loop struct {
workingDir string
sessionID string
onWorking func(working bool)
onHeartbeat func()
sequenceID int64
emitUIAction UIActionFunc
emitUserPrompt EmitUserPromptFunc
Expand Down Expand Up @@ -122,6 +128,7 @@ func NewLoop(config LoopConfig) *Loop {
workingDir: config.WorkingDir,
sessionID: config.SessionID,
onWorking: config.OnWorking,
onHeartbeat: config.OnHeartbeat,
emitUIAction: config.EmitUIAction,
emitUserPrompt: config.EmitUserPrompt,
waitUserResponse: config.WaitUserResponse,
Expand Down Expand Up @@ -163,6 +170,9 @@ func (l *Loop) Go(ctx context.Context) error {
idleTimer := time.NewTimer(idlePollingInterval)
defer idleTimer.Stop()

heartbeatTicker := time.NewTicker(loopHeartbeatInterval)
defer heartbeatTicker.Stop()

for {
select {
case <-ctx.Done():
Expand All @@ -171,6 +181,15 @@ func (l *Loop) Go(ctx context.Context) error {
default:
}

// Non-blocking heartbeat drain for progress between iterations.
select {
case <-heartbeatTicker.C:
if l.onHeartbeat != nil {
l.onHeartbeat()
}
default:
}

// Process any queued messages
l.mu.Lock()
hasQueuedMessages := len(l.messageQueue) > 0
Expand Down Expand Up @@ -201,6 +220,10 @@ func (l *Loop) Go(ctx context.Context) error {
select {
case <-ctx.Done():
return ctx.Err()
case <-heartbeatTicker.C:
if l.onHeartbeat != nil {
l.onHeartbeat()
}
case <-idleTimer.C:
}
}
Expand Down Expand Up @@ -390,6 +413,11 @@ func (l *Loop) SetUserContext(u UserIdentity) {
// instead of recursion to prevent stack overflow with long tool call chains.
func (l *Loop) handleToolCalls(ctx context.Context, toolCalls []llm.ToolCall) error {
for depth := range maxToolCallDepth {
// Heartbeat so cleanup doesn't cancel long-running tool chains.
if l.onHeartbeat != nil {
l.onHeartbeat()
}

l.executeToolCalls(ctx, toolCalls)

resp, err := l.sendRequest(ctx)
Expand Down
47 changes: 47 additions & 0 deletions internal/agent/loop_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,53 @@ func TestLoop_Go(t *testing.T) {
assert.Equal(t, "You are a helpful assistant.", capturedRequest.Messages[0].Content)
})

t.Run("calls OnHeartbeat during tool calls", func(t *testing.T) {
t.Parallel()

var mu sync.Mutex
heartbeatCount := 0

callCount := atomic.Int32{}
provider := &mockLLMProvider{
chatFunc: func(_ context.Context, _ *llm.ChatRequest) (*llm.ChatResponse, error) {
n := callCount.Add(1)
if n == 1 {
return &llm.ChatResponse{
FinishReason: "tool_calls",
ToolCalls: []llm.ToolCall{{
ID: "hb-call",
Type: "function",
Function: llm.ToolCallFunction{
Name: "think",
Arguments: `{"thought": "heartbeat test"}`,
},
}},
}, nil
}
return simpleStopResponse("done"), nil
},
}

loop := NewLoop(LoopConfig{
Provider: provider,
Tools: CreateTools(ToolConfig{}),
OnHeartbeat: func() {
mu.Lock()
heartbeatCount++
mu.Unlock()
},
})
loop.QueueUserMessage(llm.Message{Role: llm.RoleUser, Content: "test"})

runLoopForDuration(t, loop, 500*time.Millisecond)

mu.Lock()
count := heartbeatCount
mu.Unlock()

assert.GreaterOrEqual(t, count, 1, "OnHeartbeat should fire during tool call handling")
})

t.Run("accumulates token usage", func(t *testing.T) {
t.Parallel()

Expand Down
47 changes: 39 additions & 8 deletions internal/agent/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ type SessionManager struct {
mu sync.Mutex
createdAt time.Time
lastActivity time.Time
lastHeartbeat time.Time
model string
messages []Message
subpub *SubPub[StreamResponse]
Expand Down Expand Up @@ -174,10 +175,11 @@ func (sm *SessionManager) SetWorking(working bool) {
sm.logger.Debug("agent working state changed", "working", working)
sm.subpub.Broadcast(StreamResponse{
SessionState: &SessionState{
SessionID: id,
Working: working,
Model: model,
TotalCost: totalCost,
SessionID: id,
Working: working,
HasPendingPrompt: sm.HasPendingPrompt(),
Model: model,
TotalCost: totalCost,
},
})
if callback != nil {
Expand Down Expand Up @@ -213,6 +215,29 @@ func (sm *SessionManager) LastActivity() time.Time {
return sm.lastActivity
}

// HasPendingPrompt returns true if the session has pending user prompts.
func (sm *SessionManager) HasPendingPrompt() bool {
sm.promptsMu.Lock()
defer sm.promptsMu.Unlock()
return len(sm.pendingPrompts) > 0
}

// RecordHeartbeat updates the heartbeat and activity timestamps.
func (sm *SessionManager) RecordHeartbeat() {
sm.mu.Lock()
now := time.Now()
sm.lastHeartbeat = now
sm.lastActivity = now
sm.mu.Unlock()
}

// LastHeartbeat returns the time of the most recent heartbeat.
func (sm *SessionManager) LastHeartbeat() time.Time {
sm.mu.Lock()
defer sm.mu.Unlock()
return sm.lastHeartbeat
}

// GetModel returns the model ID used by this session.
func (sm *SessionManager) GetModel() string {
sm.mu.Lock()
Expand Down Expand Up @@ -351,6 +376,10 @@ func (sm *SessionManager) SubscribeWithSnapshot(ctx context.Context) (StreamResp
model := sm.model
totalCost := sm.totalCost
id := sm.id

sm.promptsMu.Lock()
hasPendingPrompt := len(sm.pendingPrompts) > 0
sm.promptsMu.Unlock()
sess := Session{
ID: id,
UserID: sm.user.UserID,
Expand All @@ -374,10 +403,11 @@ func (sm *SessionManager) SubscribeWithSnapshot(ctx context.Context) (StreamResp
Messages: msgs,
Session: &sess,
SessionState: &SessionState{
SessionID: id,
Working: working,
Model: model,
TotalCost: totalCost,
SessionID: id,
Working: working,
HasPendingPrompt: hasPendingPrompt,
Model: model,
TotalCost: totalCost,
},
Delegates: delegates,
}, next
Expand Down Expand Up @@ -460,6 +490,7 @@ func (sm *SessionManager) createLoop(provider llm.Provider, model string, histor
WorkingDir: sm.workingDir,
SessionID: sm.id,
OnWorking: sm.SetWorking,
OnHeartbeat: sm.RecordHeartbeat,
EmitUIAction: sm.createEmitUIActionFunc(),
EmitUserPrompt: sm.createEmitUserPromptFunc(),
WaitUserResponse: sm.createWaitUserResponseFunc(),
Expand Down
Loading
Loading