66 @date:2024/1/9 18:25
77 @desc: 对话step Base实现
88"""
9- import logging
9+ import json
10+ import os
1011import time
1112import traceback
1213import uuid_utils .compat as uuid
2425from application .chat_pipeline .I_base_chat_pipeline import ParagraphPipelineModel
2526from application .chat_pipeline .pipeline_manage import PipelineManage
2627from application .chat_pipeline .step .chat_step .i_chat_step import IChatStep , PostResponseHandler
27- from application .flow .tools import Reasoning
28+ from application .flow .tools import Reasoning , mcp_response_generator
2829from application .models import ApplicationChatUserStats , ChatUserType
2930from common .utils .logger import maxkb_logger
31+ from common .utils .rsa_util import rsa_long_decrypt
32+ from common .utils .tool_code import ToolExecutor
33+ from maxkb .const import CONFIG
3034from models_provider .tools import get_model_instance_by_model_workspace_id
35+ from tools .models import Tool
3136
3237
3338def add_access_num (chat_user_id = None , chat_user_type = None , application_id = None ):
@@ -54,6 +59,7 @@ def write_context(step, manage, request_token, response_token, all_text):
5459 manage .context ['answer_tokens' ] = manage .context ['answer_tokens' ] + response_token
5560
5661
62+
5763def event_content (response ,
5864 chat_id ,
5965 chat_record_id ,
@@ -169,6 +175,12 @@ def execute(self, message_list: List[BaseMessage],
169175 no_references_setting = None ,
170176 model_params_setting = None ,
171177 model_setting = None ,
178+ mcp_enable = False ,
179+ mcp_tool_ids = None ,
180+ mcp_servers = '' ,
181+ mcp_source = "referencing" ,
182+ tool_enable = False ,
183+ tool_ids = None ,
172184 ** kwargs ):
173185 chat_model = get_model_instance_by_model_workspace_id (model_id , workspace_id ,
174186 ** model_params_setting ) if model_id is not None else None
@@ -177,14 +189,24 @@ def execute(self, message_list: List[BaseMessage],
177189 paragraph_list ,
178190 manage , padding_problem_text , chat_user_id , chat_user_type ,
179191 no_references_setting ,
180- model_setting )
192+ model_setting ,
193+ mcp_enable , mcp_tool_ids , mcp_servers , mcp_source , tool_enable , tool_ids )
181194 else :
182195 return self .execute_block (message_list , chat_id , problem_text , post_response_handler , chat_model ,
183196 paragraph_list ,
184197 manage , padding_problem_text , chat_user_id , chat_user_type , no_references_setting ,
185- model_setting )
198+ model_setting ,
199+ mcp_enable , mcp_tool_ids , mcp_servers , mcp_source , tool_enable , tool_ids )
186200
187201 def get_details (self , manage , ** kwargs ):
202+ # 删除临时生成的MCP代码文件
203+ if self .context .get ('execute_ids' ):
204+ executor = ToolExecutor (CONFIG .get ('SANDBOX' ))
205+ # 清理工具代码文件,延时删除,避免文件被占用
206+ for tool_id in self .context .get ('execute_ids' ):
207+ code_path = f'{ executor .sandbox_path } /execute/{ tool_id } .py'
208+ if os .path .exists (code_path ):
209+ os .remove (code_path )
188210 return {
189211 'step_type' : 'chat_step' ,
190212 'run_time' : self .context ['run_time' ],
@@ -206,12 +228,63 @@ def reset_message_list(message_list: List[BaseMessage], answer_text):
206228 result .append ({'role' : 'ai' , 'content' : answer_text })
207229 return result
208230
209- @staticmethod
210- def get_stream_result (message_list : List [BaseMessage ],
231+ def _handle_mcp_request (self , mcp_enable , tool_enable , mcp_source , mcp_servers , mcp_tool_ids , tool_ids ,
232+ chat_model , message_list ):
233+ if not mcp_enable and not tool_enable :
234+ return None
235+
236+ mcp_servers_config = {}
237+
238+ # 迁移过来mcp_source是None
239+ if mcp_source is None :
240+ mcp_source = 'custom'
241+ if mcp_enable :
242+ # 兼容老数据
243+ if not mcp_tool_ids :
244+ mcp_tool_ids = []
245+ if mcp_source == 'custom' and mcp_servers is not None and '"stdio"' not in mcp_servers :
246+ mcp_servers_config = json .loads (mcp_servers )
247+ elif mcp_tool_ids :
248+ mcp_tools = QuerySet (Tool ).filter (id__in = mcp_tool_ids ).values ()
249+ for mcp_tool in mcp_tools :
250+ if mcp_tool and mcp_tool ['is_active' ]:
251+ mcp_servers_config = {** mcp_servers_config , ** json .loads (mcp_tool ['code' ])}
252+
253+ if tool_enable :
254+ if tool_ids and len (tool_ids ) > 0 : # 如果有工具ID,则将其转换为MCP
255+ self .context ['tool_ids' ] = tool_ids
256+ self .context ['execute_ids' ] = []
257+ for tool_id in tool_ids :
258+ tool = QuerySet (Tool ).filter (id = tool_id ).first ()
259+ if not tool .is_active :
260+ continue
261+ executor = ToolExecutor (CONFIG .get ('SANDBOX' ))
262+ if tool .init_params is not None :
263+ params = json .loads (rsa_long_decrypt (tool .init_params ))
264+ else :
265+ params = {}
266+ _id , tool_config = executor .get_tool_mcp_config (tool .code , params )
267+
268+ self .context ['execute_ids' ].append (_id )
269+ mcp_servers_config [str (tool .id )] = tool_config
270+
271+ if len (mcp_servers_config ) > 0 :
272+ return mcp_response_generator (chat_model , message_list , json .dumps (mcp_servers_config ))
273+
274+ return None
275+
276+
277+ def get_stream_result (self , message_list : List [BaseMessage ],
211278 chat_model : BaseChatModel = None ,
212279 paragraph_list = None ,
213280 no_references_setting = None ,
214- problem_text = None ):
281+ problem_text = None ,
282+ mcp_enable = False ,
283+ mcp_tool_ids = None ,
284+ mcp_servers = '' ,
285+ mcp_source = "referencing" ,
286+ tool_enable = False ,
287+ tool_ids = None ):
215288 if paragraph_list is None :
216289 paragraph_list = []
217290 directly_return_chunk_list = [AIMessageChunk (content = paragraph .content )
@@ -227,6 +300,12 @@ def get_stream_result(message_list: List[BaseMessage],
227300 return iter ([AIMessageChunk (
228301 _ ('Sorry, the AI model is not configured. Please go to the application to set up the AI model first.' ))]), False
229302 else :
303+ # 处理 MCP 请求
304+ mcp_result = self ._handle_mcp_request (
305+ mcp_enable , tool_enable , mcp_source , mcp_servers , mcp_tool_ids , tool_ids , chat_model , message_list ,
306+ )
307+ if mcp_result :
308+ return mcp_result , True
230309 return chat_model .stream (message_list ), True
231310
232311 def execute_stream (self , message_list : List [BaseMessage ],
@@ -239,9 +318,15 @@ def execute_stream(self, message_list: List[BaseMessage],
239318 padding_problem_text : str = None ,
240319 chat_user_id = None , chat_user_type = None ,
241320 no_references_setting = None ,
242- model_setting = None ):
321+ model_setting = None ,
322+ mcp_enable = False ,
323+ mcp_tool_ids = None ,
324+ mcp_servers = '' ,
325+ mcp_source = "referencing" ,
326+ tool_enable = False ,
327+ tool_ids = None ):
243328 chat_result , is_ai_chat = self .get_stream_result (message_list , chat_model , paragraph_list ,
244- no_references_setting , problem_text )
329+ no_references_setting , problem_text , mcp_enable , mcp_tool_ids , mcp_servers , mcp_source , tool_enable , tool_ids )
245330 chat_record_id = uuid .uuid7 ()
246331 r = StreamingHttpResponse (
247332 streaming_content = event_content (chat_result , chat_id , chat_record_id , paragraph_list ,
@@ -253,12 +338,17 @@ def execute_stream(self, message_list: List[BaseMessage],
253338 r ['Cache-Control' ] = 'no-cache'
254339 return r
255340
256- @staticmethod
257- def get_block_result (message_list : List [BaseMessage ],
341+ def get_block_result (self , message_list : List [BaseMessage ],
258342 chat_model : BaseChatModel = None ,
259343 paragraph_list = None ,
260344 no_references_setting = None ,
261- problem_text = None ):
345+ problem_text = None ,
346+ mcp_enable = False ,
347+ mcp_tool_ids = None ,
348+ mcp_servers = '' ,
349+ mcp_source = "referencing" ,
350+ tool_enable = False ,
351+ tool_ids = None ):
262352 if paragraph_list is None :
263353 paragraph_list = []
264354 directly_return_chunk_list = [AIMessageChunk (content = paragraph .content )
@@ -273,6 +363,12 @@ def get_block_result(message_list: List[BaseMessage],
273363 return AIMessage (
274364 _ ('Sorry, the AI model is not configured. Please go to the application to set up the AI model first.' )), False
275365 else :
366+ # 处理 MCP 请求
367+ mcp_result = self ._handle_mcp_request (
368+ mcp_enable , tool_enable , mcp_source , mcp_servers , mcp_tool_ids , tool_ids , chat_model , message_list ,
369+ )
370+ if mcp_result :
371+ return mcp_result , True
276372 return chat_model .invoke (message_list ), True
277373
278374 def execute_block (self , message_list : List [BaseMessage ],
@@ -284,7 +380,13 @@ def execute_block(self, message_list: List[BaseMessage],
284380 manage : PipelineManage = None ,
285381 padding_problem_text : str = None ,
286382 chat_user_id = None , chat_user_type = None , no_references_setting = None ,
287- model_setting = None ):
383+ model_setting = None ,
384+ mcp_enable = False ,
385+ mcp_tool_ids = None ,
386+ mcp_servers = '' ,
387+ mcp_source = "referencing" ,
388+ tool_enable = False ,
389+ tool_ids = None ):
288390 reasoning_content_enable = model_setting .get ('reasoning_content_enable' , False )
289391 reasoning_content_start = model_setting .get ('reasoning_content_start' , '<think>' )
290392 reasoning_content_end = model_setting .get ('reasoning_content_end' , '</think>' )
@@ -294,7 +396,7 @@ def execute_block(self, message_list: List[BaseMessage],
294396 # 调用模型
295397 try :
296398 chat_result , is_ai_chat = self .get_block_result (message_list , chat_model , paragraph_list ,
297- no_references_setting , problem_text )
399+ no_references_setting , problem_text , mcp_enable , mcp_tool_ids , mcp_servers , mcp_source , tool_enable , tool_ids )
298400 if is_ai_chat :
299401 request_token = chat_model .get_num_tokens_from_messages (message_list )
300402 response_token = chat_model .get_num_tokens (chat_result .content )
0 commit comments