Skip to content

Commit c524906

Browse files
authored
Merge pull request #43 from AnimeAIChat/develop
重构任务管理,连接优化
2 parents 48aa1ff + 583e442 commit c524906

File tree

17 files changed

+707
-478
lines changed

17 files changed

+707
-478
lines changed

go.mod

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,8 @@ require (
4242
golang.org/x/arch v0.17.0 // indirect
4343
golang.org/x/crypto v0.38.0 // indirect
4444
golang.org/x/net v0.40.0 // indirect
45-
golang.org/x/sync v0.14.0 // indirect
4645
golang.org/x/sys v0.33.0 // indirect
4746
golang.org/x/text v0.25.0 // indirect
4847
google.golang.org/protobuf v1.36.6 // indirect
48+
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect
4949
)

go.sum

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,11 @@ github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa02
5353
github.com/klauspost/cpuid/v2 v2.2.10 h1:tBs3QSyvjDyFTq3uoc/9xFpCuOsJQFNPiAhYdw2skhE=
5454
github.com/klauspost/cpuid/v2 v2.2.10/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0=
5555
github.com/knz/go-libedit v1.10.1/go.mod h1:MZTVkCWyz0oBc7JOWP3wNAzd002ZbM/5hgShxwh4x8M=
56+
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
5657
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
5758
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
59+
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
60+
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
5861
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
5962
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
6063
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
@@ -124,8 +127,9 @@ golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d h1:vU5i/LfpvrRCpgM/VPfJ
124127
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk=
125128
google.golang.org/protobuf v1.36.6 h1:z1NpPI8ku2WgiWnf+t9wTPsn6eP1L7ksHUlkfLvd9xY=
126129
google.golang.org/protobuf v1.36.6/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY=
127-
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
128130
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
131+
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
132+
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
129133
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
130134
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
131135
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
package Auth
1+
package auth
22

33
import (
44
"errors"

src/core/connection.go

Lines changed: 44 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,14 @@ type configGetter interface {
3636
// ConnectionHandler 连接处理器结构
3737
type ConnectionHandler struct {
3838
// 确保实现 AsrEventListener 接口
39-
_ providers.AsrEventListener
40-
config *configs.Config
41-
logger *utils.Logger
42-
conn Conn
43-
closeOnce sync.Once
44-
taskMgr *task.TaskManager
45-
providers struct {
39+
_ providers.AsrEventListener
40+
config *configs.Config
41+
logger *utils.Logger
42+
conn Conn
43+
closeOnce sync.Once
44+
taskMgr *task.TaskManager
45+
safeCallbackFunc func(func(*ConnectionHandler)) func()
46+
providers struct {
4647
asr providers.ASRProvider
4748
llm providers.LLMProvider
4849
tts providers.TTSProvider
@@ -106,6 +107,8 @@ type ConnectionHandler struct {
106107
// functions
107108
functionRegister *function.FunctionRegistry
108109
mcpManager *mcp.Manager
110+
111+
ctx context.Context
109112
}
110113

111114
// NewConnectionHandler 创建新的连接处理器
@@ -114,6 +117,7 @@ func NewConnectionHandler(
114117
providerSet *pool.ProviderSet,
115118
logger *utils.Logger,
116119
req *http.Request,
120+
ctx context.Context,
117121
) *ConnectionHandler {
118122
handler := &ConnectionHandler{
119123
config: config,
@@ -143,6 +147,8 @@ func NewConnectionHandler(
143147
serverAudioChannels: 1,
144148
serverAudioFrameDuration: 60,
145149

150+
ctx: ctx,
151+
146152
headers: make(map[string]string),
147153
}
148154

@@ -187,6 +193,37 @@ func NewConnectionHandler(
187193
return handler
188194
}
189195

196+
func (h *ConnectionHandler) SetTaskCallback(callback func(func(*ConnectionHandler)) func()) {
197+
h.safeCallbackFunc = callback
198+
}
199+
200+
func (h *ConnectionHandler) SubmitTask(taskType string, params map[string]interface{}) {
201+
_task, id := task.NewTask(h.ctx, "", params)
202+
h.logger.Info(fmt.Sprintf("提交任务: %s, ID: %s, 参数: %v", _task.Type, id, params))
203+
// 创建安全回调用于任务完成时调用
204+
var taskCallback func(result interface{})
205+
if h.safeCallbackFunc != nil {
206+
taskCallback = func(result interface{}) {
207+
fmt.Print("任务完成回调: ")
208+
safeCallback := h.safeCallbackFunc(func(handler *ConnectionHandler) {
209+
// 处理任务完成逻辑
210+
handler.handleTaskComplete(_task, id, result)
211+
})
212+
// 执行安全回调
213+
if safeCallback != nil {
214+
safeCallback()
215+
}
216+
}
217+
}
218+
cb := task.NewCallBack(taskCallback)
219+
_task.Callback = cb
220+
h.taskMgr.SubmitTask(h.sessionID, _task)
221+
}
222+
223+
func (h *ConnectionHandler) handleTaskComplete(task *task.Task, id string, result interface{}) {
224+
h.logger.Info(fmt.Sprintf("任务 %s 完成,ID: %s, %v", task.Type, id, result))
225+
}
226+
190227
// Handle 处理WebSocket连接
191228
func (h *ConnectionHandler) Handle(conn Conn) {
192229
defer conn.Close()

src/core/connection_handlemsg.go

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ import (
1010
"xiaozhi-server-go/src/core/image"
1111
"xiaozhi-server-go/src/core/providers"
1212
"xiaozhi-server-go/src/core/utils"
13-
"xiaozhi-server-go/src/task"
1413
)
1514

1615
// handleMessage 处理接收到的消息
@@ -102,17 +101,6 @@ func (h *ConnectionHandler) handleVisionMessage(msgMap map[string]interface{}) e
102101
// 处理视觉消息
103102
cmd := msgMap["cmd"].(string)
104103
if cmd == "gen_pic" {
105-
text := msgMap["text"].(string)
106-
params := map[string]interface{}{
107-
"prompt": text,
108-
"size": "1024x1024",
109-
"quality": "standard",
110-
"api_key": h.config.LLM["ChatGLMLLM"].APIKey,
111-
"client_id": h.sessionID,
112-
}
113-
task, id := task.NewTask(task.TaskTypeImageGen, params, task.NewMessageCallback(h.conn, "vision", cmd))
114-
h.taskMgr.SubmitTask(h.sessionID, task)
115-
h.logger.Info(fmt.Sprintf("生成图片任务提交成功: %s, %s", text, id))
116104
} else if cmd == "gen_video" {
117105
} else if cmd == "read_img" {
118106
}

src/core/mcp/manager.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,8 @@ func (m *Manager) BindConnection(conn Conn, fh types.FunctionRegistryInterface,
144144
} else {
145145
// 重新绑定连接而不是重新创建
146146
m.XiaoZhiMCPClient.SetConnection(conn)
147+
m.XiaoZhiMCPClient.SetID(deviceID, clientID)
148+
m.XiaoZhiMCPClient.SetToken(token)
147149
if !m.XiaoZhiMCPClient.IsReady() {
148150
if err := m.XiaoZhiMCPClient.Start(context.Background()); err != nil {
149151
return fmt.Errorf("重启XiaoZhi MCP客户端失败: %v", err)

src/core/mcp/xiaozhi_client.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ import (
88
"sync"
99
"time"
1010

11-
"xiaozhi-server-go/src/core/Auth"
11+
"xiaozhi-server-go/src/core/auth"
1212
"xiaozhi-server-go/src/core/utils"
1313

1414
"github.com/sashabaranov/go-openai"
@@ -75,7 +75,7 @@ func (c *XiaoZhiMCPClient) SetID(deviceID string, clientID string) {
7575
}
7676

7777
func (c *XiaoZhiMCPClient) SetToken(token string) {
78-
auth := Auth.NewAuthToken(token)
78+
auth := auth.NewAuthToken(token)
7979
visionToken, err := auth.GenerateToken(c.deviceID)
8080

8181
if err != nil {

src/core/websocket_conn.go

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
package core
2+
3+
import (
4+
"errors"
5+
"sync"
6+
"sync/atomic"
7+
"time"
8+
9+
"github.com/gorilla/websocket"
10+
)
11+
12+
var (
13+
ErrConnectionClosed = errors.New("websocket connection is closed")
14+
)
15+
16+
// websocketConn 封装gorilla/websocket的连接实现
17+
type websocketConn struct {
18+
conn *websocket.Conn
19+
writeMu sync.Mutex // 写操作互斥锁
20+
closed int32 // 原子操作标记连接状态 (0=open, 1=closed)
21+
lastActive int64 // 最后活跃时间戳(原子操作)
22+
}
23+
24+
func (w *websocketConn) ReadMessage() (messageType int, p []byte, err error) {
25+
if atomic.LoadInt32(&w.closed) == 1 {
26+
return 0, nil, ErrConnectionClosed
27+
}
28+
29+
// 设置读取超时
30+
w.conn.SetReadDeadline(time.Now().Add(60 * time.Second))
31+
32+
messageType, p, err = w.conn.ReadMessage()
33+
if err != nil {
34+
// 如果读取出错,标记连接为已关闭
35+
atomic.StoreInt32(&w.closed, 1)
36+
return 0, nil, err
37+
}
38+
39+
// 更新最后活跃时间
40+
atomic.StoreInt64(&w.lastActive, time.Now().Unix())
41+
42+
return messageType, p, nil
43+
}
44+
45+
func (w *websocketConn) WriteMessage(messageType int, data []byte) error {
46+
// 检查连接状态
47+
if atomic.LoadInt32(&w.closed) == 1 {
48+
return ErrConnectionClosed
49+
}
50+
51+
// 使用写锁确保写操作的串行化
52+
w.writeMu.Lock()
53+
defer w.writeMu.Unlock()
54+
55+
// 双重检查,防止在获取锁的过程中连接被关闭
56+
if atomic.LoadInt32(&w.closed) == 1 {
57+
return ErrConnectionClosed
58+
}
59+
60+
// 设置写入超时
61+
w.conn.SetWriteDeadline(time.Now().Add(30 * time.Second))
62+
63+
err := w.conn.WriteMessage(messageType, data)
64+
if err != nil {
65+
// 如果写入出错,标记连接为已关闭
66+
atomic.StoreInt32(&w.closed, 1)
67+
return err
68+
}
69+
70+
// 更新最后活跃时间
71+
atomic.StoreInt64(&w.lastActive, time.Now().Unix())
72+
73+
return nil
74+
}
75+
76+
func (w *websocketConn) Close() error {
77+
// 使用原子操作避免重复关闭
78+
if !atomic.CompareAndSwapInt32(&w.closed, 0, 1) {
79+
return nil // 已经关闭过了
80+
}
81+
82+
w.writeMu.Lock()
83+
defer w.writeMu.Unlock()
84+
85+
// 尝试发送关闭帧(不强制要求成功)
86+
closeMsg := websocket.FormatCloseMessage(websocket.CloseNormalClosure, "connection closed")
87+
w.conn.SetWriteDeadline(time.Now().Add(5 * time.Second))
88+
w.conn.WriteMessage(websocket.CloseMessage, closeMsg)
89+
90+
return w.conn.Close()
91+
}
92+
93+
// IsClosed 检查连接是否已关闭
94+
func (w *websocketConn) IsClosed() bool {
95+
return atomic.LoadInt32(&w.closed) == 1
96+
}
97+
98+
// GetLastActiveTime 获取最后活跃时间
99+
func (w *websocketConn) GetLastActiveTime() time.Time {
100+
timestamp := atomic.LoadInt64(&w.lastActive)
101+
return time.Unix(timestamp, 0)
102+
}
103+
104+
// IsStale 检查连接是否过期(基于最后活跃时间)
105+
func (w *websocketConn) IsStale(timeout time.Duration) bool {
106+
if w.IsClosed() {
107+
return true
108+
}
109+
return time.Since(w.GetLastActiveTime()) > timeout
110+
}

0 commit comments

Comments
 (0)