Skip to content

Commit acb07de

Browse files
authored
chore: improve ai agent (#418)
1 parent 2e686fb commit acb07de

File tree

19 files changed

+1098
-322
lines changed

19 files changed

+1098
-322
lines changed

main.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ func setupAPIRouter(r *gin.RouterGroup, cm *cluster.ClusterManager) {
213213
// AI chat routes
214214
api.GET("/ai/status", ai.HandleAIStatus)
215215
api.POST("/ai/chat", ai.HandleChat)
216-
api.POST("/ai/execute", ai.HandleExecute)
216+
api.POST("/ai/execute/continue", ai.HandleExecuteContinue)
217217

218218
api.Use(middleware.RBACMiddleware())
219219
resources.RegisterRoutes(api)

pkg/ai/agent.go

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -280,9 +280,9 @@ func buildContextualSystemPrompt(pageCtx *PageContext, runtimeCtx runtimePromptC
280280
}
281281

282282
if language == "zh" {
283-
prompt += "\n\nResponse language:\n- Always respond in Simplified Chinese unless the user explicitly asks for another language."
283+
prompt += "\n\nResponse language:\n- Prefer replying in the same language as the user's latest message.\n- If the user's latest message language is unclear, respond in Simplified Chinese unless the user explicitly asks for another language."
284284
} else {
285-
prompt += "\n\nResponse language:\n- Always respond in English unless the user explicitly asks for another language."
285+
prompt += "\n\nResponse language:\n- Prefer replying in the same language as the user's latest message.\n- If the user's latest message language is unclear, respond in English unless the user explicitly asks for another language."
286286
}
287287

288288
klog.V(4).Infof("system prompt %s", prompt)
@@ -299,6 +299,20 @@ func (a *Agent) ProcessChat(c *gin.Context, req *ChatRequest, sendEvent func(SSE
299299
}
300300
}
301301

302+
func (a *Agent) ContinuePendingAction(c *gin.Context, sessionID string, sendEvent func(SSEEvent)) error {
303+
session, err := agentPendingSessions.take(sessionID)
304+
if err != nil {
305+
return err
306+
}
307+
308+
switch session.Provider {
309+
case model.GeneralAIProviderAnthropic:
310+
return a.continueChatAnthropic(c, session, sendEvent)
311+
default:
312+
return a.continueChatOpenAI(c, session, sendEvent)
313+
}
314+
}
315+
302316
func parseToolCallArguments(raw string) (map[string]interface{}, error) {
303317
raw = strings.TrimSpace(raw)
304318
if raw == "" {

pkg/ai/anthropic.go

Lines changed: 61 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package ai
22

33
import (
4+
"context"
45
"fmt"
56
"strings"
67

@@ -34,7 +35,45 @@ func (a *Agent) processChatAnthropic(c *gin.Context, req *ChatRequest, sendEvent
3435
}
3536
sysPrompt := buildContextualSystemPrompt(req.PageContext, runtimeCtx, language)
3637
messages := toAnthropicMessages(req.Messages)
37-
tools := AnthropicToolDefs()
38+
a.runAnthropicConversation(ctx, c, sysPrompt, messages, sendEvent)
39+
}
40+
41+
func (a *Agent) continueChatAnthropic(c *gin.Context, session pendingSession, sendEvent func(SSEEvent)) error {
42+
ctx := c.Request.Context()
43+
result, isError := ExecuteTool(ctx, c, a.cs, session.ToolCall.Name, session.ToolCall.Args)
44+
45+
sendEvent(SSEEvent{
46+
Event: "tool_result",
47+
Data: map[string]interface{}{
48+
"tool": session.ToolCall.Name,
49+
"result": result,
50+
"is_error": isError,
51+
},
52+
})
53+
54+
toolResult := result
55+
if isError {
56+
toolResult = "Tool error: " + result
57+
}
58+
59+
session.AnthropicMessages = append(
60+
session.AnthropicMessages,
61+
anthropic.NewUserMessage(
62+
anthropic.NewToolResultBlock(session.ToolCall.ID, toolResult, isError),
63+
),
64+
)
65+
a.runAnthropicConversation(ctx, c, session.SystemPrompt, session.AnthropicMessages, sendEvent)
66+
return nil
67+
}
68+
69+
func (a *Agent) runAnthropicConversation(
70+
ctx context.Context,
71+
c *gin.Context,
72+
sysPrompt string,
73+
messages []anthropic.MessageParam,
74+
sendEvent func(SSEEvent),
75+
) {
76+
tools := AnthropicToolDefs(a.cs)
3877

3978
maxIterations := 100
4079
for i := 0; i < maxIterations; i++ {
@@ -100,11 +139,30 @@ func (a *Agent) processChatAnthropic(c *gin.Context, req *ChatRequest, sendEvent
100139
toolResults = append(toolResults, anthropic.NewToolResultBlock(tc.ID, "Tool error: "+result, true))
101140
continue
102141
}
142+
if len(toolResults) > 0 {
143+
messages = append(messages, anthropic.NewUserMessage(toolResults...))
144+
}
145+
sessionID := agentPendingSessions.save(pendingSession{
146+
Provider: a.provider,
147+
SystemPrompt: sysPrompt,
148+
AnthropicMessages: append([]anthropic.MessageParam(nil), messages...),
149+
ToolCall: pendingToolCall{
150+
ID: tc.ID,
151+
Name: toolName,
152+
Args: args,
153+
},
154+
})
155+
if sessionID == "" {
156+
errorMsg := "Failed to save pending session"
157+
toolResults = append(toolResults, anthropic.NewToolResultBlock(tc.ID, "Tool error: "+errorMsg, true))
158+
continue
159+
}
103160
sendEvent(SSEEvent{
104161
Event: "action_required",
105162
Data: map[string]interface{}{
106-
"tool": toolName,
107-
"args": args,
163+
"tool": toolName,
164+
"args": args,
165+
"session_id": sessionID,
108166
},
109167
})
110168
return

pkg/ai/handler.go

Lines changed: 27 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -65,32 +65,23 @@ func HandleChat(c *gin.Context) {
6565
c.Header("Connection", "keep-alive")
6666
c.Header("X-Accel-Buffering", "no")
6767

68-
flusher, ok := c.Writer.(http.Flusher)
69-
if !ok {
70-
c.JSON(http.StatusInternalServerError, gin.H{"error": "Streaming not supported"})
71-
return
72-
}
73-
7468
sendEvent := func(event SSEEvent) {
7569
data := MarshalSSEEvent(event)
7670
_, _ = fmt.Fprint(c.Writer, data)
77-
flusher.Flush()
71+
c.Writer.Flush()
7872
}
7973

8074
agent.ProcessChat(c, &req, sendEvent)
8175

8276
sendEvent(SSEEvent{Event: "done", Data: map[string]string{}})
8377
}
8478

85-
// ExecuteRequest is the request body for the stateless execute endpoint.
86-
type ExecuteRequest struct {
87-
Tool string `json:"tool"`
88-
Args map[string]interface{} `json:"args"`
79+
type ContinueRequest struct {
80+
SessionID string `json:"sessionId"`
8981
}
9082

91-
// HandleExecute executes a confirmed mutation action. Stateless — the client
92-
// sends the full tool name and args, no server-side session needed.
93-
func HandleExecute(c *gin.Context) {
83+
// HandleExecuteContinue resumes a pending AI action after user confirmation.
84+
func HandleExecuteContinue(c *gin.Context) {
9485
cfg, err := LoadRuntimeConfig()
9586
if err != nil {
9687
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("Failed to load AI config: %v", err)})
@@ -101,14 +92,13 @@ func HandleExecute(c *gin.Context) {
10192
return
10293
}
10394

104-
var req ExecuteRequest
95+
var req ContinueRequest
10596
if err := c.ShouldBindJSON(&req); err != nil {
10697
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("Invalid request: %v", err)})
10798
return
10899
}
109-
110-
if !MutationTools[req.Tool] {
111-
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("Tool %s is not a mutation tool", req.Tool)})
100+
if strings.TrimSpace(req.SessionID) == "" {
101+
c.JSON(http.StatusBadRequest, gin.H{"error": "sessionId is required"})
112102
return
113103
}
114104

@@ -118,25 +108,28 @@ func HandleExecute(c *gin.Context) {
118108
return
119109
}
120110

121-
result, isError := ExecuteTool(c.Request.Context(), c, clientSet, req.Tool, req.Args)
122-
if isError {
123-
statusCode := http.StatusInternalServerError
124-
if strings.HasPrefix(result, "Forbidden: ") {
125-
statusCode = http.StatusForbidden
126-
} else if strings.HasPrefix(result, "Error: ") || strings.HasPrefix(result, "Unknown tool: ") {
127-
statusCode = http.StatusBadRequest
128-
}
129-
c.JSON(statusCode, gin.H{
130-
"status": "error",
131-
"message": result,
132-
})
111+
agent, err := NewAgent(clientSet, cfg)
112+
if err != nil {
113+
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("Failed to create AI agent: %v", err)})
133114
return
134115
}
135116

136-
c.JSON(http.StatusOK, gin.H{
137-
"status": "ok",
138-
"message": result,
139-
})
117+
c.Header("Content-Type", "text/event-stream")
118+
c.Header("Cache-Control", "no-cache")
119+
c.Header("Connection", "keep-alive")
120+
c.Header("X-Accel-Buffering", "no")
121+
122+
sendEvent := func(event SSEEvent) {
123+
data := MarshalSSEEvent(event)
124+
_, _ = fmt.Fprint(c.Writer, data)
125+
c.Writer.Flush()
126+
}
127+
128+
if err := agent.ContinuePendingAction(c, req.SessionID, sendEvent); err != nil {
129+
sendEvent(SSEEvent{Event: "error", Data: map[string]string{"message": err.Error()}})
130+
}
131+
132+
sendEvent(SSEEvent{Event: "done", Data: map[string]string{}})
140133
}
141134

142135
func HandleGetGeneralSetting(c *gin.Context) {

pkg/ai/openai.go

Lines changed: 49 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package ai
22

33
import (
4+
"context"
45
"encoding/json"
56
"fmt"
67
"sort"
@@ -37,8 +38,38 @@ func (a *Agent) processChatOpenAI(c *gin.Context, req *ChatRequest, sendEvent fu
3738
}
3839
sysPrompt := buildContextualSystemPrompt(req.PageContext, runtimeCtx, language)
3940
messages := toOpenAIMessages(sysPrompt, req.Messages)
41+
a.runOpenAIConversation(ctx, c, messages, sendEvent)
42+
}
43+
44+
func (a *Agent) continueChatOpenAI(c *gin.Context, session pendingSession, sendEvent func(SSEEvent)) error {
45+
ctx := c.Request.Context()
46+
result, isError := ExecuteTool(ctx, c, a.cs, session.ToolCall.Name, session.ToolCall.Args)
47+
48+
sendEvent(SSEEvent{
49+
Event: "tool_result",
50+
Data: map[string]interface{}{
51+
"tool": session.ToolCall.Name,
52+
"result": result,
53+
"is_error": isError,
54+
},
55+
})
4056

41-
tools := OpenAIToolDefs()
57+
if isError {
58+
result = "Tool error: " + result
59+
}
60+
61+
session.OpenAIMessages = append(session.OpenAIMessages, openai.ToolMessage(result, session.ToolCall.ID))
62+
a.runOpenAIConversation(ctx, c, session.OpenAIMessages, sendEvent)
63+
return nil
64+
}
65+
66+
func (a *Agent) runOpenAIConversation(
67+
ctx context.Context,
68+
c *gin.Context,
69+
messages []openai.ChatCompletionMessageParamUnion,
70+
sendEvent func(SSEEvent),
71+
) {
72+
tools := OpenAIToolDefs(a.cs)
4273

4374
maxIterations := 100
4475
for i := 0; i < maxIterations; i++ {
@@ -107,11 +138,26 @@ func (a *Agent) processChatOpenAI(c *gin.Context, req *ChatRequest, sendEvent fu
107138
messages = append(messages, openai.ToolMessage("Tool error: "+result, tc.ID))
108139
continue
109140
}
141+
sessionID := agentPendingSessions.save(pendingSession{
142+
Provider: a.provider,
143+
OpenAIMessages: append([]openai.ChatCompletionMessageParamUnion(nil), messages...),
144+
ToolCall: pendingToolCall{
145+
ID: tc.ID,
146+
Name: toolName,
147+
Args: args,
148+
},
149+
})
150+
if sessionID == "" {
151+
errorMsg := "Failed to save pending session"
152+
messages = append(messages, openai.ToolMessage("Tool error: "+errorMsg, tc.ID))
153+
continue
154+
}
110155
sendEvent(SSEEvent{
111156
Event: "action_required",
112157
Data: map[string]interface{}{
113-
"tool": toolName,
114-
"args": args,
158+
"tool": toolName,
159+
"args": args,
160+
"session_id": sessionID,
115161
},
116162
})
117163
return

0 commit comments

Comments
 (0)