Skip to content

Commit e9c8c95

Browse files
committed
feat: add MCP tool ID and source fields to chat node for enhanced configuration
1 parent 1875368 commit e9c8c95

File tree

2 files changed

+76
-10
lines changed

2 files changed

+76
-10
lines changed

apps/application/flow/step_node/ai_chat_step_node/i_chat_node.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,14 @@ class ChatNodeSerializer(serializers.Serializer):
3131
label='Model settings')
3232
dialogue_type = serializers.CharField(required=False, allow_blank=True, allow_null=True,
3333
label=_("Context Type"))
34-
mcp_enable = serializers.BooleanField(required=False,
35-
label=_("Whether to enable MCP"))
34+
mcp_enable = serializers.BooleanField(required=False, label=_("Whether to enable MCP"))
3635
mcp_servers = serializers.JSONField(required=False, label=_("MCP Server"))
36+
mcp_tool_id = serializers.CharField(required=False, allow_blank=True, allow_null=True, label=_("MCP Tool ID"))
37+
mcp_source = serializers.CharField(required=False, allow_blank=True, allow_null=True, label=_("MCP Source"))
38+
39+
tool_enable = serializers.BooleanField(required=False, default=False, label=_("Whether to enable tools"))
40+
tool_ids = serializers.ListField(child=serializers.UUIDField(), required=False, allow_empty=True,
41+
label=_("Tool IDs"), )
3742

3843

3944
class IChatNode(INode):
@@ -52,5 +57,9 @@ def execute(self, model_id, system, prompt, dialogue_number, history_chat_record
5257
model_setting=None,
5358
mcp_enable=False,
5459
mcp_servers=None,
60+
mcp_tool_id=None,
61+
mcp_source=None,
62+
tool_enable=False,
63+
tool_ids=None,
5564
**kwargs) -> NodeResult:
5665
pass

apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py

Lines changed: 65 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
"""
99
import asyncio
1010
import json
11-
import logging
11+
import os
1212
import re
1313
import time
1414
from functools import reduce
@@ -23,9 +23,11 @@
2323
from application.flow.i_step_node import NodeResult, INode
2424
from application.flow.step_node.ai_chat_step_node.i_chat_node import IChatNode
2525
from application.flow.tools import Reasoning
26+
from common.utils.logger import maxkb_logger
27+
from common.utils.tool_code import ToolExecutor
2628
from models_provider.models import Model
2729
from models_provider.tools import get_model_credential, get_model_instance_by_model_workspace_id
28-
from common.utils.logger import maxkb_logger
30+
from tools.models import Tool
2931

3032
tool_message_template = """
3133
<details>
@@ -211,6 +213,10 @@ def execute(self, model_id, system, prompt, dialogue_number, history_chat_record
211213
model_setting=None,
212214
mcp_enable=False,
213215
mcp_servers=None,
216+
mcp_tool_id=None,
217+
mcp_source=None,
218+
tool_enable=False,
219+
tool_ids=None,
214220
**kwargs) -> NodeResult:
215221
if dialogue_type is None:
216222
dialogue_type = 'WORKFLOW'
@@ -234,12 +240,13 @@ def execute(self, model_id, system, prompt, dialogue_number, history_chat_record
234240
message_list = self.generate_message_list(system, prompt, history_message)
235241
self.context['message_list'] = message_list
236242

237-
if mcp_enable and mcp_servers is not None and '"stdio"' not in mcp_servers:
238-
r = mcp_response_generator(chat_model, message_list, mcp_servers)
239-
return NodeResult(
240-
{'result': r, 'chat_model': chat_model, 'message_list': message_list,
241-
'history_message': history_message, 'question': question.content}, {},
242-
_write_context=write_context_stream)
243+
# 处理 MCP 请求
244+
mcp_result = self._handle_mcp_request(
245+
mcp_enable, tool_enable, mcp_source, mcp_servers, mcp_tool_id, tool_ids, chat_model, message_list,
246+
history_message, question
247+
)
248+
if mcp_result:
249+
return mcp_result
243250

244251
if stream:
245252
r = chat_model.stream(message_list)
@@ -252,6 +259,48 @@ def execute(self, model_id, system, prompt, dialogue_number, history_chat_record
252259
'history_message': history_message, 'question': question.content}, {},
253260
_write_context=write_context)
254261

262+
def _handle_mcp_request(self, mcp_enable, tool_enable, mcp_source, mcp_servers, mcp_tool_id, tool_ids,
263+
chat_model, message_list, history_message, question):
264+
if not mcp_enable and not tool_enable:
265+
return None
266+
267+
mcp_servers_config = {}
268+
269+
if mcp_enable:
270+
if mcp_source == 'custom' and mcp_servers is not None and '"stdio"' not in mcp_servers:
271+
mcp_servers_config = json.loads(mcp_servers)
272+
elif mcp_tool_id:
273+
mcp_tool = QuerySet(Tool).filter(id=mcp_tool_id).first()
274+
if mcp_tool:
275+
mcp_servers_config = json.loads(mcp_tool.code)
276+
277+
if tool_enable:
278+
if tool_ids and len(tool_ids) > 0: # 如果有工具ID,则将其转换为MCP
279+
self.context['tool_ids'] = tool_ids
280+
for tool_id in tool_ids:
281+
tool = QuerySet(Tool).filter(id=tool_id).first()
282+
executor = ToolExecutor()
283+
code = executor.generate_mcp_server_code(tool.code)
284+
code_path = f'{executor.sandbox_path}/execute/{tool_id}.py'
285+
with open(code_path, 'w') as f:
286+
f.write(code)
287+
288+
tool_config = {
289+
'command': 'python',
290+
'args': [code_path],
291+
'transport': 'stdio',
292+
}
293+
mcp_servers_config[str(tool.id)] = tool_config
294+
295+
if len(mcp_servers_config) > 0:
296+
r = mcp_response_generator(chat_model, message_list, json.dumps(mcp_servers_config))
297+
return NodeResult(
298+
{'result': r, 'chat_model': chat_model, 'message_list': message_list,
299+
'history_message': history_message, 'question': question.content}, {},
300+
_write_context=write_context_stream)
301+
302+
return None
303+
255304
@staticmethod
256305
def get_history_message(history_chat_record, dialogue_number, dialogue_type, runtime_node_id):
257306
start_index = len(history_chat_record) - dialogue_number
@@ -284,6 +333,14 @@ def reset_message_list(message_list: List[BaseMessage], answer_text):
284333
return result
285334

286335
def get_details(self, index: int, **kwargs):
336+
# 删除临时生成的MCP代码文件
337+
if self.context.get('tool_ids'):
338+
executor = ToolExecutor()
339+
# 清理工具代码文件,延时删除,避免文件被占用
340+
for tool_id in self.context.get('tool_ids'):
341+
code_path = f'{executor.sandbox_path}/execute/{tool_id}.py'
342+
if os.path.exists(code_path):
343+
os.remove(code_path)
287344
return {
288345
'name': self.node.properties.get('stepName'),
289346
"index": index,

0 commit comments

Comments
 (0)