Skip to content

Commit 90a43e2

Browse files
committed
feat: add MCP and tool configuration options in application settings
1 parent 76766eb commit 90a43e2

File tree

10 files changed

+493
-86
lines changed

10 files changed

+493
-86
lines changed

apps/application/chat_pipeline/step/chat_step/i_chat_step.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,12 @@ class InstanceSerializer(serializers.Serializer):
8282

8383
model_params_setting = serializers.DictField(required=False, allow_null=True,
8484
label=_("Model parameter settings"))
85+
mcp_enable = serializers.BooleanField(label="MCP否启用", required=False, default=False)
86+
mcp_tool_ids = serializers.JSONField(label="MCP工具ID列表", required=False, default=list)
87+
mcp_servers = serializers.JSONField(label="MCP服务列表", required=False, default=dict)
88+
mcp_source = serializers.CharField(label="MCP Source", required=False, default="referencing")
89+
tool_enable = serializers.BooleanField(label="工具是否启用", required=False, default=False)
90+
tool_ids = serializers.JSONField(label="工具ID列表", required=False, default=list)
8591

8692
def is_valid(self, *, raise_exception=False):
8793
super().is_valid(raise_exception=True)
@@ -106,5 +112,8 @@ def execute(self, message_list: List[BaseMessage],
106112
paragraph_list=None,
107113
manage: PipelineManage = None,
108114
padding_problem_text: str = None, stream: bool = True, chat_user_id=None, chat_user_type=None,
109-
no_references_setting=None, model_params_setting=None, model_setting=None, **kwargs):
115+
no_references_setting=None, model_params_setting=None, model_setting=None,
116+
mcp_enable=False, mcp_tool_ids=None, mcp_servers='', mcp_source="referencing",
117+
tool_enable=False, tool_ids=None,
118+
**kwargs):
110119
pass

apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py

Lines changed: 116 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
@date:2024/1/9 18:25
77
@desc: 对话step Base实现
88
"""
9-
import logging
9+
import json
10+
import os
1011
import time
1112
import traceback
1213
import uuid_utils.compat as uuid
@@ -24,10 +25,14 @@
2425
from application.chat_pipeline.I_base_chat_pipeline import ParagraphPipelineModel
2526
from application.chat_pipeline.pipeline_manage import PipelineManage
2627
from application.chat_pipeline.step.chat_step.i_chat_step import IChatStep, PostResponseHandler
27-
from application.flow.tools import Reasoning
28+
from application.flow.tools import Reasoning, mcp_response_generator
2829
from application.models import ApplicationChatUserStats, ChatUserType
2930
from common.utils.logger import maxkb_logger
31+
from common.utils.rsa_util import rsa_long_decrypt
32+
from common.utils.tool_code import ToolExecutor
33+
from maxkb.const import CONFIG
3034
from models_provider.tools import get_model_instance_by_model_workspace_id
35+
from tools.models import Tool
3136

3237

3338
def add_access_num(chat_user_id=None, chat_user_type=None, application_id=None):
@@ -54,6 +59,7 @@ def write_context(step, manage, request_token, response_token, all_text):
5459
manage.context['answer_tokens'] = manage.context['answer_tokens'] + response_token
5560

5661

62+
5763
def event_content(response,
5864
chat_id,
5965
chat_record_id,
@@ -169,6 +175,12 @@ def execute(self, message_list: List[BaseMessage],
169175
no_references_setting=None,
170176
model_params_setting=None,
171177
model_setting=None,
178+
mcp_enable=False,
179+
mcp_tool_ids=None,
180+
mcp_servers='',
181+
mcp_source="referencing",
182+
tool_enable=False,
183+
tool_ids=None,
172184
**kwargs):
173185
chat_model = get_model_instance_by_model_workspace_id(model_id, workspace_id,
174186
**model_params_setting) if model_id is not None else None
@@ -177,14 +189,24 @@ def execute(self, message_list: List[BaseMessage],
177189
paragraph_list,
178190
manage, padding_problem_text, chat_user_id, chat_user_type,
179191
no_references_setting,
180-
model_setting)
192+
model_setting,
193+
mcp_enable, mcp_tool_ids, mcp_servers, mcp_source, tool_enable, tool_ids)
181194
else:
182195
return self.execute_block(message_list, chat_id, problem_text, post_response_handler, chat_model,
183196
paragraph_list,
184197
manage, padding_problem_text, chat_user_id, chat_user_type, no_references_setting,
185-
model_setting)
198+
model_setting,
199+
mcp_enable, mcp_tool_ids, mcp_servers, mcp_source, tool_enable, tool_ids)
186200

187201
def get_details(self, manage, **kwargs):
202+
# 删除临时生成的MCP代码文件
203+
if self.context.get('execute_ids'):
204+
executor = ToolExecutor(CONFIG.get('SANDBOX'))
205+
# 清理工具代码文件,延时删除,避免文件被占用
206+
for tool_id in self.context.get('execute_ids'):
207+
code_path = f'{executor.sandbox_path}/execute/{tool_id}.py'
208+
if os.path.exists(code_path):
209+
os.remove(code_path)
188210
return {
189211
'step_type': 'chat_step',
190212
'run_time': self.context['run_time'],
@@ -206,12 +228,63 @@ def reset_message_list(message_list: List[BaseMessage], answer_text):
206228
result.append({'role': 'ai', 'content': answer_text})
207229
return result
208230

209-
@staticmethod
210-
def get_stream_result(message_list: List[BaseMessage],
231+
def _handle_mcp_request(self, mcp_enable, tool_enable, mcp_source, mcp_servers, mcp_tool_ids, tool_ids,
232+
chat_model, message_list):
233+
if not mcp_enable and not tool_enable:
234+
return None
235+
236+
mcp_servers_config = {}
237+
238+
# 迁移过来mcp_source是None
239+
if mcp_source is None:
240+
mcp_source = 'custom'
241+
if mcp_enable:
242+
# 兼容老数据
243+
if not mcp_tool_ids:
244+
mcp_tool_ids = []
245+
if mcp_source == 'custom' and mcp_servers is not None and '"stdio"' not in mcp_servers:
246+
mcp_servers_config = json.loads(mcp_servers)
247+
elif mcp_tool_ids:
248+
mcp_tools = QuerySet(Tool).filter(id__in=mcp_tool_ids).values()
249+
for mcp_tool in mcp_tools:
250+
if mcp_tool and mcp_tool['is_active']:
251+
mcp_servers_config = {**mcp_servers_config, **json.loads(mcp_tool['code'])}
252+
253+
if tool_enable:
254+
if tool_ids and len(tool_ids) > 0: # 如果有工具ID,则将其转换为MCP
255+
self.context['tool_ids'] = tool_ids
256+
self.context['execute_ids'] = []
257+
for tool_id in tool_ids:
258+
tool = QuerySet(Tool).filter(id=tool_id).first()
259+
if not tool.is_active:
260+
continue
261+
executor = ToolExecutor(CONFIG.get('SANDBOX'))
262+
if tool.init_params is not None:
263+
params = json.loads(rsa_long_decrypt(tool.init_params))
264+
else:
265+
params = {}
266+
_id, tool_config = executor.get_tool_mcp_config(tool.code, params)
267+
268+
self.context['execute_ids'].append(_id)
269+
mcp_servers_config[str(tool.id)] = tool_config
270+
271+
if len(mcp_servers_config) > 0:
272+
return mcp_response_generator(chat_model, message_list, json.dumps(mcp_servers_config))
273+
274+
return None
275+
276+
277+
def get_stream_result(self, message_list: List[BaseMessage],
211278
chat_model: BaseChatModel = None,
212279
paragraph_list=None,
213280
no_references_setting=None,
214-
problem_text=None):
281+
problem_text=None,
282+
mcp_enable=False,
283+
mcp_tool_ids=None,
284+
mcp_servers='',
285+
mcp_source="referencing",
286+
tool_enable=False,
287+
tool_ids=None):
215288
if paragraph_list is None:
216289
paragraph_list = []
217290
directly_return_chunk_list = [AIMessageChunk(content=paragraph.content)
@@ -227,6 +300,12 @@ def get_stream_result(message_list: List[BaseMessage],
227300
return iter([AIMessageChunk(
228301
_('Sorry, the AI model is not configured. Please go to the application to set up the AI model first.'))]), False
229302
else:
303+
# 处理 MCP 请求
304+
mcp_result = self._handle_mcp_request(
305+
mcp_enable, tool_enable, mcp_source, mcp_servers, mcp_tool_ids, tool_ids, chat_model, message_list,
306+
)
307+
if mcp_result:
308+
return mcp_result, True
230309
return chat_model.stream(message_list), True
231310

232311
def execute_stream(self, message_list: List[BaseMessage],
@@ -239,9 +318,15 @@ def execute_stream(self, message_list: List[BaseMessage],
239318
padding_problem_text: str = None,
240319
chat_user_id=None, chat_user_type=None,
241320
no_references_setting=None,
242-
model_setting=None):
321+
model_setting=None,
322+
mcp_enable=False,
323+
mcp_tool_ids=None,
324+
mcp_servers='',
325+
mcp_source="referencing",
326+
tool_enable=False,
327+
tool_ids=None):
243328
chat_result, is_ai_chat = self.get_stream_result(message_list, chat_model, paragraph_list,
244-
no_references_setting, problem_text)
329+
no_references_setting, problem_text, mcp_enable, mcp_tool_ids, mcp_servers, mcp_source, tool_enable, tool_ids)
245330
chat_record_id = uuid.uuid7()
246331
r = StreamingHttpResponse(
247332
streaming_content=event_content(chat_result, chat_id, chat_record_id, paragraph_list,
@@ -253,12 +338,17 @@ def execute_stream(self, message_list: List[BaseMessage],
253338
r['Cache-Control'] = 'no-cache'
254339
return r
255340

256-
@staticmethod
257-
def get_block_result(message_list: List[BaseMessage],
341+
def get_block_result(self, message_list: List[BaseMessage],
258342
chat_model: BaseChatModel = None,
259343
paragraph_list=None,
260344
no_references_setting=None,
261-
problem_text=None):
345+
problem_text=None,
346+
mcp_enable=False,
347+
mcp_tool_ids=None,
348+
mcp_servers='',
349+
mcp_source="referencing",
350+
tool_enable=False,
351+
tool_ids=None):
262352
if paragraph_list is None:
263353
paragraph_list = []
264354
directly_return_chunk_list = [AIMessageChunk(content=paragraph.content)
@@ -273,6 +363,12 @@ def get_block_result(message_list: List[BaseMessage],
273363
return AIMessage(
274364
_('Sorry, the AI model is not configured. Please go to the application to set up the AI model first.')), False
275365
else:
366+
# 处理 MCP 请求
367+
mcp_result = self._handle_mcp_request(
368+
mcp_enable, tool_enable, mcp_source, mcp_servers, mcp_tool_ids, tool_ids, chat_model, message_list,
369+
)
370+
if mcp_result:
371+
return mcp_result, True
276372
return chat_model.invoke(message_list), True
277373

278374
def execute_block(self, message_list: List[BaseMessage],
@@ -284,7 +380,13 @@ def execute_block(self, message_list: List[BaseMessage],
284380
manage: PipelineManage = None,
285381
padding_problem_text: str = None,
286382
chat_user_id=None, chat_user_type=None, no_references_setting=None,
287-
model_setting=None):
383+
model_setting=None,
384+
mcp_enable=False,
385+
mcp_tool_ids=None,
386+
mcp_servers='',
387+
mcp_source="referencing",
388+
tool_enable=False,
389+
tool_ids=None):
288390
reasoning_content_enable = model_setting.get('reasoning_content_enable', False)
289391
reasoning_content_start = model_setting.get('reasoning_content_start', '<think>')
290392
reasoning_content_end = model_setting.get('reasoning_content_end', '</think>')
@@ -294,7 +396,7 @@ def execute_block(self, message_list: List[BaseMessage],
294396
# 调用模型
295397
try:
296398
chat_result, is_ai_chat = self.get_block_result(message_list, chat_model, paragraph_list,
297-
no_references_setting, problem_text)
399+
no_references_setting, problem_text, mcp_enable, mcp_tool_ids, mcp_servers, mcp_source, tool_enable, tool_ids)
298400
if is_ai_chat:
299401
request_token = chat_model.get_num_tokens_from_messages(message_list)
300402
response_token = chat_model.get_num_tokens(chat_result.content)

apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py

Lines changed: 3 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -6,59 +6,28 @@
66
@date:2024/6/4 14:30
77
@desc:
88
"""
9-
import asyncio
109
import json
1110
import os
1211
import re
13-
import sys
1412
import time
15-
import traceback
1613
from functools import reduce
1714
from typing import List, Dict
1815

19-
import uuid_utils.compat as uuid
2016
from django.db.models import QuerySet
2117
from langchain.schema import HumanMessage, SystemMessage
22-
from langchain_core.messages import BaseMessage, AIMessage, AIMessageChunk, ToolMessage
23-
from langchain_mcp_adapters.client import MultiServerMCPClient
24-
from langgraph.prebuilt import create_react_agent
18+
from langchain_core.messages import BaseMessage, AIMessage
19+
2520

2621
from application.flow.i_step_node import NodeResult, INode
2722
from application.flow.step_node.ai_chat_step_node.i_chat_node import IChatNode
28-
from application.flow.tools import Reasoning
29-
from common.utils.logger import maxkb_logger
23+
from application.flow.tools import Reasoning, mcp_response_generator
3024
from common.utils.rsa_util import rsa_long_decrypt
3125
from common.utils.tool_code import ToolExecutor
3226
from maxkb.const import CONFIG
3327
from models_provider.models import Model
3428
from models_provider.tools import get_model_credential, get_model_instance_by_model_workspace_id
3529
from tools.models import Tool
3630

37-
tool_message_template = """
38-
<details>
39-
<summary>
40-
<strong>Called MCP Tool: <em>%s</em></strong>
41-
</summary>
42-
43-
%s
44-
45-
</details>
46-
47-
"""
48-
49-
tool_message_json_template = """
50-
```json
51-
%s
52-
```
53-
"""
54-
55-
56-
def generate_tool_message_template(name, context):
57-
if '```' in context:
58-
return tool_message_template % (name, context)
59-
else:
60-
return tool_message_template % (name, tool_message_json_template % (context))
61-
6231

6332
def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow, answer: str,
6433
reasoning_content: str):
@@ -122,39 +91,6 @@ def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INo
12291
_write_context(node_variable, workflow_variable, node, workflow, answer, reasoning_content)
12392

12493

125-
async def _yield_mcp_response(chat_model, message_list, mcp_servers):
126-
client = MultiServerMCPClient(json.loads(mcp_servers))
127-
tools = await client.get_tools()
128-
agent = create_react_agent(chat_model, tools)
129-
response = agent.astream({"messages": message_list}, stream_mode='messages')
130-
async for chunk in response:
131-
if isinstance(chunk[0], ToolMessage):
132-
content = generate_tool_message_template(chunk[0].name, chunk[0].content)
133-
chunk[0].content = content
134-
yield chunk[0]
135-
if isinstance(chunk[0], AIMessageChunk):
136-
yield chunk[0]
137-
138-
139-
def mcp_response_generator(chat_model, message_list, mcp_servers):
140-
loop = asyncio.new_event_loop()
141-
try:
142-
async_gen = _yield_mcp_response(chat_model, message_list, mcp_servers)
143-
while True:
144-
try:
145-
chunk = loop.run_until_complete(anext_async(async_gen))
146-
yield chunk
147-
except StopAsyncIteration:
148-
break
149-
except Exception as e:
150-
maxkb_logger.error(f'Exception: {e}', traceback.format_exc())
151-
finally:
152-
loop.close()
153-
154-
155-
async def anext_async(agen):
156-
return await agen.__anext__()
157-
15894

15995
def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow):
16096
"""

0 commit comments

Comments
 (0)