88"""
99import asyncio
1010import json
11- import logging
11+ import os
1212import re
1313import time
1414from functools import reduce
2323from application .flow .i_step_node import NodeResult , INode
2424from application .flow .step_node .ai_chat_step_node .i_chat_node import IChatNode
2525from application .flow .tools import Reasoning
26+ from common .utils .logger import maxkb_logger
27+ from common .utils .tool_code import ToolExecutor
2628from models_provider .models import Model
2729from models_provider .tools import get_model_credential , get_model_instance_by_model_workspace_id
28- from common . utils . logger import maxkb_logger
30+ from tools . models import Tool
2931
3032tool_message_template = """
3133<details>
@@ -211,6 +213,10 @@ def execute(self, model_id, system, prompt, dialogue_number, history_chat_record
211213 model_setting = None ,
212214 mcp_enable = False ,
213215 mcp_servers = None ,
216+ mcp_tool_id = None ,
217+ mcp_source = None ,
218+ tool_enable = False ,
219+ tool_ids = None ,
214220 ** kwargs ) -> NodeResult :
215221 if dialogue_type is None :
216222 dialogue_type = 'WORKFLOW'
@@ -234,12 +240,13 @@ def execute(self, model_id, system, prompt, dialogue_number, history_chat_record
234240 message_list = self .generate_message_list (system , prompt , history_message )
235241 self .context ['message_list' ] = message_list
236242
237- if mcp_enable and mcp_servers is not None and '"stdio"' not in mcp_servers :
238- r = mcp_response_generator (chat_model , message_list , mcp_servers )
239- return NodeResult (
240- {'result' : r , 'chat_model' : chat_model , 'message_list' : message_list ,
241- 'history_message' : history_message , 'question' : question .content }, {},
242- _write_context = write_context_stream )
243+ # 处理 MCP 请求
244+ mcp_result = self ._handle_mcp_request (
245+ mcp_enable , tool_enable , mcp_source , mcp_servers , mcp_tool_id , tool_ids , chat_model , message_list ,
246+ history_message , question
247+ )
248+ if mcp_result :
249+ return mcp_result
243250
244251 if stream :
245252 r = chat_model .stream (message_list )
@@ -252,6 +259,48 @@ def execute(self, model_id, system, prompt, dialogue_number, history_chat_record
252259 'history_message' : history_message , 'question' : question .content }, {},
253260 _write_context = write_context )
254261
262+ def _handle_mcp_request (self , mcp_enable , tool_enable , mcp_source , mcp_servers , mcp_tool_id , tool_ids ,
263+ chat_model , message_list , history_message , question ):
264+ if not mcp_enable and not tool_enable :
265+ return None
266+
267+ mcp_servers_config = {}
268+
269+ if mcp_enable :
270+ if mcp_source == 'custom' and mcp_servers is not None and '"stdio"' not in mcp_servers :
271+ mcp_servers_config = json .loads (mcp_servers )
272+ elif mcp_tool_id :
273+ mcp_tool = QuerySet (Tool ).filter (id = mcp_tool_id ).first ()
274+ if mcp_tool :
275+ mcp_servers_config = json .loads (mcp_tool .code )
276+
277+ if tool_enable :
278+ if tool_ids and len (tool_ids ) > 0 : # 如果有工具ID,则将其转换为MCP
279+ self .context ['tool_ids' ] = tool_ids
280+ for tool_id in tool_ids :
281+ tool = QuerySet (Tool ).filter (id = tool_id ).first ()
282+ executor = ToolExecutor ()
283+ code = executor .generate_mcp_server_code (tool .code )
284+ code_path = f'{ executor .sandbox_path } /execute/{ tool_id } .py'
285+ with open (code_path , 'w' ) as f :
286+ f .write (code )
287+
288+ tool_config = {
289+ 'command' : 'python' ,
290+ 'args' : [code_path ],
291+ 'transport' : 'stdio' ,
292+ }
293+ mcp_servers_config [str (tool .id )] = tool_config
294+
295+ if len (mcp_servers_config ) > 0 :
296+ r = mcp_response_generator (chat_model , message_list , json .dumps (mcp_servers_config ))
297+ return NodeResult (
298+ {'result' : r , 'chat_model' : chat_model , 'message_list' : message_list ,
299+ 'history_message' : history_message , 'question' : question .content }, {},
300+ _write_context = write_context_stream )
301+
302+ return None
303+
255304 @staticmethod
256305 def get_history_message (history_chat_record , dialogue_number , dialogue_type , runtime_node_id ):
257306 start_index = len (history_chat_record ) - dialogue_number
@@ -284,6 +333,14 @@ def reset_message_list(message_list: List[BaseMessage], answer_text):
284333 return result
285334
286335 def get_details (self , index : int , ** kwargs ):
336+ # 删除临时生成的MCP代码文件
337+ if self .context .get ('tool_ids' ):
338+ executor = ToolExecutor ()
339+ # 清理工具代码文件,延时删除,避免文件被占用
340+ for tool_id in self .context .get ('tool_ids' ):
341+ code_path = f'{ executor .sandbox_path } /execute/{ tool_id } .py'
342+ if os .path .exists (code_path ):
343+ os .remove (code_path )
287344 return {
288345 'name' : self .node .properties .get ('stepName' ),
289346 "index" : index ,
0 commit comments