Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 34 additions & 1 deletion apps/common/utils/tool_code.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# coding=utf-8

import ast
import os
import pickle
import subprocess
Expand Down Expand Up @@ -83,6 +83,39 @@ def exec_code(self, code_str, keywords):
return result.get('data')
raise Exception(result.get('msg'))

def generate_mcp_server_code(self, _code):
self.validate_banned_keywords(_code)

# 解析代码,提取导入语句和函数定义
try:
tree = ast.parse(_code)
except SyntaxError:
return _code

imports = []
functions = []
other_code = []

for node in tree.body:
if isinstance(node, ast.Import) or isinstance(node, ast.ImportFrom):
imports.append(ast.unparse(node))
elif isinstance(node, ast.FunctionDef):
# 为函数添加 @mcp.tool() 装饰器
func_code = ast.unparse(node)
functions.append(f"@mcp.tool()\n{func_code}\n")
else:
other_code.append(ast.unparse(node))

# 构建完整的 MCP 服务器代码
code_parts = ["from mcp.server.fastmcp import FastMCP"]
code_parts.extend(imports)
code_parts.append(f"\nmcp = FastMCP(\"{uuid.uuid7()}\")\n")
code_parts.extend(other_code)
code_parts.extend(functions)
code_parts.append("\nmcp.run(transport=\"stdio\")\n")

return "\n".join(code_parts)

def _exec_sandbox(self, _code, _id):
exec_python_file = f'{self.sandbox_path}/execute/{_id}.py'
with open(exec_python_file, 'w') as file:
Expand Down
18 changes: 18 additions & 0 deletions apps/knowledge/migrations/0002_alter_file_source_type.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Generated by Django 5.2.4 on 2025-08-11 09:45

from django.db import migrations, models


class Migration(migrations.Migration):

dependencies = [
('knowledge', '0001_initial'),
]

operations = [
migrations.AlterField(
model_name='file',
name='source_type',
field=models.CharField(choices=[('KNOWLEDGE', 'Knowledge'), ('APPLICATION', 'Application'), ('TOOL', 'Tool'), ('DOCUMENT', 'Document'), ('CHAT', 'Chat'), ('SYSTEM', 'System'), ('TEMPORARY_30_MINUTE', 'Temporary 30 Minute'), ('TEMPORARY_120_MINUTE', 'Temporary 120 Minute'), ('TEMPORARY_1_DAY', 'Temporary 1 Day')], db_index=True, default='TEMPORARY_120_MINUTE', verbose_name='资源类型'),
),
]
18 changes: 18 additions & 0 deletions apps/tools/migrations/0002_alter_tool_tool_type.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Generated by Django 5.2.4 on 2025-08-11 09:37

from django.db import migrations, models


class Migration(migrations.Migration):

dependencies = [
('tools', '0001_initial'),
]

operations = [
migrations.AlterField(
model_name='tool',
name='tool_type',
field=models.CharField(choices=[('INTERNAL', '内置'), ('CUSTOM', '自定义'), ('MCP', 'MCP工具')], db_index=True, default='CUSTOM', max_length=20, verbose_name='工具类型'),
),
]
1 change: 1 addition & 0 deletions apps/tools/models/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class ToolScope(models.TextChoices):
class ToolType(models.TextChoices):
INTERNAL = "INTERNAL", '内置'
CUSTOM = "CUSTOM", "自定义"
MCP = "MCP", "MCP工具"


class Tool(AppModelMixin):
Expand Down
154 changes: 154 additions & 0 deletions apps/tools/serializers/tool.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
# -*- coding: utf-8 -*-
import asyncio
import io
import json
import os
import pickle
import re
from typing import Dict

import uuid_utils.compat as uuid
from django.core import validators
Expand All @@ -12,6 +14,7 @@
from django.http import HttpResponse
from django.utils import timezone
from django.utils.translation import gettext_lazy as _
from langchain_mcp_adapters.client import MultiServerMCPClient
from pylint.lint import Run
from pylint.reporters import JSON2Reporter
from rest_framework import serializers, status
Expand All @@ -22,6 +25,7 @@
from common.field.common import UploadedImageField
from common.result import result
from common.utils.common import get_file_content
from common.utils.logger import maxkb_logger
from common.utils.rsa_util import rsa_long_decrypt, rsa_long_encrypt
from common.utils.tool_code import ToolExecutor
from knowledge.models import File, FileSourceType
Expand Down Expand Up @@ -103,6 +107,18 @@ def encryption(message: str):
return pre_str + content + end_str


def validate_mcp_config(servers: Dict):
async def validate():
client = MultiServerMCPClient(servers)
await client.get_tools()

try:
asyncio.run(validate())
except Exception as e:
maxkb_logger.error(f"validate mcp config error: {e}, servers: {servers}")
raise serializers.ValidationError(_('MCP configuration is invalid'))


class ToolModelSerializer(serializers.ModelSerializer):
class Meta:
model = Tool
Expand Down Expand Up @@ -201,6 +217,131 @@ class PylintInstance(serializers.Serializer):


class ToolSerializer(serializers.Serializer):
class Query(serializers.Serializer):
workspace_id = serializers.CharField(required=True, label=_('workspace id'))
folder_id = serializers.CharField(required=False, allow_blank=True, allow_null=True, label=_('folder id'))
name = serializers.CharField(required=False, allow_null=True, allow_blank=True, label=_('tool name'))
user_id = serializers.UUIDField(required=False, allow_null=True, label=_('user id'))
scope = serializers.CharField(required=True, label=_('scope'))
tool_type = serializers.CharField(required=False, label=_('tool type'), allow_null=True, allow_blank=True)
create_user = serializers.UUIDField(required=False, label=_('create user'), allow_null=True)

def get_query_set(self, workspace_manage, is_x_pack_ee):
tool_query_set = QuerySet(Tool).filter(workspace_id=self.data.get('workspace_id'))
folder_query_set = QuerySet(ToolFolder)
default_query_set = QuerySet(Tool)

workspace_id = self.data.get('workspace_id')
user_id = self.data.get('user_id')
scope = self.data.get('scope')
tool_type = self.data.get('tool_type')
desc = self.data.get('desc')
name = self.data.get('name')
folder_id = self.data.get('folder_id')
create_user = self.data.get('create_user')

if workspace_id is not None:
folder_query_set = folder_query_set.filter(workspace_id=workspace_id)
default_query_set = default_query_set.filter(workspace_id=workspace_id)
if folder_id is not None:
folder_query_set = folder_query_set.filter(parent=folder_id)
default_query_set = default_query_set.filter(folder_id=folder_id)
if name is not None:
folder_query_set = folder_query_set.filter(name__icontains=name)
default_query_set = default_query_set.filter(name__icontains=name)
if desc is not None:
folder_query_set = folder_query_set.filter(desc__icontains=desc)
default_query_set = default_query_set.filter(desc__icontains=desc)
if create_user is not None:
tool_query_set = tool_query_set.filter(user_id=create_user)
folder_query_set = folder_query_set.filter(user_id=create_user)

default_query_set = default_query_set.order_by("-create_time")

if scope is not None:
tool_query_set = tool_query_set.filter(scope=scope)
if tool_type:
tool_query_set = tool_query_set.filter(tool_type=tool_type)

query_set_dict = {
'folder_query_set': folder_query_set,
'tool_query_set': tool_query_set,
'default_query_set': default_query_set,
}
if not workspace_manage:
query_set_dict['workspace_user_resource_permission_query_set'] = QuerySet(
WorkspaceUserResourcePermission).filter(
auth_target_type="TOOL",
workspace_id=workspace_id,
user_id=user_id
)
return query_set_dict

def get_authorized_query_set(self):
default_query_set = QuerySet(Tool)
tool_type = self.data.get('tool_type')
desc = self.data.get('desc')
name = self.data.get('name')
create_user = self.data.get('create_user')

default_query_set = default_query_set.filter(workspace_id='None')
default_query_set = default_query_set.filter(scope=ToolScope.SHARED)
if name is not None:
default_query_set = default_query_set.filter(name__icontains=name)
if desc is not None:
default_query_set = default_query_set.filter(desc__icontains=desc)
if create_user is not None:
default_query_set = default_query_set.filter(user_id=create_user)
if tool_type:
default_query_set = default_query_set.filter(tool_type=tool_type)

default_query_set = default_query_set.order_by("-create_time")

return default_query_set

@staticmethod
def is_x_pack_ee():
workspace_user_role_mapping_model = DatabaseModelManage.get_model("workspace_user_role_mapping")
role_permission_mapping_model = DatabaseModelManage.get_model("role_permission_mapping_model")
return workspace_user_role_mapping_model is not None and role_permission_mapping_model is not None

def get_tools(self):
self.is_valid(raise_exception=True)

workspace_manage = is_workspace_manage(self.data.get('user_id'), self.data.get('workspace_id'))
is_x_pack_ee = self.is_x_pack_ee()
results = native_search(
self.get_query_set(workspace_manage, is_x_pack_ee),
get_file_content(
os.path.join(
PROJECT_DIR,
"apps", "tools", 'sql',
'list_tool.sql' if workspace_manage else (
'list_tool_user_ee.sql' if is_x_pack_ee else 'list_tool_user.sql'
)
)
),
)

get_authorized_tool = DatabaseModelManage.get_model("get_authorized_tool")
shared_queryset = QuerySet(Tool).none()
if get_authorized_tool is not None:
shared_queryset = self.get_authorized_query_set()
shared_queryset = get_authorized_tool(shared_queryset, self.data.get('workspace_id'))

return {
'shared_tools': [
ToolModelSerializer(data).data for data in shared_queryset
],
'tools': [
{
**tool,
'input_field_list': json.loads(tool.get('input_field_list', '[]')),
'init_field_list': json.loads(tool.get('init_field_list', '[]')),
} for tool in results if tool['resource_type'] == 'tool'
],
}

class Create(serializers.Serializer):
user_id = serializers.UUIDField(required=True, label=_('user id'))
workspace_id = serializers.CharField(required=True, label=_('workspace id'))
Expand All @@ -212,6 +353,10 @@ def insert(self, instance, with_valid=True):
ToolCreateRequest(data=instance).is_valid(raise_exception=True)
# 校验代码是否包括禁止的关键字
ToolExecutor().validate_banned_keywords(instance.get('code', ''))
# 校验mcp json
if instance.get('tool_type') == ToolType.MCP.value:
validate_mcp_config(json.loads(instance.get('code')))

tool_id = uuid.uuid7()
Tool(
id=tool_id,
Expand All @@ -223,6 +368,7 @@ def insert(self, instance, with_valid=True):
input_field_list=instance.get('input_field_list', []),
init_field_list=instance.get('init_field_list', []),
scope=instance.get('scope', ToolScope.WORKSPACE),
tool_type=instance.get('tool_type', ToolType.CUSTOM),
folder_id=instance.get('folder_id', self.data.get('workspace_id')),
is_active=False
).save()
Expand Down Expand Up @@ -326,6 +472,10 @@ def edit(self, instance, with_valid=True):
ToolEditRequest(data=instance).is_valid(raise_exception=True)
# 校验代码是否包括禁止的关键字
ToolExecutor().validate_banned_keywords(instance.get('code', ''))
# 校验mcp json
if instance.get('tool_type') == ToolType.MCP.value:
validate_mcp_config(json.loads(instance.get('code')))

if not QuerySet(Tool).filter(id=self.data.get('id')).exists():
raise serializers.ValidationError(_('Tool not found'))

Expand Down Expand Up @@ -574,6 +724,7 @@ class Query(serializers.Serializer):
name = serializers.CharField(required=False, allow_null=True, allow_blank=True, label=_('tool name'))
user_id = serializers.UUIDField(required=False, allow_null=True, label=_('user id'))
scope = serializers.CharField(required=True, label=_('scope'))
tool_type = serializers.CharField(required=False, label=_('tool type'), allow_null=True, allow_blank=True)
create_user = serializers.UUIDField(required=False, label=_('create user'), allow_null=True)

def page_tool(self, current_page: int, page_size: int):
Expand Down Expand Up @@ -609,6 +760,7 @@ def get_query_set(self, workspace_manage, is_x_pack_ee):
workspace_id = self.data.get('workspace_id')
user_id = self.data.get('user_id')
scope = self.data.get('scope')
tool_type = self.data.get('tool_type')
desc = self.data.get('desc')
name = self.data.get('name')
folder_id = self.data.get('folder_id')
Expand All @@ -634,6 +786,8 @@ def get_query_set(self, workspace_manage, is_x_pack_ee):

if scope is not None:
tool_query_set = tool_query_set.filter(scope=scope)
if tool_type:
tool_query_set = tool_query_set.filter(tool_type=tool_type)

query_set_dict = {
'folder_query_set': folder_query_set,
Expand Down
1 change: 1 addition & 0 deletions apps/tools/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
path('workspace/<str:workspace_id>/tool/import', views.ToolView.Import.as_view()),
path('workspace/<str:workspace_id>/tool/pylint', views.ToolView.Pylint.as_view()),
path('workspace/<str:workspace_id>/tool/debug', views.ToolView.Debug.as_view()),
path('workspace/<str:workspace_id>/tool/tool_list', views.ToolView.Query.as_view()),
path('workspace/<str:workspace_id>/tool/<str:tool_id>', views.ToolView.Operate.as_view()),
path('workspace/<str:workspace_id>/tool/<str:tool_id>/edit_icon', views.ToolView.EditIcon.as_view()),
path('workspace/<str:workspace_id>/tool/<str:tool_id>/export', views.ToolView.Export.as_view()),
Expand Down
33 changes: 33 additions & 0 deletions apps/tools/views/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def get(self, request: Request, workspace_id: str):
'folder_id': request.query_params.get('folder_id'),
'name': request.query_params.get('name'),
'scope': request.query_params.get('scope', ToolScope.WORKSPACE),
'tool_type': request.query_params.get('tool_type'),
'user_id': request.user.id,
'create_user': request.query_params.get('create_user'),
}
Expand Down Expand Up @@ -209,11 +210,43 @@ def get(self, request: Request, workspace_id: str, current_page: int, page_size:
'folder_id': request.query_params.get('folder_id'),
'name': request.query_params.get('name'),
'scope': request.query_params.get('scope'),
'tool_type': request.query_params.get('tool_type'),
'user_id': request.user.id,
'create_user': request.query_params.get('create_user'),
}
).page_tool_with_folders(current_page, page_size))

class Query(APIView):
authentication_classes = [TokenAuth]

@extend_schema(
methods=['GET'],
description=_('Get tool list '),
summary=_('Get tool list'),
operation_id=_('Get tool list'), # type: ignore
parameters=ToolReadAPI.get_parameters(),
responses=ToolReadAPI.get_response(),
tags=[_('Tool')] # type: ignore
)
@has_permissions(
PermissionConstants.TOOL_READ.get_workspace_permission(),
PermissionConstants.TOOL_READ.get_workspace_permission_workspace_manage_role(),
RoleConstants.WORKSPACE_MANAGE.get_workspace_role(), RoleConstants.USER.get_workspace_role()
)
@log(menu='Tool', operate='Get tool list')
def get(self, request: Request, workspace_id: str):
return result.success(ToolSerializer.Query(
data={
'workspace_id': workspace_id,
'folder_id': request.query_params.get('folder_id'),
'name': request.query_params.get('name'),
'scope': request.query_params.get('scope'),
'tool_type': request.query_params.get('tool_type'),
'user_id': request.user.id,
'create_user': request.query_params.get('create_user'),
}
).get_tools())

class Import(APIView):
authentication_classes = [TokenAuth]
parser_classes = [MultiPartParser]
Expand Down
Loading
Loading