@@ -36,13 +36,14 @@ type configGetter interface {
3636// ConnectionHandler 连接处理器结构
3737type ConnectionHandler struct {
3838 // 确保实现 AsrEventListener 接口
39- _ providers.AsrEventListener
40- config * configs.Config
41- logger * utils.Logger
42- conn Conn
43- closeOnce sync.Once
44- taskMgr * task.TaskManager
45- providers struct {
39+ _ providers.AsrEventListener
40+ config * configs.Config
41+ logger * utils.Logger
42+ conn Conn
43+ closeOnce sync.Once
44+ taskMgr * task.TaskManager
45+ safeCallbackFunc func (func (* ConnectionHandler )) func ()
46+ providers struct {
4647 asr providers.ASRProvider
4748 llm providers.LLMProvider
4849 tts providers.TTSProvider
@@ -106,6 +107,8 @@ type ConnectionHandler struct {
106107 // functions
107108 functionRegister * function.FunctionRegistry
108109 mcpManager * mcp.Manager
110+
111+ ctx context.Context
109112}
110113
111114// NewConnectionHandler 创建新的连接处理器
@@ -114,6 +117,7 @@ func NewConnectionHandler(
114117 providerSet * pool.ProviderSet ,
115118 logger * utils.Logger ,
116119 req * http.Request ,
120+ ctx context.Context ,
117121) * ConnectionHandler {
118122 handler := & ConnectionHandler {
119123 config : config ,
@@ -143,6 +147,8 @@ func NewConnectionHandler(
143147 serverAudioChannels : 1 ,
144148 serverAudioFrameDuration : 60 ,
145149
150+ ctx : ctx ,
151+
146152 headers : make (map [string ]string ),
147153 }
148154
@@ -187,6 +193,37 @@ func NewConnectionHandler(
187193 return handler
188194}
189195
196+ func (h * ConnectionHandler ) SetTaskCallback (callback func (func (* ConnectionHandler )) func ()) {
197+ h .safeCallbackFunc = callback
198+ }
199+
200+ func (h * ConnectionHandler ) SubmitTask (taskType string , params map [string ]interface {}) {
201+ _task , id := task .NewTask (h .ctx , "" , params )
202+ h .logger .Info (fmt .Sprintf ("提交任务: %s, ID: %s, 参数: %v" , _task .Type , id , params ))
203+ // 创建安全回调用于任务完成时调用
204+ var taskCallback func (result interface {})
205+ if h .safeCallbackFunc != nil {
206+ taskCallback = func (result interface {}) {
207+ fmt .Print ("任务完成回调: " )
208+ safeCallback := h .safeCallbackFunc (func (handler * ConnectionHandler ) {
209+ // 处理任务完成逻辑
210+ handler .handleTaskComplete (_task , id , result )
211+ })
212+ // 执行安全回调
213+ if safeCallback != nil {
214+ safeCallback ()
215+ }
216+ }
217+ }
218+ cb := task .NewCallBack (taskCallback )
219+ _task .Callback = cb
220+ h .taskMgr .SubmitTask (h .sessionID , _task )
221+ }
222+
223+ func (h * ConnectionHandler ) handleTaskComplete (task * task.Task , id string , result interface {}) {
224+ h .logger .Info (fmt .Sprintf ("任务 %s 完成,ID: %s, %v" , task .Type , id , result ))
225+ }
226+
190227// Handle 处理WebSocket连接
191228func (h * ConnectionHandler ) Handle (conn Conn ) {
192229 defer conn .Close ()
0 commit comments