diff --git a/.github/workflows/build-and-push-vector-model.yml b/.github/workflows/build-and-push-vector-model.yml index 556a398b885..0c51e86f62b 100644 --- a/.github/workflows/build-and-push-vector-model.yml +++ b/.github/workflows/build-and-push-vector-model.yml @@ -5,7 +5,7 @@ on: inputs: dockerImageTag: description: 'Docker Image Tag' - default: 'v2.0.2' + default: 'v2.0.3' required: true architecture: description: 'Architecture' diff --git a/.gitignore b/.gitignore index cc289d0865d..17f102d33ba 100644 --- a/.gitignore +++ b/.gitignore @@ -188,4 +188,5 @@ apps/models_provider/impl/*/icon/ apps/models_provider/impl/tencent_model_provider/credential/stt.py apps/models_provider/impl/tencent_model_provider/model/stt.py tmp/ -config.yml \ No newline at end of file +config.yml +.SANDBOX_BANNED_HOSTS \ No newline at end of file diff --git a/apps/application/flow/step_node/loop_node/impl/base_loop_node.py b/apps/application/flow/step_node/loop_node/impl/base_loop_node.py index 233551e7ac9..92083a6da2e 100644 --- a/apps/application/flow/step_node/loop_node/impl/base_loop_node.py +++ b/apps/application/flow/step_node/loop_node/impl/base_loop_node.py @@ -224,11 +224,17 @@ def handler(self, workflow): class BaseLoopNode(ILoopNode): def save_context(self, details, workflow_manage): + self.context['loop_context_data'] = details.get('loop_context_data') + self.context['loop_answer_data'] = details.get('loop_answer_data') + self.context['loop_node_data'] = details.get('loop_node_data') self.context['result'] = details.get('result') - for key, value in details['context'].items(): - if key not in self.context: - self.context[key] = value - self.answer_text = str(details.get('result')) + self.context['params'] = details.get('params') + self.context['run_time'] = details.get('run_time') + self.context['index'] = details.get('current_index') + self.context['item'] = details.get('current_item') + for key, value in (details.get('loop_context_data') or {}).items(): + self.context[key] = value + self.answer_text = "" def get_answer_list(self) -> List[Answer] | None: result = [] diff --git a/apps/application/flow/tools.py b/apps/application/flow/tools.py index 527b11ce841..feb8b62dc20 100644 --- a/apps/application/flow/tools.py +++ b/apps/application/flow/tools.py @@ -8,7 +8,8 @@ """ import asyncio import json -import traceback +import queue +import threading from typing import Iterator from django.http import StreamingHttpResponse @@ -227,6 +228,30 @@ def generate_tool_message_template(name, context): return tool_message_template % (name, tool_message_json_template % (context)) +# 全局单例事件循环 +_global_loop = None +_loop_thread = None +_loop_lock = threading.Lock() + + +def get_global_loop(): + """获取全局共享的事件循环""" + global _global_loop, _loop_thread + + with _loop_lock: + if _global_loop is None: + _global_loop = asyncio.new_event_loop() + + def run_forever(): + asyncio.set_event_loop(_global_loop) + _global_loop.run_forever() + + _loop_thread = threading.Thread(target=run_forever, daemon=True, name="GlobalAsyncLoop") + _loop_thread.start() + + return _global_loop + + async def _yield_mcp_response(chat_model, message_list, mcp_servers, mcp_output_enable=True): client = MultiServerMCPClient(json.loads(mcp_servers)) tools = await client.get_tools() @@ -242,19 +267,31 @@ async def _yield_mcp_response(chat_model, message_list, mcp_servers, mcp_output_ def mcp_response_generator(chat_model, message_list, mcp_servers, mcp_output_enable=True): - loop = asyncio.new_event_loop() - try: - async_gen = _yield_mcp_response(chat_model, message_list, mcp_servers, mcp_output_enable) - while True: - try: - chunk = loop.run_until_complete(anext_async(async_gen)) - yield chunk - except StopAsyncIteration: - break - except Exception as e: - maxkb_logger.error(f'Exception: {e}', exc_info=True) - finally: - loop.close() + """使用全局事件循环,不创建新实例""" + result_queue = queue.Queue() + loop = get_global_loop() # 使用共享循环 + + async def _run(): + try: + async_gen = _yield_mcp_response(chat_model, message_list, mcp_servers, mcp_output_enable) + async for chunk in async_gen: + result_queue.put(('data', chunk)) + except Exception as e: + maxkb_logger.error(f'Exception: {e}', exc_info=True) + result_queue.put(('error', e)) + finally: + result_queue.put(('done', None)) + + # 在全局循环中调度任务 + asyncio.run_coroutine_threadsafe(_run(), loop) + + while True: + msg_type, data = result_queue.get() + if msg_type == 'done': + break + if msg_type == 'error': + raise data + yield data async def anext_async(agen): diff --git a/apps/application/flow/workflow_manage.py b/apps/application/flow/workflow_manage.py index d35173ac92c..343d6f9aa2d 100644 --- a/apps/application/flow/workflow_manage.py +++ b/apps/application/flow/workflow_manage.py @@ -234,25 +234,62 @@ def run_block(self, language='zh'): 非流式响应 @return: 结果 """ - self.run_chain_async(None, None, language) - while self.is_run(): - pass - details = self.get_runtime_details() - message_tokens = sum([row.get('message_tokens') for row in details.values() if - 'message_tokens' in row and row.get('message_tokens') is not None]) - answer_tokens = sum([row.get('answer_tokens') for row in details.values() if - 'answer_tokens' in row and row.get('answer_tokens') is not None]) - answer_text_list = self.get_answer_text_list() - answer_text = '\n\n'.join( - '\n\n'.join([a.get('content') for a in answer]) for answer in - answer_text_list) - answer_list = reduce(lambda pre, _n: [*pre, *_n], answer_text_list, []) - self.work_flow_post_handler.handler(self) - return self.base_to_response.to_block_response(self.params['chat_id'], - self.params['chat_record_id'], answer_text, True - , message_tokens, answer_tokens, - _status=status.HTTP_200_OK if self.status == 200 else status.HTTP_500_INTERNAL_SERVER_ERROR, - other_params={'answer_list': answer_list}) + try: + self.run_chain_async(None, None, language) + while self.is_run(): + pass + details = self.get_runtime_details() + message_tokens = sum([row.get('message_tokens') for row in details.values() if + 'message_tokens' in row and row.get('message_tokens') is not None]) + answer_tokens = sum([row.get('answer_tokens') for row in details.values() if + 'answer_tokens' in row and row.get('answer_tokens') is not None]) + answer_text_list = self.get_answer_text_list() + answer_text = '\n\n'.join( + '\n\n'.join([a.get('content') for a in answer]) for answer in + answer_text_list) + answer_list = reduce(lambda pre, _n: [*pre, *_n], answer_text_list, []) + self.work_flow_post_handler.handler(self) + + res = self.base_to_response.to_block_response(self.params['chat_id'], + self.params['chat_record_id'], answer_text, True + , message_tokens, answer_tokens, + _status=status.HTTP_200_OK if self.status == 200 else status.HTTP_500_INTERNAL_SERVER_ERROR, + other_params={'answer_list': answer_list}) + finally: + self._cleanup() + return res + + def _cleanup(self): + """清理所有对象引用""" + # 清理列表 + self.future_list.clear() + self.field_list.clear() + self.global_field_list.clear() + self.chat_field_list.clear() + self.image_list.clear() + self.video_list.clear() + self.document_list.clear() + self.audio_list.clear() + self.other_list.clear() + if hasattr(self, 'node_context'): + self.node_context.clear() + + # 清理字典 + self.context.clear() + self.chat_context.clear() + self.form_data.clear() + + # 清理对象引用 + self.node_chunk_manage = None + self.work_flow_post_handler = None + self.flow = None + self.start_node = None + self.current_node = None + self.current_result = None + self.chat_record = None + self.base_to_response = None + self.params = None + self.lock = None def run_stream(self, current_node, node_result_future, language='zh'): """ @@ -307,6 +344,7 @@ def await_result(self): '', [], '', True, message_tokens, answer_tokens, {}) + self._cleanup() def run_chain_async(self, current_node, node_result_future, language='zh'): future = executor.submit(self.run_chain_manage, current_node, node_result_future, language) diff --git a/apps/application/serializers/common.py b/apps/application/serializers/common.py index 10d7efb2e5e..e8ac1080914 100644 --- a/apps/application/serializers/common.py +++ b/apps/application/serializers/common.py @@ -6,7 +6,6 @@ @date:2025/6/9 13:42 @desc: """ -from datetime import datetime from typing import List from django.core.cache import cache @@ -226,10 +225,66 @@ def append_chat_record(self, chat_record: ChatRecord): chat_record.save() ChatCountSerializer(data={'chat_id': self.chat_id}).update_chat() + def to_dict(self): + + return { + 'chat_id': self.chat_id, + 'chat_user_id': self.chat_user_id, + 'chat_user_type': self.chat_user_type, + 'knowledge_id_list': self.knowledge_id_list, + 'exclude_document_id_list': self.exclude_document_id_list, + 'application_id': self.application_id, + 'chat_record_list': [self.chat_record_to_map(c) for c in self.chat_record_list], + 'debug': self.debug + } + + def chat_record_to_map(self, chat_record): + return {'id': chat_record.id, + 'chat_id': chat_record.chat_id, + 'vote_status': chat_record.vote_status, + 'problem_text': chat_record.problem_text, + 'answer_text': chat_record.answer_text, + 'answer_text_list': chat_record.answer_text_list, + 'message_tokens': chat_record.message_tokens, + 'answer_tokens': chat_record.answer_tokens, + 'const': chat_record.const, + 'details': chat_record.details, + 'improve_paragraph_id_list': chat_record.improve_paragraph_id_list, + 'run_time': chat_record.run_time, + 'index': chat_record.index} + + @staticmethod + def map_to_chat_record(chat_record_dict): + ChatRecord(id=chat_record_dict.get('id'), + chat_id=chat_record_dict.get('chat_id'), + vote_status=chat_record_dict.get('vote_status'), + problem_text=chat_record_dict.get('problem_text'), + answer_text=chat_record_dict.get('answer_text'), + answer_text_list=chat_record_dict.get('answer_text_list'), + message_tokens=chat_record_dict.get('message_tokens'), + answer_tokens=chat_record_dict.get('answer_tokens'), + const=chat_record_dict.get('const'), + details=chat_record_dict.get('details'), + improve_paragraph_id_list=chat_record_dict.get('improve_paragraph_id_list'), + run_time=chat_record_dict.get('run_time'), + index=chat_record_dict.get('index'), ) + def set_cache(self): - cache.set(Cache_Version.CHAT.get_key(key=self.chat_id), self, version=Cache_Version.CHAT.get_version(), + cache.set(Cache_Version.CHAT.get_key(key=self.chat_id), self.to_dict(), + version=Cache_Version.CHAT_INFO.get_version(), timeout=60 * 30) + @staticmethod + def map_to_chat_info(chat_info_dict): + return ChatInfo(chat_info_dict.get('chat_id'), chat_info_dict.get('chat_user_id'), + chat_info_dict.get('chat_user_type'), chat_info_dict.get('knowledge_id_list'), + chat_info_dict.get('exclude_document_id_list'), + chat_info_dict.get('application_id'), + [ChatInfo.map_to_chat_record(c_r) for c_r in chat_info_dict.get('chat_record_list')]) + @staticmethod def get_cache(chat_id): - return cache.get(Cache_Version.CHAT.get_key(key=chat_id), version=Cache_Version.CHAT.get_version()) + chat_info_dict = cache.get(Cache_Version.CHAT.get_key(key=chat_id), version=Cache_Version.CHAT_INFO.get_version()) + if chat_info_dict: + return ChatInfo.map_to_chat_info(chat_info_dict) + return None diff --git a/apps/application/views/application_chat.py b/apps/application/views/application_chat.py index 381bde9ddda..d60206ded5a 100644 --- a/apps/application/views/application_chat.py +++ b/apps/application/views/application_chat.py @@ -125,8 +125,8 @@ class OpenView(APIView): responses=None, tags=[_('Application')] # type: ignore ) - @has_permissions(PermissionConstants.APPLICATION_EDIT.get_workspace_application_permission(), - PermissionConstants.APPLICATION_EDIT.get_workspace_permission_workspace_manage_role(), + @has_permissions(PermissionConstants.APPLICATION_READ.get_workspace_application_permission(), + PermissionConstants.APPLICATION_READ.get_workspace_permission_workspace_manage_role(), ViewPermission([RoleConstants.USER.get_workspace_role()], [PermissionConstants.APPLICATION.get_workspace_application_permission()], CompareConstants.AND), @@ -167,8 +167,8 @@ class PromptGenerateView(APIView): responses=None, tags=[_('Application')] # type: ignore ) - @has_permissions(PermissionConstants.APPLICATION_EDIT.get_workspace_application_permission(), - PermissionConstants.APPLICATION_EDIT.get_workspace_permission_workspace_manage_role(), + @has_permissions(PermissionConstants.APPLICATION_READ.get_workspace_application_permission(), + PermissionConstants.APPLICATION_READ.get_workspace_permission_workspace_manage_role(), ViewPermission([RoleConstants.USER.get_workspace_role()], [PermissionConstants.APPLICATION.get_workspace_application_permission()], CompareConstants.AND), diff --git a/apps/application/views/application_chat_record.py b/apps/application/views/application_chat_record.py index dbcb246e45b..0d59146b29d 100644 --- a/apps/application/views/application_chat_record.py +++ b/apps/application/views/application_chat_record.py @@ -93,8 +93,8 @@ class ApplicationChatRecordOperateAPI(APIView): ) @has_permissions(PermissionConstants.APPLICATION_CHAT_LOG_READ.get_workspace_application_permission(), PermissionConstants.APPLICATION_CHAT_LOG_READ.get_workspace_permission_workspace_manage_role(), - PermissionConstants.APPLICATION_EDIT.get_workspace_application_permission(), - PermissionConstants.APPLICATION_EDIT.get_workspace_permission_workspace_manage_role(), + PermissionConstants.APPLICATION_READ.get_workspace_application_permission(), + PermissionConstants.APPLICATION_READ.get_workspace_permission_workspace_manage_role(), ViewPermission([RoleConstants.USER.get_workspace_role()], [PermissionConstants.APPLICATION.get_workspace_application_permission()], CompareConstants.AND), diff --git a/apps/common/auth/common.py b/apps/common/auth/common.py index 40158f7d239..ad8e0e50a48 100644 --- a/apps/common/auth/common.py +++ b/apps/common/auth/common.py @@ -45,7 +45,7 @@ def to_string(self): value = json.dumps(self.to_dict()) authentication = encrypt(value) cache_key = hashlib.sha256(authentication.encode()).hexdigest() - authentication_cache.set(cache_key, value, version=Cache_Version.CHAT.value, timeout=60 * 60 * 2) + authentication_cache.set(cache_key, value, version=Cache_Version.CHAT.get_version(), timeout=60 * 60 * 2) return authentication @staticmethod diff --git a/apps/common/config/tokenizer_manage_config.py b/apps/common/config/tokenizer_manage_config.py index 47a6d61e902..9a3ae73f2c2 100644 --- a/apps/common/config/tokenizer_manage_config.py +++ b/apps/common/config/tokenizer_manage_config.py @@ -7,18 +7,24 @@ @desc: """ +import os + +class MKTokenizer: + def __init__(self, tokenizer): + self.tokenizer = tokenizer + + def encode(self, text): + return self.tokenizer.encode(text).ids + class TokenizerManage: tokenizer = None @staticmethod def get_tokenizer(): - from transformers import BertTokenizer - if TokenizerManage.tokenizer is None: - TokenizerManage.tokenizer = BertTokenizer.from_pretrained( - 'bert-base-cased', - cache_dir="/opt/maxkb-app/model/tokenizer", - local_files_only=True, - resume_download=False, - force_download=False) - return TokenizerManage.tokenizer + from tokenizers import Tokenizer + # 创建Tokenizer + model_path = os.path.join("/opt/maxkb-app", "model", "tokenizer", "models--bert-base-cased") + with open(f"{model_path}/refs/main", encoding="utf-8") as f: snapshot = f.read() + TokenizerManage.tokenizer = Tokenizer.from_file(f"{model_path}/snapshots/{snapshot}/tokenizer.json") + return MKTokenizer(TokenizerManage.tokenizer) diff --git a/apps/common/constants/cache_version.py b/apps/common/constants/cache_version.py index 2cf889c17d5..6664acb5623 100644 --- a/apps/common/constants/cache_version.py +++ b/apps/common/constants/cache_version.py @@ -30,6 +30,8 @@ class Cache_Version(Enum): # 对话 CHAT = "CHAT", lambda key: key + CHAT_INFO = "CHAT_INFO", lambda key: key + CHAT_VARIABLE = "CHAT_VARIABLE", lambda key: key # 应用API KEY diff --git a/apps/common/event/__init__.py b/apps/common/event/__init__.py index 59a61e9abd5..e8f0b3d819b 100644 --- a/apps/common/event/__init__.py +++ b/apps/common/event/__init__.py @@ -7,9 +7,10 @@ @desc: """ from django.core.cache import cache +from django.db.models import QuerySet from django.utils.translation import gettext as _ -from .listener_manage import * + from ..constants.cache_version import Cache_Version from ..db.sql_execute import update_execute from ..utils.lock import RedisLock diff --git a/apps/common/handle/impl/text/zip_split_handle.py b/apps/common/handle/impl/text/zip_split_handle.py index 5752fe0d753..6609a981c33 100644 --- a/apps/common/handle/impl/text/zip_split_handle.py +++ b/apps/common/handle/impl/text/zip_split_handle.py @@ -15,7 +15,6 @@ import uuid_utils.compat as uuid from charset_normalizer import detect -from django.db.models import QuerySet from django.utils.translation import gettext_lazy as _ from common.handle.base_split_handle import BaseSplitHandle @@ -39,7 +38,6 @@ def get_buffer(self, file): return self.buffer - default_split_handle = TextSplitHandle() split_handles = [ HTMLSplitHandle(), diff --git a/apps/common/management/commands/services/command.py b/apps/common/management/commands/services/command.py index a3fa3d72ffb..5dfb7570df0 100644 --- a/apps/common/management/commands/services/command.py +++ b/apps/common/management/commands/services/command.py @@ -1,18 +1,16 @@ import math +import os from django.core.management.base import BaseCommand from django.db.models import TextChoices -from .hands import * from .utils import ServicesUtil -import os class Services(TextChoices): gunicorn = 'gunicorn', 'gunicorn' celery_default = 'celery_default', 'celery_default' local_model = 'local_model', 'local_model' - scheduler = 'scheduler', 'scheduler' web = 'web', 'web' celery = 'celery', 'celery' celery_model = 'celery_model', 'celery_model' @@ -26,7 +24,6 @@ def get_service_object_class(cls, name): cls.gunicorn.value: services.GunicornService, cls.celery_default: services.CeleryDefaultService, cls.local_model: services.GunicornLocalModelService, - cls.scheduler: services.SchedulerService, } return services_map.get(name) @@ -42,13 +39,10 @@ def celery_services(cls): def task_services(cls): return cls.celery_services() - @classmethod - def scheduler_services(cls): - return [cls.scheduler] @classmethod def all_services(cls): - return cls.web_services() + cls.task_services() + cls.scheduler_services() + return cls.web_services() + cls.task_services() @classmethod def export_services_values(cls): @@ -102,7 +96,7 @@ def add_arguments(self, parser): ) parser.add_argument('-d', '--daemon', nargs="?", const=True) parser.add_argument('-w', '--worker', type=int, nargs="?", - default=2 if os.cpu_count() > 6 else math.floor(os.cpu_count() / 2)) + default=3 if os.cpu_count() > 6 else max(1, math.floor(os.cpu_count() / 2))) parser.add_argument('-f', '--force', nargs="?", const=True) def initial_util(self, *args, **options): diff --git a/apps/common/management/commands/services/services/celery_default.py b/apps/common/management/commands/services/services/celery_default.py index 5d3e6d7b8a4..f8f4b54175e 100644 --- a/apps/common/management/commands/services/services/celery_default.py +++ b/apps/common/management/commands/services/services/celery_default.py @@ -1,4 +1,8 @@ +import os +import subprocess + from .celery_base import CeleryBaseService +from django.conf import settings __all__ = ['CeleryDefaultService'] @@ -8,3 +12,20 @@ class CeleryDefaultService(CeleryBaseService): def __init__(self, **kwargs): kwargs['queue'] = 'celery' super().__init__(**kwargs) + + def open_subprocess(self): + env = os.environ.copy() + env['LC_ALL'] = 'C.UTF-8' + env['PYTHONOPTIMIZE'] = '1' + env['ANSIBLE_FORCE_COLOR'] = 'True' + env['PYTHONPATH'] = settings.APPS_DIR + env['SERVER_NAME'] = 'celery' + if os.getuid() == 0: + env.setdefault('C_FORCE_ROOT', '1') + kwargs = { + 'cwd': self.cwd, + 'stderr': self.log_file, + 'stdout': self.log_file, + 'env': env + } + self._process = subprocess.Popen(self.cmd, **kwargs) diff --git a/apps/common/management/commands/services/services/gunicorn.py b/apps/common/management/commands/services/services/gunicorn.py index a0c89a920bd..4a72339206d 100644 --- a/apps/common/management/commands/services/services/gunicorn.py +++ b/apps/common/management/commands/services/services/gunicorn.py @@ -1,3 +1,5 @@ +import subprocess + from .base import BaseService from ..hands import * @@ -16,15 +18,17 @@ def cmd(self): log_format = '%(h)s %(t)s %(L)ss "%(r)s" %(s)s %(b)s ' bind = f'{HTTP_HOST}:{HTTP_PORT}' + max_requests = 10240 if int(self.worker) > 1 else 0 cmd = [ 'gunicorn', 'maxkb.wsgi:application', '-b', bind, - '--preload', '-k', 'gthread', '--threads', '200', '-w', str(self.worker), - '--max-requests', '10240', + '--max-requests', str(max_requests), '--max-requests-jitter', '2048', + '--timeout', '0', + '--graceful-timeout', '0', '--access-logformat', log_format, '--access-logfile', '/dev/null', '--error-logfile', '-' @@ -36,3 +40,15 @@ def cmd(self): @property def cwd(self): return APPS_DIR + + def open_subprocess(self): + # 复制当前环境变量,并设置 ENABLE_SCHEDULER=1 + env = os.environ.copy() + env['SERVER_NAME'] = 'web' + kwargs = { + 'cwd': self.cwd, + 'stderr': self.log_file, + 'stdout': self.log_file, + 'env': env + } + self._process = subprocess.Popen(self.cmd, **kwargs) diff --git a/apps/common/management/commands/services/services/local_model.py b/apps/common/management/commands/services/services/local_model.py index 8383a58445c..a37f0d16b04 100644 --- a/apps/common/management/commands/services/services/local_model.py +++ b/apps/common/management/commands/services/services/local_model.py @@ -6,6 +6,8 @@ @date:2024/8/21 13:28 @desc: """ +import subprocess + from maxkb.const import CONFIG from .base import BaseService from ..hands import * @@ -22,19 +24,20 @@ def __init__(self, **kwargs): @property def cmd(self): print("\n- Start Gunicorn Local Model WSGI HTTP Server") - os.environ.setdefault('SERVER_NAME', 'local_model') log_format = '%(h)s %(t)s %(L)ss "%(r)s" %(s)s %(b)s ' bind = f'{CONFIG.get("LOCAL_MODEL_HOST")}:{CONFIG.get("LOCAL_MODEL_PORT")}' worker = CONFIG.get("LOCAL_MODEL_HOST_WORKER", 1) + max_requests = 10240 if int(worker) > 1 else 0 cmd = [ 'gunicorn', 'maxkb.wsgi:application', '-b', bind, - '--preload', '-k', 'gthread', '--threads', '200', '-w', str(worker), - '--max-requests', '10240', + '--max-requests', str(max_requests), '--max-requests-jitter', '2048', + '--timeout', '0', + '--graceful-timeout', '0', '--access-logformat', log_format, '--access-logfile', '/dev/null', '--error-logfile', '-' @@ -46,3 +49,15 @@ def cmd(self): @property def cwd(self): return APPS_DIR + + def open_subprocess(self): + # 复制当前环境变量,并设置 ENABLE_SCHEDULER=1 + env = os.environ.copy() + env['SERVER_NAME'] = 'local_model' + kwargs = { + 'cwd': self.cwd, + 'stderr': self.log_file, + 'stdout': self.log_file, + 'env': env + } + self._process = subprocess.Popen(self.cmd, **kwargs) diff --git a/apps/common/management/commands/services/services/scheduler.py b/apps/common/management/commands/services/services/scheduler.py index d82e5b91920..dcef849b8e1 100644 --- a/apps/common/management/commands/services/services/scheduler.py +++ b/apps/common/management/commands/services/services/scheduler.py @@ -18,15 +18,17 @@ def cmd(self): log_format = '%(h)s %(t)s %(L)ss "%(r)s" %(s)s %(b)s ' bind = f'127.0.0.1:6060' + max_requests = 10240 if int(self.worker) > 1 else 0 cmd = [ 'gunicorn', 'maxkb.wsgi:application', '-b', bind, - '--preload', '-k', 'gthread', '--threads', '200', '-w', str(self.worker), - '--max-requests', '10240', + '--max-requests', str(max_requests), '--max-requests-jitter', '2048', + '--timeout', '0', + '--graceful-timeout', '0', '--access-logformat', log_format, '--access-logfile', '/dev/null', '--error-logfile', '-' diff --git a/apps/common/utils/tool_code.py b/apps/common/utils/tool_code.py index 8292ed39db6..2e6abd4c65c 100644 --- a/apps/common/utils/tool_code.py +++ b/apps/common/utils/tool_code.py @@ -5,21 +5,20 @@ import subprocess import sys from textwrap import dedent - +import socket import uuid_utils.compat as uuid from django.utils.translation import gettext_lazy as _ - from maxkb.const import BASE_DIR, CONFIG from maxkb.const import PROJECT_DIR +from common.utils.logger import maxkb_logger python_directory = sys.executable - class ToolExecutor: def __init__(self, sandbox=False): self.sandbox = sandbox if sandbox: - self.sandbox_path = '/opt/maxkb-app/sandbox' + self.sandbox_path = CONFIG.get("SANDBOX_HOME", '/opt/maxkb-app/sandbox') self.user = 'sandbox' else: self.sandbox_path = os.path.join(PROJECT_DIR, 'data', 'sandbox') @@ -28,6 +27,21 @@ def __init__(self, sandbox=False): if self.sandbox: os.system(f"chown -R {self.user}:root {self.sandbox_path}") self.banned_keywords = CONFIG.get("SANDBOX_PYTHON_BANNED_KEYWORDS", 'nothing_is_banned').split(','); + try: + banned_hosts_file_path = f'{self.sandbox_path}/.SANDBOX_BANNED_HOSTS' + if os.path.exists(banned_hosts_file_path): + os.remove(banned_hosts_file_path) + banned_hosts = CONFIG.get("SANDBOX_PYTHON_BANNED_HOSTS", '').strip() + if banned_hosts: + hostname = socket.gethostname() + local_ip = socket.gethostbyname(hostname) + banned_hosts = f"{banned_hosts},{hostname},{local_ip}" + with open(banned_hosts_file_path, "w") as f: + f.write(banned_hosts) + os.chmod(banned_hosts_file_path, 0o644) + except Exception as e: + maxkb_logger.error(f'Failed to init SANDBOX_BANNED_HOSTS due to exception: {e}', exc_info=True) + pass def _createdir(self): old_mask = os.umask(0o077) @@ -53,13 +67,9 @@ def exec_code(self, code_str, keywords): path_to_exclude = ['/opt/py3/lib/python3.11/site-packages', '/opt/maxkb-app/apps'] sys.path = [p for p in sys.path if p not in path_to_exclude] sys.path += {python_paths} - env = dict(os.environ) - for key in list(env.keys()): - if key in os.environ and (key.startswith('MAXKB') or key.startswith('POSTGRES') or key.startswith('PG') or key.startswith('REDIS') or key == 'PATH'): - del os.environ[key] locals_v={'{}'} keywords={keywords} - globals_v=globals() + globals_v={'{}'} exec({dedent(code_str)!a}, globals_v, locals_v) f_name, f = locals_v.popitem() for local in locals_v: @@ -163,10 +173,6 @@ def generate_mcp_server_code(self, code_str, params): path_to_exclude = ['/opt/py3/lib/python3.11/site-packages', '/opt/maxkb-app/apps'] sys.path = [p for p in sys.path if p not in path_to_exclude] sys.path += {python_paths} -env = dict(os.environ) -for key in list(env.keys()): - if key in os.environ and (key.startswith('MAXKB') or key.startswith('POSTGRES') or key.startswith('PG') or key.startswith('REDIS') or key == 'PATH'): - del os.environ[key] exec({dedent(code)!a}) """ @@ -188,6 +194,9 @@ def get_tool_mcp_config(self, code, params): self.user, ], 'cwd': self.sandbox_path, + 'env': { + 'LD_PRELOAD': f'{self.sandbox_path}/sandbox.so', + }, 'transport': 'stdio', } else: @@ -204,6 +213,9 @@ def _exec_sandbox(self, _code, _id): file.write(_code) os.system(f"chown {self.user}:root {exec_python_file}") kwargs = {'cwd': BASE_DIR} + kwargs['env'] = { + 'LD_PRELOAD': f'{self.sandbox_path}/sandbox.so', + } subprocess_result = subprocess.run( ['su', '-s', python_directory, '-c', "exec(open('" + exec_python_file + "').read())", self.user], text=True, diff --git a/apps/folders/serializers/folder.py b/apps/folders/serializers/folder.py index 73afd65be37..790a913cafa 100644 --- a/apps/folders/serializers/folder.py +++ b/apps/folders/serializers/folder.py @@ -57,7 +57,7 @@ def get_folder_tree_serializer(source): return None -FOLDER_DEPTH = 2 # Folder 不能超过3层 +FOLDER_DEPTH = 10000 def check_depth(source, parent_id, workspace_id, current_depth=0): @@ -79,7 +79,7 @@ def check_depth(source, parent_id, workspace_id, current_depth=0): # 验证层级深度 if depth + current_depth > FOLDER_DEPTH: - raise serializers.ValidationError(_('Folder depth cannot exceed 3 levels')) + raise serializers.ValidationError(_('Folder depth cannot exceed 10000 levels')) def get_max_depth(current_node): @@ -100,6 +100,12 @@ def get_max_depth(current_node): return max_depth +def has_target_permission(workspace_id, source, user_id, target): + return QuerySet(WorkspaceUserResourcePermission).filter(workspace_id=workspace_id, user_id=user_id, + auth_target_type=source, target=target, + permission_list__contains=['MANAGE']).exists() + + class FolderSerializer(serializers.Serializer): id = serializers.CharField(required=True, label=_('folder id')) name = serializers.CharField(required=True, label=_('folder name')) @@ -183,13 +189,27 @@ def edit(self, instance): field in instance and instance.get(field) is not None)} QuerySet(Folder).filter(id=current_id).update(**edit_dict) + current_node.refresh_from_db() if parent_id is not None and current_id != current_node.workspace_id and current_node.parent_id != parent_id: - # Folder 不能超过3层 - current_depth = get_max_depth(current_node) - check_depth(self.data.get('source'), parent_id, current_node.workspace_id, current_depth) - parent = Folder.objects.get(id=parent_id) - current_node.move_to(parent) + + source_type = self.data.get('source') + if has_target_permission(current_node.workspace_id, source_type, self.data.get('user_id'), + parent_id) or is_workspace_manage(self.data.get('user_id'), + current_node.workspace_id): + current_depth = get_max_depth(current_node) + check_depth(self.data.get('source'), parent_id, current_node.workspace_id, current_depth) + parent = Folder.objects.get(id=parent_id) + + if QuerySet(Folder).filter(name=current_node.name, parent_id=parent_id, + workspace_id=current_node.workspace_id).exists(): + raise serializers.ValidationError(_('Folder name already exists')) + + current_node.parent = parent + current_node.save() + current_node.refresh_from_db() + else: + raise AppApiException(403, _('No permission for the target folder')) return self.one() diff --git a/apps/knowledge/serializers/document.py b/apps/knowledge/serializers/document.py index 292dbe81bdc..8f568c11081 100644 --- a/apps/knowledge/serializers/document.py +++ b/apps/knowledge/serializers/document.py @@ -25,7 +25,7 @@ from xlwt import Utils from common.db.search import native_search, get_dynamics_model, native_page_search -from common.event import ListenerManagement +from common.event.listener_manage import ListenerManagement from common.event.common import work_thread_pool from common.exception.app_exception import AppApiException from common.field.common import UploadedFileField @@ -683,10 +683,11 @@ def delete(self): ] QuerySet(File).filter(id__in=source_file_ids).delete() QuerySet(File).filter(source_id=document_id, source_type=FileSourceType.DOCUMENT).delete() + paragraph_ids = QuerySet(model=Paragraph).filter(document_id=document_id).values_list("id", flat=True) + # 删除问题 + delete_problems_and_mappings(paragraph_ids) # 删除段落 QuerySet(model=Paragraph).filter(document_id=document_id).delete() - # 删除问题 - delete_problems_and_mappings([document_id]) # 删除向量库 delete_embedding_by_document(document_id) QuerySet(model=DocumentTag).filter(document_id=document_id).delete() @@ -1217,9 +1218,12 @@ def batch_delete(self, instance: Dict, with_valid=True): Document.objects.filter(id__in=document_id_list).values("meta")] QuerySet(File).filter(id__in=source_file_ids).delete() QuerySet(Document).filter(id__in=document_id_list).delete() - QuerySet(Paragraph).filter(document_id__in=document_id_list).delete() QuerySet(DocumentTag).filter(document_id__in=document_id_list).delete() - delete_problems_and_mappings(document_id_list) + paragraph_ids = QuerySet(Paragraph).filter(document_id__in=document_id_list).values_list("id", flat=True) + # 删除问题关系 + delete_problems_and_mappings(paragraph_ids) + # 删除段落 + QuerySet(Paragraph).filter(document_id__in=document_id_list).delete() # 删除向量库 delete_embedding_by_document_list(document_id_list) return True diff --git a/apps/knowledge/serializers/knowledge.py b/apps/knowledge/serializers/knowledge.py index 0f98a42a0f5..9daeb7dad8f 100644 --- a/apps/knowledge/serializers/knowledge.py +++ b/apps/knowledge/serializers/knowledge.py @@ -24,7 +24,7 @@ from common.database_model_manage.database_model_manage import DatabaseModelManage from common.db.search import native_search, get_dynamics_model, native_page_search from common.db.sql_execute import select_list -from common.event import ListenerManagement +from common.event.listener_manage import ListenerManagement from common.exception.app_exception import AppApiException from common.utils.common import post, get_file_content, parse_image from common.utils.fork import Fork, ChildLink diff --git a/apps/knowledge/serializers/paragraph.py b/apps/knowledge/serializers/paragraph.py index 28f1aa10bf0..a1df1bd8100 100644 --- a/apps/knowledge/serializers/paragraph.py +++ b/apps/knowledge/serializers/paragraph.py @@ -10,7 +10,7 @@ from rest_framework import serializers from common.db.search import page_search -from common.event import ListenerManagement +from common.event.listener_manage import ListenerManagement from common.exception.app_exception import AppApiException from common.utils.common import post from knowledge.models import Paragraph, Problem, Document, ProblemParagraphMapping, SourceType, TaskType, State, \ diff --git a/apps/knowledge/sql/list_problem.sql b/apps/knowledge/sql/list_problem.sql index affb51334fe..90f82dc6dd3 100644 --- a/apps/knowledge/sql/list_problem.sql +++ b/apps/knowledge/sql/list_problem.sql @@ -1,5 +1,6 @@ -SELECT - problem.*, - (SELECT "count"("id") FROM "problem_paragraph_mapping" WHERE problem_id="problem"."id") as "paragraph_count" - FROM - problem problem +SELECT problem.*, + (SELECT COUNT(ppm.id) + FROM problem_paragraph_mapping ppm + INNER JOIN paragraph p ON ppm.paragraph_id = p.id + WHERE ppm.problem_id = problem.id) AS "paragraph_count" +FROM problem problem diff --git a/apps/knowledge/task/embedding.py b/apps/knowledge/task/embedding.py index 8410118d035..750c28e15b0 100644 --- a/apps/knowledge/task/embedding.py +++ b/apps/knowledge/task/embedding.py @@ -8,7 +8,7 @@ from django.utils.translation import gettext_lazy as _ from common.config.embedding_config import ModelManage -from common.event import ListenerManagement, UpdateProblemArgs, UpdateEmbeddingKnowledgeIdArgs, \ +from common.event.listener_manage import ListenerManagement, UpdateProblemArgs, UpdateEmbeddingKnowledgeIdArgs, \ UpdateEmbeddingDocumentIdArgs from common.utils.logger import maxkb_logger from knowledge.models import Document, TaskType, State diff --git a/apps/knowledge/task/generate.py b/apps/knowledge/task/generate.py index 9a86a7ea2b3..bf89122a524 100644 --- a/apps/knowledge/task/generate.py +++ b/apps/knowledge/task/generate.py @@ -1,4 +1,3 @@ -import logging import traceback from celery_once import QueueOnce @@ -8,7 +7,7 @@ from langchain_core.messages import HumanMessage from common.config.embedding_config import ModelManage -from common.event import ListenerManagement +from common.event.listener_manage import ListenerManagement from common.utils.logger import maxkb_logger from common.utils.page_utils import page, page_desc from knowledge.models import Paragraph, Document, Status, TaskType, State diff --git a/apps/models_provider/impl/local_model_provider/model/__init__.py b/apps/local_model/__init__.py similarity index 100% rename from apps/models_provider/impl/local_model_provider/model/__init__.py rename to apps/local_model/__init__.py diff --git a/apps/local_model/admin.py b/apps/local_model/admin.py new file mode 100644 index 00000000000..8c38f3f3dad --- /dev/null +++ b/apps/local_model/admin.py @@ -0,0 +1,3 @@ +from django.contrib import admin + +# Register your models here. diff --git a/apps/local_model/apps.py b/apps/local_model/apps.py new file mode 100644 index 00000000000..285ca727840 --- /dev/null +++ b/apps/local_model/apps.py @@ -0,0 +1,6 @@ +from django.apps import AppConfig + + +class LocalModelConfig(AppConfig): + default_auto_field = 'django.db.models.BigAutoField' + name = 'local_model' diff --git a/apps/local_model/migrations/__init__.py b/apps/local_model/migrations/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/apps/local_model/models/__init__.py b/apps/local_model/models/__init__.py new file mode 100644 index 00000000000..8d63a092173 --- /dev/null +++ b/apps/local_model/models/__init__.py @@ -0,0 +1,10 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: __init__.py + @date:2023/9/25 15:04 + @desc: +""" + +from .model_management import * diff --git a/apps/local_model/models/model_management.py b/apps/local_model/models/model_management.py new file mode 100644 index 00000000000..ff3c0bf4e23 --- /dev/null +++ b/apps/local_model/models/model_management.py @@ -0,0 +1,49 @@ +# coding=utf-8 +import uuid_utils.compat as uuid + +from django.db import models + +from common.mixins.app_model_mixin import AppModelMixin +from local_model.models.user import User + + +class Status(models.TextChoices): + """系统设置类型""" + SUCCESS = "SUCCESS", '成功' + + ERROR = "ERROR", "失败" + + DOWNLOAD = "DOWNLOAD", '下载中' + + PAUSE_DOWNLOAD = "PAUSE_DOWNLOAD", '暂停下载' + + +class Model(AppModelMixin): + """ + 模型数据 + """ + id = models.UUIDField(primary_key=True, max_length=128, default=uuid.uuid7, editable=False, verbose_name="主键id") + + name = models.CharField(max_length=128, verbose_name="名称", db_index=True) + + status = models.CharField(max_length=20, verbose_name='设置类型', choices=Status.choices, + default=Status.SUCCESS, db_index=True) + + model_type = models.CharField(max_length=128, verbose_name="模型类型", db_index=True) + + model_name = models.CharField(max_length=128, verbose_name="模型名称", db_index=True) + + user = models.ForeignKey(User, on_delete=models.SET_NULL, db_constraint=False, blank=True, null=True) + + provider = models.CharField(max_length=128, verbose_name='供应商', db_index=True) + + credential = models.CharField(max_length=102400, verbose_name="模型认证信息") + + meta = models.JSONField(verbose_name="模型元数据,用于存储下载,或者错误信息", default=dict) + + model_params_form = models.JSONField(verbose_name="模型参数配置", default=list) + workspace_id = models.CharField(max_length=64, verbose_name="工作空间id", default="default", db_index=True) + + class Meta: + db_table = "model" + unique_together = ['name', 'workspace_id'] diff --git a/apps/local_model/models/system_setting.py b/apps/local_model/models/system_setting.py new file mode 100644 index 00000000000..4b62d47a74d --- /dev/null +++ b/apps/local_model/models/system_setting.py @@ -0,0 +1,34 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: system_management.py + @date:2024/3/19 13:47 + @desc: 邮箱管理 +""" + +from django.db import models + +from common.mixins.app_model_mixin import AppModelMixin + + +class SettingType(models.IntegerChoices): + """系统设置类型""" + EMAIL = 0, '邮箱' + + RSA = 1, "私钥秘钥" + + LOG = 2, "日志清理时间" + + +class SystemSetting(AppModelMixin): + """ + 系统设置 + """ + type = models.IntegerField(primary_key=True, verbose_name='设置类型', choices=SettingType.choices, + default=SettingType.EMAIL) + + meta = models.JSONField(verbose_name="配置数据", default=dict) + + class Meta: + db_table = "system_setting" diff --git a/apps/local_model/models/user.py b/apps/local_model/models/user.py new file mode 100644 index 00000000000..0f480d89fde --- /dev/null +++ b/apps/local_model/models/user.py @@ -0,0 +1,38 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎虎 + @file: user.py + @date:2025/4/14 10:20 + @desc: +""" +import uuid_utils.compat as uuid + +from django.db import models + +from common.utils.common import password_encrypt + + +class User(models.Model): + id = models.UUIDField(primary_key=True, max_length=128, default=uuid.uuid7, editable=False, verbose_name="主键id") + email = models.EmailField(unique=True, null=True, blank=True, verbose_name="邮箱", db_index=True) + phone = models.CharField(max_length=20, verbose_name="电话", default="", db_index=True) + nick_name = models.CharField(max_length=150, verbose_name="昵称", unique=True, db_index=True) + username = models.CharField(max_length=150, unique=True, verbose_name="用户名", db_index=True) + password = models.CharField(max_length=150, verbose_name="密码") + role = models.CharField(max_length=150, verbose_name="角色") + source = models.CharField(max_length=10, verbose_name="来源", default="LOCAL", db_index=True) + is_active = models.BooleanField(default=True, db_index=True) + language = models.CharField(max_length=10, verbose_name="语言", null=True, default=None) + create_time = models.DateTimeField(verbose_name="创建时间", auto_now_add=True, null=True, db_index=True) + update_time = models.DateTimeField(verbose_name="修改时间", auto_now=True, null=True, db_index=True) + + USERNAME_FIELD = 'username' + REQUIRED_FIELDS = [] + + class Meta: + db_table = "user" + + def set_password(self, row_password): + self.password = password_encrypt(row_password) + self._password = row_password diff --git a/apps/local_model/serializers/__init__.py b/apps/local_model/serializers/__init__.py new file mode 100644 index 00000000000..9bad5790a57 --- /dev/null +++ b/apps/local_model/serializers/__init__.py @@ -0,0 +1 @@ +# coding=utf-8 diff --git a/apps/local_model/serializers/model_apply_serializers.py b/apps/local_model/serializers/model_apply_serializers.py new file mode 100644 index 00000000000..5ecb2260c9d --- /dev/null +++ b/apps/local_model/serializers/model_apply_serializers.py @@ -0,0 +1,160 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: model_apply_serializers.py + @date:2024/8/20 20:39 + @desc: +""" +import json +import threading +import time + +from django.db import connection +from django.db.models import QuerySet +from django.utils.translation import gettext_lazy as _ +from langchain_core.documents import Document +from rest_framework import serializers + +from local_model.models import Model +from local_model.serializers.rsa_util import rsa_long_decrypt +from models_provider.impl.local_model_provider.local_model_provider import LocalModelProvider + +from common.cache.mem_cache import MemCache + +_lock = threading.Lock() +locks = {} + + +class ModelManage: + cache = MemCache('model', {}) + up_clear_time = time.time() + + @staticmethod + def _get_lock(_id): + lock = locks.get(_id) + if lock is None: + with _lock: + lock = locks.get(_id) + if lock is None: + lock = threading.Lock() + locks[_id] = lock + + return lock + + @staticmethod + def get_model(_id, get_model): + model_instance = ModelManage.cache.get(_id) + if model_instance is None: + lock = ModelManage._get_lock(_id) + with lock: + model_instance = ModelManage.cache.get(_id) + if model_instance is None: + model_instance = get_model(_id) + ModelManage.cache.set(_id, model_instance, timeout=60 * 60 * 8) + else: + if model_instance.is_cache_model(): + ModelManage.cache.touch(_id, timeout=60 * 60 * 8) + else: + model_instance = get_model(_id) + ModelManage.cache.set(_id, model_instance, timeout=60 * 60 * 8) + ModelManage.clear_timeout_cache() + return model_instance + + @staticmethod + def clear_timeout_cache(): + if time.time() - ModelManage.up_clear_time > 60 * 60: + threading.Thread(target=lambda: ModelManage.cache.clear_timeout_data()).start() + ModelManage.up_clear_time = time.time() + + @staticmethod + def delete_key(_id): + if ModelManage.cache.has_key(_id): + ModelManage.cache.delete(_id) + + +def get_local_model(model, **kwargs): + return LocalModelProvider().get_model(model.model_type, model.model_name, + json.loads( + rsa_long_decrypt(model.credential)), + model_id=model.id, + streaming=True, **kwargs) + + +def get_embedding_model(model_id): + model = QuerySet(Model).filter(id=model_id).first() + # 手动关闭数据库连接 + connection.close() + embedding_model = ModelManage.get_model(model_id, + lambda _id: get_local_model(model, use_local=True)) + return embedding_model + + +class EmbedDocuments(serializers.Serializer): + texts = serializers.ListField(required=True, child=serializers.CharField(required=True, + label=_('vector text')), + label=_('vector text list')), + + +class EmbedQuery(serializers.Serializer): + text = serializers.CharField(required=True, label=_('vector text')) + + +class CompressDocument(serializers.Serializer): + page_content = serializers.CharField(required=True, label=_('text')) + metadata = serializers.DictField(required=False, label=_('metadata')) + + +class CompressDocuments(serializers.Serializer): + documents = CompressDocument(required=True, many=True) + query = serializers.CharField(required=True, label=_('query')) + + +class ValidateModelSerializers(serializers.Serializer): + model_name = serializers.CharField(required=True, label=_('model_name')) + + model_type = serializers.CharField(required=True, label=_('model_type')) + + model_credential = serializers.DictField(required=True, label="credential") + + def validate_model(self, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + LocalModelProvider().is_valid_credential(self.data.get('model_type'), self.data.get('model_name'), + self.data.get('model_credential'), model_params={}, + raise_exception=True) + + +class ModelApplySerializers(serializers.Serializer): + model_id = serializers.UUIDField(required=True, label=_('model id')) + + def embed_documents(self, instance, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + EmbedDocuments(data=instance).is_valid(raise_exception=True) + + model = get_embedding_model(self.data.get('model_id')) + return model.embed_documents(instance.getlist('texts')) + + def embed_query(self, instance, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + EmbedQuery(data=instance).is_valid(raise_exception=True) + + model = get_embedding_model(self.data.get('model_id')) + return model.embed_query(instance.get('text')) + + def compress_documents(self, instance, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + CompressDocuments(data=instance).is_valid(raise_exception=True) + model = get_embedding_model(self.data.get('model_id')) + return [{'page_content': d.page_content, 'metadata': d.metadata} for d in model.compress_documents( + [Document(page_content=document.get('page_content'), metadata=document.get('metadata')) for document in + instance.get('documents')], instance.get('query'))] + + def unload(self, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + ModelManage.delete_key(self.data.get('model_id')) + return True diff --git a/apps/local_model/serializers/rsa_util.py b/apps/local_model/serializers/rsa_util.py new file mode 100644 index 00000000000..df2cedba736 --- /dev/null +++ b/apps/local_model/serializers/rsa_util.py @@ -0,0 +1,139 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: rsa_util.py + @date:2023/11/3 11:13 + @desc: +""" +import base64 +import threading + +from Crypto.Cipher import PKCS1_v1_5 as PKCS1_cipher +from Crypto.PublicKey import RSA +from django.core import cache +from django.db.models import QuerySet + +from common.constants.cache_version import Cache_Version +from local_model.models.system_setting import SystemSetting, SettingType + +lock = threading.Lock() +rsa_cache = cache.cache +cache_key = "rsa_key" +# 对密钥加密的密码 +secret_code = "mac_kb_password" + + +def generate(): + """ + 生成 私钥秘钥对 + :return:{key:'公钥',value:'私钥'} + """ + # 生成一个 2048 位的密钥 + key = RSA.generate(2048) + + # 获取私钥 + encrypted_key = key.export_key(passphrase=secret_code, pkcs=8, + protection="scryptAndAES128-CBC") + return {'key': key.publickey().export_key(), 'value': encrypted_key} + + +def get_key_pair(): + rsa_value = rsa_cache.get(cache_key) + if rsa_value is None: + with lock: + rsa_value = rsa_cache.get(cache_key) + if rsa_value is not None: + return rsa_value + rsa_value = get_key_pair_by_sql() + version, get_key = Cache_Version.SYSTEM.value + rsa_cache.set(get_key(key='rsa_key'), rsa_value, timeout=None, version=version) + return rsa_value + + +def get_key_pair_by_sql(): + system_setting = QuerySet(SystemSetting).filter(type=SettingType.RSA.value).first() + if system_setting is None: + kv = generate() + system_setting = SystemSetting(type=SettingType.RSA.value, + meta={'key': kv.get('key').decode(), 'value': kv.get('value').decode()}) + system_setting.save() + return system_setting.meta + + +def encrypt(msg, public_key: str | None = None): + """ + 加密 + :param msg: 加密数据 + :param public_key: 公钥 + :return: 加密后的数据 + """ + if public_key is None: + public_key = get_key_pair().get('key') + cipher = PKCS1_cipher.new(RSA.importKey(public_key)) + encrypt_msg = cipher.encrypt(msg.encode("utf-8")) + return base64.b64encode(encrypt_msg).decode() + + +def decrypt(msg, pri_key: str | None = None): + """ + 解密 + :param msg: 需要解密的数据 + :param pri_key: 私钥 + :return: 解密后数据 + """ + if pri_key is None: + pri_key = get_key_pair().get('value') + cipher = PKCS1_cipher.new(RSA.importKey(pri_key, passphrase=secret_code)) + decrypt_data = cipher.decrypt(base64.b64decode(msg), 0) + return decrypt_data.decode("utf-8") + + +def rsa_long_encrypt(message, public_key: str | None = None, length=200): + """ + 超长文本加密 + + :param message: 需要加密的字符串 + :param public_key 公钥 + :param length: 1024bit的证书用100, 2048bit的证书用 200 + :return: 加密后的数据 + """ + # 读取公钥 + if public_key is None: + public_key = get_key_pair().get('key') + cipher = PKCS1_cipher.new(RSA.importKey(extern_key=public_key, + passphrase=secret_code)) + # 处理:Plaintext is too long. 分段加密 + if len(message) <= length: + # 对编码的数据进行加密,并通过base64进行编码 + result = base64.b64encode(cipher.encrypt(message.encode('utf-8'))) + else: + rsa_text = [] + # 对编码后的数据进行切片,原因:加密长度不能过长 + for i in range(0, len(message), length): + cont = message[i:i + length] + # 对切片后的数据进行加密,并新增到text后面 + rsa_text.append(cipher.encrypt(cont.encode('utf-8'))) + # 加密完进行拼接 + cipher_text = b''.join(rsa_text) + # base64进行编码 + result = base64.b64encode(cipher_text) + return result.decode() + + +def rsa_long_decrypt(message, pri_key: str | None = None, length=256): + """ + 超长文本解密,默认不加密 + :param message: 需要解密的数据 + :param pri_key: 秘钥 + :param length : 1024bit的证书用128,2048bit证书用256位 + :return: 解密后的数据 + """ + if pri_key is None: + pri_key = get_key_pair().get('value') + cipher = PKCS1_cipher.new(RSA.importKey(pri_key, passphrase=secret_code)) + base64_de = base64.b64decode(message) + res = [] + for i in range(0, len(base64_de), length): + res.append(cipher.decrypt(base64_de[i:i + length], 0)) + return b"".join(res).decode() diff --git a/apps/local_model/tests.py b/apps/local_model/tests.py new file mode 100644 index 00000000000..7ce503c2dd9 --- /dev/null +++ b/apps/local_model/tests.py @@ -0,0 +1,3 @@ +from django.test import TestCase + +# Create your tests here. diff --git a/apps/local_model/urls.py b/apps/local_model/urls.py new file mode 100644 index 00000000000..a9c060254a3 --- /dev/null +++ b/apps/local_model/urls.py @@ -0,0 +1,15 @@ +import os + +from django.urls import path + +from . import views + +app_name = "local_model" +# @formatter:off +urlpatterns = [ + path('model/validate', views.LocalModelApply.Validate.as_view()), + path('model//embed_documents', views.LocalModelApply.EmbedDocuments.as_view()), + path('model//embed_query', views.LocalModelApply.EmbedQuery.as_view()), + path('model//compress_documents', views.LocalModelApply.CompressDocuments.as_view()), + path('model//unload', views.LocalModelApply.Unload.as_view()), + ] diff --git a/apps/local_model/views/__init__.py b/apps/local_model/views/__init__.py new file mode 100644 index 00000000000..b9dd8b0a4a4 --- /dev/null +++ b/apps/local_model/views/__init__.py @@ -0,0 +1,2 @@ +# coding=utf-8 +from .model_apply import * diff --git a/apps/local_model/views/model_apply.py b/apps/local_model/views/model_apply.py new file mode 100644 index 00000000000..98c07dd7493 --- /dev/null +++ b/apps/local_model/views/model_apply.py @@ -0,0 +1,43 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: model_apply.py + @date:2024/8/20 20:38 + @desc: +""" +from urllib.request import Request + +from rest_framework.views import APIView + +from common.result import result +from local_model.serializers.model_apply_serializers import ModelApplySerializers, ValidateModelSerializers + + +class LocalModelApply(APIView): + class EmbedDocuments(APIView): + + def post(self, request: Request, model_id): + return result.success( + ModelApplySerializers(data={'model_id': model_id}).embed_documents(request.data)) + + class EmbedQuery(APIView): + + def post(self, request: Request, model_id): + return result.success( + ModelApplySerializers(data={'model_id': model_id}).embed_query(request.data)) + + class CompressDocuments(APIView): + + def post(self, request: Request, model_id): + return result.success( + ModelApplySerializers(data={'model_id': model_id}).compress_documents(request.data)) + + class Unload(APIView): + def post(self, request: Request, model_id): + return result.success( + ModelApplySerializers(data={'model_id': model_id}).compress_documents(request.data)) + + class Validate(APIView): + def post(self, request: Request): + return result.success(ValidateModelSerializers(data=request.data).validate_model()) diff --git a/apps/locales/en_US/LC_MESSAGES/django.po b/apps/locales/en_US/LC_MESSAGES/django.po index af7b713d158..a2449479017 100644 --- a/apps/locales/en_US/LC_MESSAGES/django.po +++ b/apps/locales/en_US/LC_MESSAGES/django.po @@ -2207,7 +2207,7 @@ msgid "parent id" msgstr "" #: apps/folders/serializers/folder.py:75 -msgid "Folder depth cannot exceed 3 levels" +msgid "Folder depth cannot exceed 5 levels" msgstr "" #: apps/folders/serializers/folder.py:100 @@ -8763,4 +8763,7 @@ msgid "Tag value already exists" msgstr "" msgid "Non-existent id" +msgstr "" + +msgid "No permission for the target folder" msgstr "" \ No newline at end of file diff --git a/apps/locales/zh_CN/LC_MESSAGES/django.po b/apps/locales/zh_CN/LC_MESSAGES/django.po index f5160241bc0..793bef404d1 100644 --- a/apps/locales/zh_CN/LC_MESSAGES/django.po +++ b/apps/locales/zh_CN/LC_MESSAGES/django.po @@ -2214,8 +2214,8 @@ msgid "parent id" msgstr "父级 ID" #: apps/folders/serializers/folder.py:75 -msgid "Folder depth cannot exceed 3 levels" -msgstr "文件夹深度不能超过3级" +msgid "Folder depth cannot exceed 5 levels" +msgstr "文件夹深度不能超过5级" #: apps/folders/serializers/folder.py:100 msgid "folder user id" @@ -8890,3 +8890,7 @@ msgstr "标签值已存在" msgid "Non-existent id" msgstr "不存在的ID" + +msgid "No permission for the target folder" +msgstr "没有目标文件夹的权限" + diff --git a/apps/locales/zh_Hant/LC_MESSAGES/django.po b/apps/locales/zh_Hant/LC_MESSAGES/django.po index 01749b1ef09..e543737505c 100644 --- a/apps/locales/zh_Hant/LC_MESSAGES/django.po +++ b/apps/locales/zh_Hant/LC_MESSAGES/django.po @@ -2214,8 +2214,8 @@ msgid "parent id" msgstr "父級 ID" #: apps/folders/serializers/folder.py:75 -msgid "Folder depth cannot exceed 3 levels" -msgstr "文件夾深度不能超過3級" +msgid "Folder depth cannot exceed 5 levels" +msgstr "文件夾深度不能超過5級" #: apps/folders/serializers/folder.py:100 msgid "folder user id" @@ -8890,3 +8890,7 @@ msgstr "標籤值已存在" msgid "Non-existent id" msgstr "不存在的ID" + +msgid "No permission for the target folder" +msgstr "沒有目標資料夾的權限" + diff --git a/apps/maxkb/settings/__init__.py b/apps/maxkb/settings/__init__.py index 8333fa1bd4f..e973afc3014 100644 --- a/apps/maxkb/settings/__init__.py +++ b/apps/maxkb/settings/__init__.py @@ -9,4 +9,5 @@ from .base import * from .logging import * from .auth import * -from .lib import * \ No newline at end of file +from .lib import * +from .mem import * \ No newline at end of file diff --git a/apps/maxkb/settings/auth/__init__.py b/apps/maxkb/settings/auth/__init__.py new file mode 100644 index 00000000000..b7996e1f56f --- /dev/null +++ b/apps/maxkb/settings/auth/__init__.py @@ -0,0 +1,14 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎虎 + @file: __init__.py + @date:2025/11/5 14:50 + @desc: +""" +import os + +if os.environ.get('SERVER_NAME', 'web') == 'local_model': + from .model import * +else: + from .web import * diff --git a/apps/maxkb/settings/auth/model.py b/apps/maxkb/settings/auth/model.py new file mode 100644 index 00000000000..a210130254a --- /dev/null +++ b/apps/maxkb/settings/auth/model.py @@ -0,0 +1,11 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎虎 + @file: auth.py + @date:2024/7/9 18:47 + @desc: +""" + +AUTH_HANDLES = [ +] diff --git a/apps/maxkb/settings/auth.py b/apps/maxkb/settings/auth/web.py similarity index 100% rename from apps/maxkb/settings/auth.py rename to apps/maxkb/settings/auth/web.py diff --git a/apps/maxkb/settings/base/__init__.py b/apps/maxkb/settings/base/__init__.py new file mode 100644 index 00000000000..65d1845bbbf --- /dev/null +++ b/apps/maxkb/settings/base/__init__.py @@ -0,0 +1,14 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎虎 + @file: __init__.py.py + @date:2025/11/5 14:53 + @desc: +""" +import os + +if os.environ.get('SERVER_NAME', 'web') == 'local_model': + from .model import * +else: + from .web import * diff --git a/apps/maxkb/settings/base/model.py b/apps/maxkb/settings/base/model.py new file mode 100644 index 00000000000..a4736423749 --- /dev/null +++ b/apps/maxkb/settings/base/model.py @@ -0,0 +1,179 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎虎 + @file: model.py + @date:2025/11/5 14:53 + @desc: +""" + +from pathlib import Path +from ...const import CONFIG, PROJECT_DIR +import os +from django.utils.translation import gettext_lazy as _ + +# Build paths inside the project like this: BASE_DIR / 'subdir'. +BASE_DIR = Path(__file__).resolve().parent.parent.parent + +# Quick-start development settings - unsuitable for production +# See https://docs.djangoproject.com/en/4.2/howto/deployment/checklist/ + +# SECURITY WARNING: keep the secret key used in production secret! +SECRET_KEY = CONFIG.get("SECRET_KEY") or 'django-insecure-zm^1_^i5)3gp^&0io6zg72&z!a*d=9kf9o2%uft+27l)+t(#3e' + +# SECURITY WARNING: don't run with debug turned on in production! +DEBUG = CONFIG.get_debug() + +ALLOWED_HOSTS = ['*'] + +# Application definition + +INSTALLED_APPS = [ + 'django.contrib.contenttypes', + 'django.contrib.messages', + 'django.contrib.staticfiles', + 'rest_framework', + 'local_model', +] + +MIDDLEWARE = [ + 'django.middleware.locale.LocaleMiddleware', + 'django.middleware.security.SecurityMiddleware', + 'django.contrib.sessions.middleware.SessionMiddleware', + +] + +REST_FRAMEWORK = { + 'EXCEPTION_HANDLER': 'common.exception.handle_exception.handle_exception', + 'DEFAULT_SCHEMA_CLASS': 'drf_spectacular.openapi.AutoSchema', + 'DEFAULT_AUTHENTICATION_CLASSES': ['common.auth.authenticate.AnonymousAuthentication'] +} +STATICFILES_DIRS = [(os.path.join(PROJECT_DIR, 'ui', 'dist'))] +STATIC_ROOT = os.path.join(BASE_DIR.parent, 'static') +ROOT_URLCONF = 'maxkb.urls' +APPS_DIR = os.path.join(PROJECT_DIR, 'apps') + +TEMPLATES = [ + { + 'BACKEND': 'django.template.backends.django.DjangoTemplates', + 'DIRS': ["apps/static/admin"], + 'APP_DIRS': True, + 'OPTIONS': { + 'context_processors': [ + 'django.template.context_processors.debug', + 'django.template.context_processors.request', + 'django.contrib.auth.context_processors.auth', + 'django.contrib.messages.context_processors.messages', + ], + }, + }, + {"NAME": "CHAT", + 'BACKEND': 'django.template.backends.django.DjangoTemplates', + 'DIRS': ["apps/static/chat"], + 'APP_DIRS': True, + 'OPTIONS': { + 'context_processors': [ + 'django.template.context_processors.debug', + 'django.template.context_processors.request', + 'django.contrib.auth.context_processors.auth', + 'django.contrib.messages.context_processors.messages', + ], + }, + }, + {"NAME": "DOC", + 'BACKEND': 'django.template.backends.django.DjangoTemplates', + 'DIRS': ["apps/static/drf_spectacular_sidecar"], + 'APP_DIRS': True, + 'OPTIONS': { + 'context_processors': [ + 'django.template.context_processors.debug', + 'django.template.context_processors.request', + 'django.contrib.auth.context_processors.auth', + 'django.contrib.messages.context_processors.messages', + ], + }, + }, +] +SPECTACULAR_SETTINGS = { + 'TITLE': 'MaxKB API', + 'DESCRIPTION': _('Intelligent customer service platform'), + 'VERSION': 'v2', + 'SERVE_INCLUDE_SCHEMA': False, + # OTHER SETTINGS + 'SWAGGER_UI_DIST': f'{CONFIG.get_admin_path()}/api-doc/swagger-ui-dist', # shorthand to use the sidecar instead + 'SWAGGER_UI_FAVICON_HREF': f'{CONFIG.get_admin_path()}/api-doc/swagger-ui-dist/favicon-32x32.png', + 'REDOC_DIST': f'{CONFIG.get_admin_path()}/api-doc/redoc', + 'SECURITY_DEFINITIONS': { + 'Bearer': { + 'type': 'apiKey', + 'name': 'AUTHORIZATION', + 'in': 'header', + } + } +} +WSGI_APPLICATION = 'maxkb.wsgi.application' + +# Database +# https://docs.djangoproject.com/en/4.2/ref/settings/#databases + +DATABASES = {'default': CONFIG.get_db_setting()} + +CACHES = CONFIG.get_cache_setting() + +# Password validation +# https://docs.djangoproject.com/en/4.2/ref/settings/#auth-password-validators + +AUTH_PASSWORD_VALIDATORS = [ + { + 'NAME': 'django.contrib.auth.password_validation.UserAttributeSimilarityValidator', + }, + { + 'NAME': 'django.contrib.auth.password_validation.MinimumLengthValidator', + }, + { + 'NAME': 'django.contrib.auth.password_validation.CommonPasswordValidator', + }, + { + 'NAME': 'django.contrib.auth.password_validation.NumericPasswordValidator', + }, +] + +# Internationalization +# https://docs.djangoproject.com/en/4.2/topics/i18n/ + +LANGUAGE_CODE = CONFIG.get("LANGUAGE_CODE") + +TIME_ZONE = CONFIG.get_time_zone() + +USE_I18N = True + +USE_TZ = True + +# 文件上传配置 +DATA_UPLOAD_MAX_NUMBER_FILES = 1000 + +# 支持的语言 +LANGUAGES = [ + ('en', 'English'), + ('zh', '中文简体'), + ('zh-hant', '中文繁体') +] +# 翻译文件路径 +LOCALE_PATHS = [ + os.path.join(BASE_DIR.parent, 'locales') +] + +# Static files (CSS, JavaScript, Images) +# https://docs.djangoproject.com/en/4.2/howto/static-files/ + +STATIC_URL = 'static/' + +# Default primary key field type +# https://docs.djangoproject.com/en/4.2/ref/settings/#default-auto-field + +DEFAULT_AUTO_FIELD = 'django.db.models.BigAutoField' + +edition = 'CE' + +if os.environ.get('MAXKB_REDIS_SENTINEL_SENTINELS') is not None: + DJANGO_REDIS_CONNECTION_FACTORY = "django_redis.pool.SentinelConnectionFactory" diff --git a/apps/maxkb/settings/base.py b/apps/maxkb/settings/base/web.py similarity index 93% rename from apps/maxkb/settings/base.py rename to apps/maxkb/settings/base/web.py index 780f1f32e37..04f6ae5a350 100644 --- a/apps/maxkb/settings/base.py +++ b/apps/maxkb/settings/base/web.py @@ -1,22 +1,18 @@ +# coding=utf-8 """ -Django settings for maxkb project. - -Generated by 'django-admin startproject' using Django 4.2.4. - -For more information on this file, see -https://docs.djangoproject.com/en/4.2/topics/settings/ - -For the full list of settings and their values, see -https://docs.djangoproject.com/en/4.2/ref/settings/ + @project: MaxKB + @Author:虎虎 + @file: web.py + @date:2025/11/5 14:53 + @desc: """ - from pathlib import Path -from ..const import CONFIG, PROJECT_DIR +from ...const import CONFIG, PROJECT_DIR import os from django.utils.translation import gettext_lazy as _ # Build paths inside the project like this: BASE_DIR / 'subdir'. -BASE_DIR = Path(__file__).resolve().parent.parent +BASE_DIR = Path(__file__).resolve().parent.parent.parent # Quick-start development settings - unsuitable for production # See https://docs.djangoproject.com/en/4.2/howto/deployment/checklist/ diff --git a/apps/maxkb/settings/mem.py b/apps/maxkb/settings/mem.py new file mode 100644 index 00000000000..d4e2ccc8e9c --- /dev/null +++ b/apps/maxkb/settings/mem.py @@ -0,0 +1,22 @@ +# coding=utf-8 +import os +import gc +import threading +from maxkb.const import CONFIG +from common.utils.logger import maxkb_logger +import random + +CURRENT_PID=os.getpid() +# 1 hour +GC_INTERVAL = 3600 + +def enable_force_gc(): + collected = gc.collect() + maxkb_logger.debug(f"(PID: {CURRENT_PID}) Forced GC ({collected} objects collected)") + t = threading.Timer(GC_INTERVAL - random.randint(0, 900), enable_force_gc) + t.daemon = True + t.start() + +if CONFIG.get("ENABLE_FORCE_GC", '1') == "1": + maxkb_logger.info(f"(PID: {CURRENT_PID}) Forced GC enabled") + enable_force_gc() diff --git a/apps/maxkb/urls/__init__.py b/apps/maxkb/urls/__init__.py new file mode 100644 index 00000000000..f788f029ea6 --- /dev/null +++ b/apps/maxkb/urls/__init__.py @@ -0,0 +1,14 @@ +# coding=utf-8 +""" + @project: MaxKB-xpack + @Author:虎虎 + @file: __init__.py.py + @date:2025/11/5 14:45 + @desc: +""" +import os + +if os.environ.get('SERVER_NAME', 'web') == 'local_model': + from .model import * +else: + from .web import * diff --git a/apps/maxkb/urls/model.py b/apps/maxkb/urls/model.py new file mode 100644 index 00000000000..6e3a11dd063 --- /dev/null +++ b/apps/maxkb/urls/model.py @@ -0,0 +1,28 @@ +""" +URL configuration for maxkb project. + +The `urlpatterns` list routes URLs to views. For more information please see: + https://docs.djangoproject.com/en/4.2/topics/http/urls/ +Examples: +Function views + 1. Add an import: from my_app import views + 2. Add a URL to urlpatterns: path('', views.home, name='home') +Class-based views + 1. Add an import: from other_app.views import Home + 2. Add a URL to urlpatterns: path('', Home.as_view(), name='home') +Including another URLconf + 1. Import the include() function: from django.urls import include, path + 2. Add a URL to urlpatterns: path('blog/', include('blog.urls')) +""" + +from django.urls import path, include + +from maxkb.const import CONFIG + +admin_api_prefix = CONFIG.get_admin_path()[1:] + '/api/' +admin_ui_prefix = CONFIG.get_admin_path() +chat_api_prefix = CONFIG.get_chat_path()[1:] + '/api/' +chat_ui_prefix = CONFIG.get_chat_path() +urlpatterns = [ + path(admin_api_prefix, include("local_model.urls")), +] diff --git a/apps/maxkb/urls.py b/apps/maxkb/urls/web.py similarity index 100% rename from apps/maxkb/urls.py rename to apps/maxkb/urls/web.py diff --git a/apps/maxkb/wsgi.py b/apps/maxkb/wsgi.py deleted file mode 100644 index fc271a7c333..00000000000 --- a/apps/maxkb/wsgi.py +++ /dev/null @@ -1,38 +0,0 @@ -""" -WSGI config for maxkb project. - -It exposes the WSGI callable as a module-level variable named ``application``. - -For more information on this file, see -https://docs.djangoproject.com/en/4.2/howto/deployment/wsgi/ -""" - -import os - -from django.core.wsgi import get_wsgi_application - -os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'maxkb.settings') - -application = get_wsgi_application() - - - -def post_handler(): - from common.database_model_manage.database_model_manage import DatabaseModelManage - from common import event - - event.run() - DatabaseModelManage.init() - - -def post_scheduler_handler(): - from common import job - - job.run() - -# 启动后处理函数 -post_handler() - -# 仅在scheduler中启动定时任务,dev local_model celery 不需要 -if os.environ.get('ENABLE_SCHEDULER') == '1': - post_scheduler_handler() diff --git a/apps/maxkb/wsgi/__init__.py b/apps/maxkb/wsgi/__init__.py new file mode 100644 index 00000000000..58649348fc6 --- /dev/null +++ b/apps/maxkb/wsgi/__init__.py @@ -0,0 +1,14 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎虎 + @file: __init__.py.py + @date:2025/11/5 15:14 + @desc: +""" +import os + +if os.environ.get('SERVER_NAME', 'web') == 'local_model': + from .model import * +else: + from .web import * diff --git a/apps/maxkb/wsgi/model.py b/apps/maxkb/wsgi/model.py new file mode 100644 index 00000000000..1d70874be1d --- /dev/null +++ b/apps/maxkb/wsgi/model.py @@ -0,0 +1,15 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎虎 + @file: model.py + @date:2025/11/5 15:14 + @desc: +""" +import os + +from django.core.wsgi import get_wsgi_application + +os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'maxkb.settings') + +application = get_wsgi_application() \ No newline at end of file diff --git a/apps/maxkb/wsgi/web.py b/apps/maxkb/wsgi/web.py new file mode 100644 index 00000000000..d1fa687d8ee --- /dev/null +++ b/apps/maxkb/wsgi/web.py @@ -0,0 +1,48 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎虎 + @file: web.py + @date:2025/11/5 15:14 + @desc: +""" +import builtins +import os +import sys + +from django.core.wsgi import get_wsgi_application + + +class TorchBlocker: + def __init__(self): + self.original_import = builtins.__import__ + + def __call__(self, name, *args, **kwargs): + if len([True for i in + ['torch'] + if + i in name.lower()]) > 0: + print(f"Disable package is being imported: 【{name}】", file=sys.stderr) + pass + else: + return self.original_import(name, *args, **kwargs) + + +# 安装导入拦截器 +builtins.__import__ = TorchBlocker() + +os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'maxkb.settings') +os.environ['TIKTOKEN_CACHE_DIR'] = '/opt/maxkb-app/model/tokenizer/openai-tiktoken-cl100k-base' +application = get_wsgi_application() + + +def post_handler(): + from common.database_model_manage.database_model_manage import DatabaseModelManage + from common import event + + event.run() + DatabaseModelManage.init() + + +# 启动后处理函数 +post_handler() diff --git a/apps/models_provider/impl/gemini_model_provider/model/image.py b/apps/models_provider/impl/gemini_model_provider/model/image.py index 1f4e97a182e..e4d605c1c03 100644 --- a/apps/models_provider/impl/gemini_model_provider/model/image.py +++ b/apps/models_provider/impl/gemini_model_provider/model/image.py @@ -23,6 +23,5 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], ** return GeminiImage( model=model_name, google_api_key=model_credential.get('api_key'), - streaming=True, **optional_params, ) diff --git a/apps/models_provider/impl/local_model_provider/credential/embedding/__init__.py b/apps/models_provider/impl/local_model_provider/credential/embedding/__init__.py new file mode 100644 index 00000000000..29828bb7401 --- /dev/null +++ b/apps/models_provider/impl/local_model_provider/credential/embedding/__init__.py @@ -0,0 +1,14 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎虎 + @file: __init__.py.py + @date:2025/11/7 14:02 + @desc: +""" +import os + +if os.environ.get('SERVER_NAME', 'web') == 'local_model': + from .model import * +else: + from .web import * diff --git a/apps/models_provider/impl/local_model_provider/credential/embedding.py b/apps/models_provider/impl/local_model_provider/credential/embedding/model.py similarity index 96% rename from apps/models_provider/impl/local_model_provider/credential/embedding.py rename to apps/models_provider/impl/local_model_provider/credential/embedding/model.py index 9d656ad9833..d9ec4c3daff 100644 --- a/apps/models_provider/impl/local_model_provider/credential/embedding.py +++ b/apps/models_provider/impl/local_model_provider/credential/embedding/model.py @@ -1,9 +1,9 @@ # coding=utf-8 """ @project: MaxKB - @Author:虎 - @file: embedding.py - @date:2024/7/11 11:06 + @Author:虎虎 + @file: model.py.py + @date:2025/11/7 14:02 @desc: """ import traceback diff --git a/apps/models_provider/impl/local_model_provider/credential/embedding/web.py b/apps/models_provider/impl/local_model_provider/credential/embedding/web.py new file mode 100644 index 00000000000..4695d141c6d --- /dev/null +++ b/apps/models_provider/impl/local_model_provider/credential/embedding/web.py @@ -0,0 +1,37 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎虎 + @file: web.py + @date:2025/11/7 14:03 + @desc: +""" +from typing import Dict + +import requests +from django.utils.translation import gettext_lazy as _ + +from common import forms +from common.forms import BaseForm +from maxkb.const import CONFIG +from models_provider.base_model_provider import BaseModelCredential + + +class LocalEmbeddingCredential(BaseForm, BaseModelCredential): + + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider, + raise_exception=False): + bind = f'{CONFIG.get("LOCAL_MODEL_HOST")}:{CONFIG.get("LOCAL_MODEL_PORT")}' + prefix = CONFIG.get_admin_path() + res = requests.post( + f'{CONFIG.get("LOCAL_MODEL_PROTOCOL")}://{bind}{prefix}/api/model/validate', + json={'model_name': model_name, 'model_type': model_type, 'model_credential': model_credential}) + result = res.json() + if result.get('code', 500) == 200: + return result.get('data') + raise Exception(result.get('message')) + + def encryption_dict(self, model: Dict[str, object]): + return model + + cache_folder = forms.TextInputField(_('Model catalog'), required=True) diff --git a/apps/models_provider/impl/local_model_provider/credential/reranker/__init__.py b/apps/models_provider/impl/local_model_provider/credential/reranker/__init__.py new file mode 100644 index 00000000000..f9ec12bc56c --- /dev/null +++ b/apps/models_provider/impl/local_model_provider/credential/reranker/__init__.py @@ -0,0 +1,14 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎虎 + @file: __init__.py.py + @date:2025/11/7 14:22 + @desc: +""" +import os + +if os.environ.get('SERVER_NAME', 'web') == 'local_model': + from .model import * +else: + from .web import * diff --git a/apps/models_provider/impl/local_model_provider/credential/reranker.py b/apps/models_provider/impl/local_model_provider/credential/reranker/model.py similarity index 91% rename from apps/models_provider/impl/local_model_provider/credential/reranker.py rename to apps/models_provider/impl/local_model_provider/credential/reranker/model.py index 94341d52f42..85b2abce90c 100644 --- a/apps/models_provider/impl/local_model_provider/credential/reranker.py +++ b/apps/models_provider/impl/local_model_provider/credential/reranker/model.py @@ -1,9 +1,9 @@ # coding=utf-8 """ @project: MaxKB - @Author:虎 - @file: reranker.py - @date:2024/9/3 14:33 + @Author:虎虎 + @file: model.py + @date:2025/11/7 14:23 @desc: """ import traceback @@ -15,7 +15,7 @@ from common.exception.app_exception import AppApiException from common.forms import BaseForm from models_provider.base_model_provider import BaseModelCredential, ValidCode -from models_provider.impl.local_model_provider.model.reranker import LocalBaseReranker +from models_provider.impl.local_model_provider.model.reranker import LocalReranker from django.utils.translation import gettext_lazy as _, gettext @@ -33,7 +33,7 @@ def is_valid(self, model_type: str, model_name, model_credential: Dict[str, obje else: return False try: - model: LocalBaseReranker = provider.get_model(model_type, model_name, model_credential) + model: LocalReranker = provider.get_model(model_type, model_name, model_credential) model.compress_documents([Document(page_content=gettext('Hello'))], gettext('Hello')) except Exception as e: traceback.print_exc() diff --git a/apps/models_provider/impl/local_model_provider/credential/reranker/web.py b/apps/models_provider/impl/local_model_provider/credential/reranker/web.py new file mode 100644 index 00000000000..bc86982bf08 --- /dev/null +++ b/apps/models_provider/impl/local_model_provider/credential/reranker/web.py @@ -0,0 +1,37 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎虎 + @file: web.py + @date:2025/11/7 14:23 + @desc: +""" +from typing import Dict + +import requests +from django.utils.translation import gettext_lazy as _ + +from common import forms +from common.forms import BaseForm +from maxkb.const import CONFIG +from models_provider.base_model_provider import BaseModelCredential + + +class LocalRerankerCredential(BaseForm, BaseModelCredential): + + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider, + raise_exception=False): + bind = f'{CONFIG.get("LOCAL_MODEL_HOST")}:{CONFIG.get("LOCAL_MODEL_PORT")}' + prefix = CONFIG.get_admin_path() + res = requests.post( + f'{CONFIG.get("LOCAL_MODEL_PROTOCOL")}://{bind}{prefix}/api/model/validate', + json={'model_name': model_name, 'model_type': model_type, 'model_credential': model_credential}) + result = res.json() + if result.get('code', 500) == 200: + return result.get('data') + raise Exception(result.get('message')) + + def encryption_dict(self, model: Dict[str, object]): + return model + + cache_folder = forms.TextInputField(_('Model catalog'), required=True) diff --git a/apps/models_provider/impl/local_model_provider/local_model_provider.py b/apps/models_provider/impl/local_model_provider/local_model_provider.py index 14603962a68..342f585f4d4 100644 --- a/apps/models_provider/impl/local_model_provider/local_model_provider.py +++ b/apps/models_provider/impl/local_model_provider/local_model_provider.py @@ -8,15 +8,16 @@ """ import os +from django.utils.translation import gettext as _ + from common.utils.common import get_file_content +from maxkb.conf import PROJECT_DIR from models_provider.base_model_provider import ModelProvideInfo, ModelTypeConst, ModelInfo, IModelProvider, \ ModelInfoManage from models_provider.impl.local_model_provider.credential.embedding import LocalEmbeddingCredential from models_provider.impl.local_model_provider.credential.reranker import LocalRerankerCredential from models_provider.impl.local_model_provider.model.embedding import LocalEmbedding from models_provider.impl.local_model_provider.model.reranker import LocalReranker -from maxkb.conf import PROJECT_DIR -from django.utils.translation import gettext as _ embedding_text2vec_base_chinese = ModelInfo('shibing624/text2vec-base-chinese', '', ModelTypeConst.EMBEDDING, LocalEmbeddingCredential(), LocalEmbedding) diff --git a/apps/models_provider/impl/local_model_provider/model/embedding.py b/apps/models_provider/impl/local_model_provider/model/embedding.py deleted file mode 100644 index 94ea652af12..00000000000 --- a/apps/models_provider/impl/local_model_provider/model/embedding.py +++ /dev/null @@ -1,64 +0,0 @@ -# coding=utf-8 -""" - @project: MaxKB - @Author:虎 - @file: embedding.py - @date:2024/7/11 14:06 - @desc: -""" -from typing import Dict, List - -import requests -from langchain_core.embeddings import Embeddings -from pydantic import BaseModel -from langchain_huggingface import HuggingFaceEmbeddings - -from models_provider.base_model_provider import MaxKBBaseModel -from maxkb.const import CONFIG - - -class WebLocalEmbedding(MaxKBBaseModel, BaseModel, Embeddings): - @staticmethod - def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): - pass - - model_id: str = None - - def __init__(self, **kwargs): - super().__init__(**kwargs) - self.model_id = kwargs.get('model_id', None) - - def embed_query(self, text: str) -> List[float]: - bind = f'{CONFIG.get("LOCAL_MODEL_HOST")}:{CONFIG.get("LOCAL_MODEL_PORT")}' - prefix = CONFIG.get_admin_path() - res = requests.post(f'{CONFIG.get("LOCAL_MODEL_PROTOCOL")}://{bind}{prefix}/api/model/{self.model_id}/embed_query', - {'text': text}) - result = res.json() - if result.get('code', 500) == 200: - return result.get('data') - raise Exception(result.get('message')) - - def embed_documents(self, texts: List[str]) -> List[List[float]]: - bind = f'{CONFIG.get("LOCAL_MODEL_HOST")}:{CONFIG.get("LOCAL_MODEL_PORT")}' - prefix = CONFIG.get_admin_path() - res = requests.post(f'{CONFIG.get("LOCAL_MODEL_PROTOCOL")}://{bind}/{prefix}/api/model/{self.model_id}/embed_documents', - {'texts': texts}) - result = res.json() - if result.get('code', 500) == 200: - return result.get('data') - raise Exception(result.get('message')) - - -class LocalEmbedding(MaxKBBaseModel, HuggingFaceEmbeddings): - - @staticmethod - def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): - if model_kwargs.get('use_local', True): - return LocalEmbedding(model_name=model_name, cache_folder=model_credential.get('cache_folder'), - model_kwargs={'device': model_credential.get('device')}, - encode_kwargs={'normalize_embeddings': True} - ) - return WebLocalEmbedding(model_name=model_name, cache_folder=model_credential.get('cache_folder'), - model_kwargs={'device': model_credential.get('device')}, - encode_kwargs={'normalize_embeddings': True}, - **model_kwargs) diff --git a/apps/models_provider/impl/local_model_provider/model/embedding/__init__.py b/apps/models_provider/impl/local_model_provider/model/embedding/__init__.py new file mode 100644 index 00000000000..840afa5afc4 --- /dev/null +++ b/apps/models_provider/impl/local_model_provider/model/embedding/__init__.py @@ -0,0 +1,14 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎虎 + @file: __init__.py + @date:2025/11/5 15:24 + @desc: +""" +import os + +if os.environ.get('SERVER_NAME', 'web') == 'local_model': + from .model import * +else: + from .web import * diff --git a/apps/models_provider/impl/local_model_provider/model/embedding/model.py b/apps/models_provider/impl/local_model_provider/model/embedding/model.py new file mode 100644 index 00000000000..7ebc41cb150 --- /dev/null +++ b/apps/models_provider/impl/local_model_provider/model/embedding/model.py @@ -0,0 +1,26 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎虎 + @file: model.py + @date:2025/11/5 15:26 + @desc: +""" +from typing import Dict + +from langchain_huggingface import HuggingFaceEmbeddings + +from models_provider.base_model_provider import MaxKBBaseModel + + +class LocalEmbedding(MaxKBBaseModel, HuggingFaceEmbeddings): + @staticmethod + def is_cache_model(): + return True + + @staticmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + return LocalEmbedding(model_name=model_name, cache_folder=model_credential.get('cache_folder'), + model_kwargs={'device': model_credential.get('device')}, + encode_kwargs={'normalize_embeddings': True} + ) diff --git a/apps/models_provider/impl/local_model_provider/model/embedding/web.py b/apps/models_provider/impl/local_model_provider/model/embedding/web.py new file mode 100644 index 00000000000..bfc22bc9ba6 --- /dev/null +++ b/apps/models_provider/impl/local_model_provider/model/embedding/web.py @@ -0,0 +1,54 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎虎 + @file: web.py + @date:2025/11/5 15:24 + @desc: +""" + +from typing import Dict, List + +import requests +from anthropic import BaseModel +from langchain_core.embeddings import Embeddings + +from maxkb.const import CONFIG +from models_provider.base_model_provider import MaxKBBaseModel + + +class LocalEmbedding(MaxKBBaseModel, BaseModel, Embeddings): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.model_id = kwargs.get('model_id', None) + + @staticmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + return LocalEmbedding(model_name=model_name, cache_folder=model_credential.get('cache_folder'), + model_kwargs={'device': model_credential.get('device')}, + encode_kwargs={'normalize_embeddings': True}, + **model_kwargs) + + model_id: str = None + + def embed_query(self, text: str) -> List[float]: + bind = f'{CONFIG.get("LOCAL_MODEL_HOST")}:{CONFIG.get("LOCAL_MODEL_PORT")}' + prefix = CONFIG.get_admin_path() + res = requests.post( + f'{CONFIG.get("LOCAL_MODEL_PROTOCOL")}://{bind}{prefix}/api/model/{self.model_id}/embed_query', + {'text': text}) + result = res.json() + if result.get('code', 500) == 200: + return result.get('data') + raise Exception(result.get('message')) + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + bind = f'{CONFIG.get("LOCAL_MODEL_HOST")}:{CONFIG.get("LOCAL_MODEL_PORT")}' + prefix = CONFIG.get_admin_path() + res = requests.post( + f'{CONFIG.get("LOCAL_MODEL_PROTOCOL")}://{bind}/{prefix}/api/model/{self.model_id}/embed_documents', + {'texts': texts}) + result = res.json() + if result.get('code', 500) == 200: + return result.get('data') + raise Exception(result.get('message')) diff --git a/apps/models_provider/impl/local_model_provider/model/reranker.py b/apps/models_provider/impl/local_model_provider/model/reranker.py deleted file mode 100644 index 5ac2a4ca838..00000000000 --- a/apps/models_provider/impl/local_model_provider/model/reranker.py +++ /dev/null @@ -1,102 +0,0 @@ -# coding=utf-8 -""" - @project: MaxKB - @Author:虎 - @file: reranker.py.py - @date:2024/9/2 16:42 - @desc: -""" -from typing import Sequence, Optional, Dict, Any, ClassVar - -import requests -import torch -from langchain_core.callbacks import Callbacks -from langchain_core.documents import BaseDocumentCompressor, Document -from transformers import AutoModelForSequenceClassification, AutoTokenizer - -from models_provider.base_model_provider import MaxKBBaseModel -from maxkb.const import CONFIG - - -class LocalReranker(MaxKBBaseModel): - def __init__(self, model_name, top_n=3, cache_dir=None): - super().__init__() - self.model_name = model_name - self.cache_dir = cache_dir - self.top_n = top_n - - @staticmethod - def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): - if model_kwargs.get('use_local', True): - return LocalBaseReranker(model_name=model_name, cache_dir=model_credential.get('cache_dir'), - model_kwargs={'device': model_credential.get('device', 'cpu')} - - ) - return WebLocalBaseReranker(model_name=model_name, cache_dir=model_credential.get('cache_dir'), - model_kwargs={'device': model_credential.get('device')}, - **model_kwargs) - - -class WebLocalBaseReranker(MaxKBBaseModel, BaseDocumentCompressor): - @staticmethod - def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): - pass - - model_id: str = None - - def __init__(self, **kwargs): - super().__init__(**kwargs) - self.model_id = kwargs.get('model_id', None) - - def compress_documents(self, documents: Sequence[Document], query: str, callbacks: Optional[Callbacks] = None) -> \ - Sequence[Document]: - if documents is None or len(documents) == 0: - return [] - prefix = CONFIG.get_admin_path() - bind = f'{CONFIG.get("LOCAL_MODEL_HOST")}:{CONFIG.get("LOCAL_MODEL_PORT")}' - res = requests.post( - f'{CONFIG.get("LOCAL_MODEL_PROTOCOL")}://{bind}{prefix}/api/model/{self.model_id}/compress_documents', - json={'documents': [{'page_content': document.page_content, 'metadata': document.metadata} for document in - documents], 'query': query}, headers={'Content-Type': 'application/json'}) - result = res.json() - if result.get('code', 500) == 200: - return [Document(page_content=document.get('page_content'), metadata=document.get('metadata')) for document - in result.get('data')] - raise Exception(result.get('message')) - - -class LocalBaseReranker(MaxKBBaseModel, BaseDocumentCompressor): - client: Any = None - tokenizer: Any = None - model: Optional[str] = None - cache_dir: Optional[str] = None - model_kwargs: Any = {} - - def __init__(self, model_name, cache_dir=None, **model_kwargs): - super().__init__() - self.model = model_name - self.cache_dir = cache_dir - self.model_kwargs = model_kwargs - self.client = AutoModelForSequenceClassification.from_pretrained(self.model, cache_dir=self.cache_dir) - self.tokenizer = AutoTokenizer.from_pretrained(self.model, cache_dir=self.cache_dir) - self.client = self.client.to(self.model_kwargs.get('device', 'cpu')) - self.client.eval() - - @staticmethod - def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): - return LocalBaseReranker(model_name, cache_dir=model_credential.get('cache_dir'), **model_kwargs) - - def compress_documents(self, documents: Sequence[Document], query: str, callbacks: Optional[Callbacks] = None) -> \ - Sequence[Document]: - if documents is None or len(documents) == 0: - return [] - with torch.no_grad(): - inputs = self.tokenizer([[query, document.page_content] for document in documents], padding=True, - truncation=True, return_tensors='pt', max_length=512) - scores = [torch.sigmoid(s).float().item() for s in - self.client(**inputs, return_dict=True).logits.view(-1, ).float()] - result = [Document(page_content=documents[index].page_content, metadata={'relevance_score': scores[index]}) - for index - in range(len(documents))] - result.sort(key=lambda row: row.metadata.get('relevance_score'), reverse=True) - return result diff --git a/apps/models_provider/impl/local_model_provider/model/reranker/__init__.py b/apps/models_provider/impl/local_model_provider/model/reranker/__init__.py new file mode 100644 index 00000000000..10d5b1bb68d --- /dev/null +++ b/apps/models_provider/impl/local_model_provider/model/reranker/__init__.py @@ -0,0 +1,14 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎虎 + @file: __init__.py.py + @date:2025/11/5 15:30 + @desc: +""" +import os + +if os.environ.get('SERVER_NAME', 'web') == 'local_model': + from .model import * +else: + from .web import * diff --git a/apps/models_provider/impl/local_model_provider/model/reranker/model.py b/apps/models_provider/impl/local_model_provider/model/reranker/model.py new file mode 100644 index 00000000000..a66776101b4 --- /dev/null +++ b/apps/models_provider/impl/local_model_provider/model/reranker/model.py @@ -0,0 +1,58 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎虎 + @file: model.py + @date:2025/11/5 15:30 + @desc: +""" + +from typing import Sequence, Optional, Dict, Any + +from langchain_core.callbacks import Callbacks +from langchain_core.documents import Document, BaseDocumentCompressor + +from models_provider.base_model_provider import MaxKBBaseModel + + +class LocalReranker(MaxKBBaseModel, BaseDocumentCompressor): + client: Any = None + tokenizer: Any = None + model: Optional[str] = None + cache_dir: Optional[str] = None + model_kwargs: Any = {} + + def __init__(self, model_name, cache_dir=None, **model_kwargs): + super().__init__() + from transformers import AutoModelForSequenceClassification, AutoTokenizer + self.model = model_name + self.cache_dir = cache_dir + self.model_kwargs = model_kwargs + self.client = AutoModelForSequenceClassification.from_pretrained(self.model, cache_dir=self.cache_dir) + self.tokenizer = AutoTokenizer.from_pretrained(self.model, cache_dir=self.cache_dir) + self.client = self.client.to(self.model_kwargs.get('device', 'cpu')) + self.client.eval() + + @staticmethod + def is_cache_model(): + return False + + @staticmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + return LocalReranker(model_name, cache_dir=model_credential.get('cache_dir')) + + def compress_documents(self, documents: Sequence[Document], query: str, callbacks: Optional[Callbacks] = None) -> \ + Sequence[Document]: + if documents is None or len(documents) == 0: + return [] + import torch + with torch.no_grad(): + inputs = self.tokenizer([[query, document.page_content] for document in documents], padding=True, + truncation=True, return_tensors='pt', max_length=512) + scores = [torch.sigmoid(s).float().item() for s in + self.client(**inputs, return_dict=True).logits.view(-1, ).float()] + result = [Document(page_content=documents[index].page_content, metadata={'relevance_score': scores[index]}) + for index + in range(len(documents))] + result.sort(key=lambda row: row.metadata.get('relevance_score'), reverse=True) + return result diff --git a/apps/models_provider/impl/local_model_provider/model/reranker/web.py b/apps/models_provider/impl/local_model_provider/model/reranker/web.py new file mode 100644 index 00000000000..45ab6978a3a --- /dev/null +++ b/apps/models_provider/impl/local_model_provider/model/reranker/web.py @@ -0,0 +1,52 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎虎 + @file: web.py + @date:2025/11/5 15:30 + @desc: +""" +from typing import Sequence, Optional, Dict + +import requests +from anthropic import BaseModel +from langchain_core.callbacks import Callbacks +from langchain_core.documents import Document, BaseDocumentCompressor + +from maxkb.const import CONFIG +from models_provider.base_model_provider import MaxKBBaseModel + + +class LocalReranker(MaxKBBaseModel, BaseModel, BaseDocumentCompressor): + + @staticmethod + def is_cache_model(): + return False + + @staticmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + return LocalReranker(model_type=model_type, model_name=model_name, model_credential=model_credential, + **model_kwargs) + + model_id: str = None + + def __init__(self, **kwargs): + super().__init__(**kwargs) + print('ssss', kwargs.get('model_id', None)) + self.model_id = kwargs.get('model_id', None) + + def compress_documents(self, documents: Sequence[Document], query: str, callbacks: Optional[Callbacks] = None) -> \ + Sequence[Document]: + if documents is None or len(documents) == 0: + return [] + prefix = CONFIG.get_admin_path() + bind = f'{CONFIG.get("LOCAL_MODEL_HOST")}:{CONFIG.get("LOCAL_MODEL_PORT")}' + res = requests.post( + f'{CONFIG.get("LOCAL_MODEL_PROTOCOL")}://{bind}{prefix}/api/model/{self.model_id}/compress_documents', + json={'documents': [{'page_content': document.page_content, 'metadata': document.metadata} for document in + documents], 'query': query}, headers={'Content-Type': 'application/json'}) + result = res.json() + if result.get('code', 500) == 200: + return [Document(page_content=document.get('page_content'), metadata=document.get('metadata')) for document + in result.get('data')] + raise Exception(result.get('message')) diff --git a/apps/models_provider/impl/ollama_model_provider/model/reranker.py b/apps/models_provider/impl/ollama_model_provider/model/reranker.py index 22d6e63d222..57435961be2 100644 --- a/apps/models_provider/impl/ollama_model_provider/model/reranker.py +++ b/apps/models_provider/impl/ollama_model_provider/model/reranker.py @@ -1,12 +1,12 @@ -from typing import Sequence, Optional, Any, Dict +from typing import Sequence, Optional, Dict from langchain_community.embeddings import OllamaEmbeddings from langchain_core.callbacks import Callbacks from langchain_core.documents import Document -from models_provider.base_model_provider import MaxKBBaseModel -from sklearn.metrics.pairwise import cosine_similarity from pydantic import BaseModel, Field +from models_provider.base_model_provider import MaxKBBaseModel + class OllamaReranker(MaxKBBaseModel, OllamaEmbeddings, BaseModel): top_n: Optional[int] = Field(3, description="Number of top documents to return") @@ -22,6 +22,7 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], ** def compress_documents(self, documents: Sequence[Document], query: str, callbacks: Optional[Callbacks] = None) -> \ Sequence[Document]: + from sklearn.metrics.pairwise import cosine_similarity """Rank documents based on their similarity to the query. Args: @@ -37,7 +38,7 @@ def compress_documents(self, documents: Sequence[Document], query: str, callback document_embeddings = self.embed_documents(documents) # 计算相似度 similarities = cosine_similarity([query_embedding], document_embeddings)[0] - ranked_docs = [(doc,_) for _, doc in sorted(zip(similarities, documents), reverse=True)][:self.top_n] + ranked_docs = [(doc, _) for _, doc in sorted(zip(similarities, documents), reverse=True)][:self.top_n] return [ Document( page_content=doc, # 第一个值是文档内容 @@ -45,5 +46,3 @@ def compress_documents(self, documents: Sequence[Document], query: str, callback ) for doc, score in ranked_docs ] - - diff --git a/apps/models_provider/impl/tencent_cloud_model_provider/model/llm.py b/apps/models_provider/impl/tencent_cloud_model_provider/model/llm.py index dc962e49119..9d097dcde4e 100644 --- a/apps/models_provider/impl/tencent_cloud_model_provider/model/llm.py +++ b/apps/models_provider/impl/tencent_cloud_model_provider/model/llm.py @@ -6,9 +6,7 @@ @date:2024/4/18 15:28 @desc: """ -from typing import List, Dict - -from langchain_core.messages import BaseMessage, get_buffer_string +from typing import Dict from common.config.tokenizer_manage_config import TokenizerManage from models_provider.base_model_provider import MaxKBBaseModel diff --git a/apps/models_provider/impl/vllm_model_provider/credential/whisper_stt.py b/apps/models_provider/impl/vllm_model_provider/credential/whisper_stt.py index f65b38eacb9..5844d0a4d4f 100644 --- a/apps/models_provider/impl/vllm_model_provider/credential/whisper_stt.py +++ b/apps/models_provider/impl/vllm_model_provider/credential/whisper_stt.py @@ -13,7 +13,7 @@ class VLLMWhisperModelParams(BaseForm): Language = forms.TextInputField( - TooltipLabel(_('Language'), + TooltipLabel(_('language'), _("If not passed, the default value is 'zh'")), required=True, default_value='zh', diff --git a/apps/models_provider/impl/vllm_model_provider/model/whisper_sst.py b/apps/models_provider/impl/vllm_model_provider/model/whisper_sst.py index 922d934a8d8..ca502c4b0d4 100644 --- a/apps/models_provider/impl/vllm_model_provider/model/whisper_sst.py +++ b/apps/models_provider/impl/vllm_model_provider/model/whisper_sst.py @@ -52,11 +52,11 @@ def speech_to_text(self, audio_file): api_key=self.api_key, base_url=base_url ) - + buf = audio_file.read() filter_params = {k: v for k, v in self.params.items() if k not in {'model_id', 'use_local', 'streaming'}} transcription_params = { 'model': self.model, - 'file': audio_file, + 'file': buf, 'language': 'zh', } result = client.audio.transcriptions.create( diff --git a/apps/ops/celery/signal_handler.py b/apps/ops/celery/signal_handler.py index 91cba3ed0db..d592803fc59 100644 --- a/apps/ops/celery/signal_handler.py +++ b/apps/ops/celery/signal_handler.py @@ -17,12 +17,28 @@ logger = logging.getLogger(__file__) safe_str = lambda x: x +def init_scheduler(): + from common import job + + job.run() + + try: + from xpack import job as xpack_job + + xpack_job.run() + except ImportError: + pass + + @worker_ready.connect def on_app_ready(sender=None, headers=None, **kwargs): if cache.get("CELERY_APP_READY", 0) == 1: return cache.set("CELERY_APP_READY", 1, 10) + # 初始化定时任务 + init_scheduler() + tasks = get_after_app_ready_tasks() logger.debug("Work ready signal recv") logger.debug("Start need start task: [{}]".format(", ".join(tasks))) diff --git a/apps/users/apps.py b/apps/users/apps.py index 72b1401065b..c115a830a28 100644 --- a/apps/users/apps.py +++ b/apps/users/apps.py @@ -4,3 +4,6 @@ class UsersConfig(AppConfig): default_auto_field = 'django.db.models.BigAutoField' name = 'users' + + def ready(self): + from ops.celery import signal_handler \ No newline at end of file diff --git a/installer/Dockerfile b/installer/Dockerfile index c5ba2e5e312..790673d4488 100644 --- a/installer/Dockerfile +++ b/installer/Dockerfile @@ -9,11 +9,12 @@ RUN cd ui && ls -la && if [ -d "dist" ]; then exit 0; fi && \ FROM ghcr.io/1panel-dev/maxkb-base:python3.11-pg17.6 AS stage-build COPY --chmod=700 . /opt/maxkb-app RUN apt-get update && \ - apt-get install -y --no-install-recommends gettext libexpat1-dev libffi-dev && \ + apt-get install -y --no-install-recommends gcc g++ gettext libexpat1-dev libffi-dev && \ apt-get clean all && \ rm -rf /var/lib/apt/lists/* WORKDIR /opt/maxkb-app -RUN rm -rf /opt/maxkb-app/ui && \ +RUN gcc -shared -fPIC -o ${MAXKB_SANDBOX_HOME}/sandbox.so /opt/maxkb-app/installer/sandbox.c -ldl && \ + rm -rf /opt/maxkb-app/ui && \ pip install uv --break-system-packages && \ python -m uv pip install -r pyproject.toml && \ find /opt/maxkb-app -depth \( -name ".git*" -o -name ".docker*" -o -name ".idea*" -o -name ".editorconfig*" -o -name ".prettierrc*" -o -name "README.md" -o -name "poetry.lock" -o -name "pyproject.toml" \) -exec rm -rf {} + && \ diff --git a/installer/Dockerfile-base b/installer/Dockerfile-base index 187980c0833..1768090629e 100644 --- a/installer/Dockerfile-base +++ b/installer/Dockerfile-base @@ -1,7 +1,7 @@ FROM python:3.11-slim-trixie AS python-stage RUN python3 -m venv /opt/py3 -FROM ghcr.io/1panel-dev/maxkb-vector-model:v2.0.2 AS vector-model +FROM ghcr.io/1panel-dev/maxkb-vector-model:v2.0.3 AS vector-model FROM postgres:17.6-trixie COPY --from=python-stage /usr/local /usr/local @@ -31,7 +31,6 @@ RUN ln -sf /usr/share/zoneinfo/Asia/Shanghai /etc/localtime && \ chmod g+xr /usr/bin/ld.so && \ chmod g+x /usr/local/bin/python* && \ apt-get clean all && \ - echo "/usr/lib/$(uname -m)-linux-gnu/libjemalloc.so.2" > /etc/ld.so.preload && \ rm -rf /var/lib/apt/lists/* /usr/share/doc/* /usr/share/man/* /usr/share/info/* /usr/share/locale/* /usr/share/lintian/* /usr/share/linda/* /var/cache/* /var/log/* /var/tmp/* /tmp/* COPY --from=vector-model --chmod=700 /opt/maxkb-app/model /opt/maxkb-app/model @@ -46,8 +45,10 @@ ENV PATH=/opt/py3/bin:$PATH \ MAXKB_CONFIG_TYPE=ENV \ MAXKB_LOG_LEVEL=INFO \ MAXKB_SANDBOX=1 \ + MAXKB_SANDBOX_HOME=/opt/maxkb-app/sandbox \ MAXKB_SANDBOX_PYTHON_PACKAGE_PATHS="/opt/py3/lib/python3.11/site-packages,/opt/maxkb-app/sandbox/python-packages,/opt/maxkb/python-packages" \ MAXKB_SANDBOX_PYTHON_BANNED_KEYWORDS="subprocess.,system(,exec(,execve(,pty.,eval(,compile(,shutil.,input(,__import__" \ + MAXKB_SANDBOX_PYTHON_BANNED_HOSTS="127.0.0.1,localhost,host.docker.internal,maxkb,pgsql,redis" \ MAXKB_ADMIN_PATH=/admin EXPOSE 6379 \ No newline at end of file diff --git a/installer/Dockerfile-vector-model b/installer/Dockerfile-vector-model index c73e0307923..6001ace553e 100644 --- a/installer/Dockerfile-vector-model +++ b/installer/Dockerfile-vector-model @@ -25,7 +25,10 @@ COPY --from=vector-model /opt/maxkb/app/model /opt/maxkb-app/model COPY --from=vector-model /opt/maxkb/app/model/base/hub /opt/maxkb-app/model/tokenizer COPY --from=tmp-stage1 model/tokenizer /opt/maxkb-app/model/tokenizer RUN rm -rf /opt/maxkb-app/model/embedding/shibing624_text2vec-base-chinese/onnx - +RUN apk add --update --no-cache curl && \ + mkdir -p openai-tiktoken-cl100k-base && \ + curl -Lf https://openaipublic.blob.core.windows.net/encodings/cl100k_base.tiktoken > openai-tiktoken-cl100k-base/cl100k_base.tiktoken && \ + mv -f openai-tiktoken-cl100k-base /opt/maxkb-app/model/tokenizer/ FROM scratch diff --git a/installer/sandbox.c b/installer/sandbox.c new file mode 100644 index 00000000000..9d3a7c928b5 --- /dev/null +++ b/installer/sandbox.c @@ -0,0 +1,144 @@ +#define _GNU_SOURCE +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +static const char *BANNED_FILE_NAME = ".SANDBOX_BANNED_HOSTS"; + +/** + * 从 .so 文件所在目录读取 .SANDBOX_BANNED_HOSTS 文件内容 + * 返回 malloc 出的字符串(需 free),读取失败则返回空字符串 + */ +static char *load_banned_hosts() { + Dl_info info; + if (dladdr((void *)load_banned_hosts, &info) == 0 || !info.dli_fname) { + fprintf(stderr, "[sandbox] ⚠️ Unable to locate shared object path — allowing all hosts\n"); + return strdup(""); + } + + char so_path[PATH_MAX]; + strncpy(so_path, info.dli_fname, sizeof(so_path)); + so_path[sizeof(so_path) - 1] = '\0'; + + char *dir = dirname(so_path); + char file_path[PATH_MAX]; + snprintf(file_path, sizeof(file_path), "%s/%s", dir, BANNED_FILE_NAME); + + FILE *fp = fopen(file_path, "r"); + if (!fp) { + fprintf(stderr, "[sandbox] ⚠️ Cannot open %s — allowing all hosts\n", file_path); + return strdup(""); + } + + char *buf = malloc(4096); + if (!buf) { + fclose(fp); + fprintf(stderr, "[sandbox] ⚠️ Memory allocation failed — allowing all hosts\n"); + return strdup(""); + } + + size_t len = fread(buf, 1, 4095, fp); + buf[len] = '\0'; + fclose(fp); + return buf; +} + +/** + * 精确匹配黑名单 + */ +static int match_env_patterns(const char *target, const char *env_val) { + if (!target || !env_val || !*env_val) return 0; + + char *patterns = strdup(env_val); + char *token = strtok(patterns, ","); + int matched = 0; + + while (token) { + // 去掉前后空格 + while (*token == ' ' || *token == '\t') token++; + char *end = token + strlen(token) - 1; + while (end > token && (*end == ' ' || *end == '\t')) *end-- = '\0'; + + if (*token) { + regex_t regex; + char fullpattern[512]; + snprintf(fullpattern, sizeof(fullpattern), "^%s$", token); + + if (regcomp(®ex, fullpattern, REG_EXTENDED | REG_NOSUB | REG_ICASE) == 0) { + if (regexec(®ex, target, 0, NULL, 0) == 0) { + matched = 1; + regfree(®ex); + break; + } + regfree(®ex); + } else { + fprintf(stderr, "[sandbox] ⚠️ Invalid regex '%s' — allowing host by default\n", token); + } + } + token = strtok(NULL, ","); + } + + free(patterns); + return matched; +} + +/** 拦截 connect() —— 精确匹配 IP */ +int connect(int sockfd, const struct sockaddr *addr, socklen_t addrlen) { + static int (*real_connect)(int, const struct sockaddr *, socklen_t) = NULL; + if (!real_connect) + real_connect = dlsym(RTLD_NEXT, "connect"); + + static char *banned_env = NULL; + if (!banned_env) banned_env = load_banned_hosts(); + + char ip[INET6_ADDRSTRLEN] = {0}; + if (addr->sa_family == AF_INET) + inet_ntop(AF_INET, &((struct sockaddr_in *)addr)->sin_addr, ip, sizeof(ip)); + else if (addr->sa_family == AF_INET6) + inet_ntop(AF_INET6, &((struct sockaddr_in6 *)addr)->sin6_addr, ip, sizeof(ip)); + + if (banned_env && *banned_env && match_env_patterns(ip, banned_env)) { + fprintf(stderr, "[sandbox] 🚫 Access to host %s is banned\n", ip); + errno = EACCES; // EACCES 的值是 13, 意思是 Permission denied + return -1; + } + + return real_connect(sockfd, addr, addrlen); +} + +/** 拦截 getaddrinfo() —— 只拦截域名,不拦截纯 IP */ +int getaddrinfo(const char *node, const char *service, + const struct addrinfo *hints, struct addrinfo **res) { + static int (*real_getaddrinfo)(const char *, const char *, + const struct addrinfo *, struct addrinfo **) = NULL; + if (!real_getaddrinfo) + real_getaddrinfo = dlsym(RTLD_NEXT, "getaddrinfo"); + + static char *banned_env = NULL; + if (!banned_env) banned_env = load_banned_hosts(); + + if (banned_env && *banned_env && node) { + // 检测 node 是否是 IP + struct in_addr ipv4; + struct in6_addr ipv6; + int is_ip = (inet_pton(AF_INET, node, &ipv4) == 1) || + (inet_pton(AF_INET6, node, &ipv6) == 1); + + // 只对“非IP的域名”进行屏蔽 + if (!is_ip && match_env_patterns(node, banned_env)) { + fprintf(stderr, "[sandbox] 🚫 Access to host %s is banned (DNS blocked)\n", node); + return EAI_FAIL; // 模拟 DNS 层禁止 + } + } + + return real_getaddrinfo(node, service, hints, res); +} diff --git a/main.py b/main.py index 0a300e9cb2d..f738764fbce 100644 --- a/main.py +++ b/main.py @@ -13,7 +13,6 @@ os.chdir(BASE_DIR) sys.path.insert(0, APP_DIR) os.environ.setdefault("DJANGO_SETTINGS_MODULE", "maxkb.settings") -django.setup() def collect_static(): @@ -53,7 +52,7 @@ def start_services(): if args.worker: start_args.extend(['--worker', str(args.worker)]) else: - worker = os.environ.get('CORE_WORKER') + worker = os.environ.get('MAXKB_CORE_WORKER') if isinstance(worker, str) and worker.isdigit(): start_args.extend(['--worker', worker]) @@ -74,7 +73,6 @@ def dev(): elif services.__contains__('celery'): management.call_command('celery', 'celery') elif services.__contains__('local_model'): - os.environ.setdefault('SERVER_NAME', 'local_model') from maxkb.const import CONFIG bind = f'{CONFIG.get("LOCAL_MODEL_HOST")}:{CONFIG.get("LOCAL_MODEL_PORT")}' management.call_command('runserver', bind) @@ -108,6 +106,12 @@ def dev(): parser.add_argument('-f', '--force', nargs="?", const=True) args = parser.parse_args() action = args.action + services = args.services if isinstance(args.services, list) else args.services + if services.__contains__('web'): + os.environ.setdefault('SERVER_NAME', 'web') + elif services.__contains__('local_model'): + os.environ.setdefault('SERVER_NAME', 'local_model') + django.setup() if action == "upgrade_db": perform_db_migrate() elif action == "collect_static": @@ -120,4 +124,3 @@ def dev(): collect_static() perform_db_migrate() start_services() - diff --git a/pyproject.toml b/pyproject.toml index 52e9d5fc846..9d4dd900f6c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ authors = [{ name = "shaohuzhang1", email = "shaohu.zhang@fit2cloud.com" }] requires-python = "~=3.11.0" readme = "README.md" dependencies = [ - "django==5.2.7", + "django==5.2.8", "drf-spectacular[sidecar]==0.28.0", "django-redis==6.0.0", "django-db-connection-pool==1.2.6", @@ -29,6 +29,7 @@ dependencies = [ "langchain-huggingface==0.3.0", "langchain-ollama==0.3.4", "langgraph==0.5.3", + "langchain_core==0.3.74", "torch==2.8.0", "sentence-transformers==5.0.0", "qianfan==0.4.12.3", diff --git a/ui/package.json b/ui/package.json index 932f853f700..de0fa11ae9f 100644 --- a/ui/package.json +++ b/ui/package.json @@ -24,17 +24,16 @@ "@logicflow/extension": "^1.2.27", "@vavt/cm-extension": "^1.9.1", "@vueuse/core": "^13.3.0", - "@wecom/jssdk": "^2.3.1", "axios": "^1.8.4", "cropperjs": "^1.6.2", "dingtalk-jsapi": "^3.1.0", "echarts": "^5.6.0", - "element-plus": "^2.10.2", + "element-plus": "^2.11.7", "file-saver": "^2.0.5", "highlight.js": "^11.11.1", "html-to-image": "^1.11.13", "html2canvas": "^1.4.1", - "jspdf": "^3.0.1", + "jspdf": "^3.0.3", "katex": "^0.16.10", "marked": "^12.0.2", "md-editor-v3": "^5.8.2", diff --git a/ui/src/components/ai-chat/index.vue b/ui/src/components/ai-chat/index.vue index 66f4df0fc94..02a18e4ec46 100644 --- a/ui/src/components/ai-chat/index.vue +++ b/ui/src/components/ai-chat/index.vue @@ -583,16 +583,17 @@ function chatMessage(chat?: any, problem?: string, re_chat?: boolean, other_para if (props.chatId === 'new') { emit('refresh', chartOpenId.value) } - if (props.type === 'debug-ai-chat') { - getSourceDetail(chat) - } else { - if ( - props.applicationDetails && - (props.applicationDetails.show_exec || props.applicationDetails.show_source) - ) { - getSourceDetail(chat) - } - } + getSourceDetail(chat) + // if (props.type === 'debug-ai-chat') { + // getSourceDetail(chat) + // } else { + // if ( + // props.applicationDetails && + // (props.applicationDetails.show_exec || props.applicationDetails.show_source) + // ) { + // getSourceDetail(chat) + // } + // } }) .finally(() => { ChatManagement.close(chat.id) diff --git a/ui/src/components/app-charts/components/LineCharts.vue b/ui/src/components/app-charts/components/LineCharts.vue index 8e56944ffe5..d2cd2951e59 100644 --- a/ui/src/components/app-charts/components/LineCharts.vue +++ b/ui/src/components/app-charts/components/LineCharts.vue @@ -122,7 +122,6 @@ onMounted(() => { }) onBeforeUnmount(() => { - // echarts?.getInstanceByDom(document.getElementById(props.id)!)?.dispose() window.removeEventListener('resize', changeChartSize) }) diff --git a/ui/src/components/common-list/index.vue b/ui/src/components/common-list/index.vue index 124d81f8416..652061f4baa 100644 --- a/ui/src/components/common-list/index.vue +++ b/ui/src/components/common-list/index.vue @@ -79,7 +79,7 @@ defineExpose({ line-height: 24px; &.active { background: var(--el-color-primary-light-9); - border-radius: var(); + border-radius: var(--app-border-radius-small); color: var(--el-color-primary); font-weight: 500; &:hover { @@ -87,7 +87,7 @@ defineExpose({ } } &:hover { - border-radius: var(); + border-radius: var(--app-border-radius-small); background: var(--app-text-color-light-1); } &.is-active { diff --git a/ui/src/components/folder-breadcrumb/index.vue b/ui/src/components/folder-breadcrumb/index.vue index 8517c19e855..7dce4985b5e 100644 --- a/ui/src/components/folder-breadcrumb/index.vue +++ b/ui/src/components/folder-breadcrumb/index.vue @@ -3,14 +3,34 @@ {{ breadcrumbData[0]?.name }} - -
- {{ item.name }} -
- - {{ item.name }} - -
+ +
diff --git a/ui/src/components/folder-tree/MoveToDialog.vue b/ui/src/components/folder-tree/MoveToDialog.vue index 861e5089f6b..af61f6a3761 100644 --- a/ui/src/components/folder-tree/MoveToDialog.vue +++ b/ui/src/components/folder-tree/MoveToDialog.vue @@ -11,12 +11,6 @@ ref="treeRef" :source="source" :data="folderList" - :treeStyle="{ - height: 'calc(100vh - 320px)', - border: '1px solid #ebeef5', - borderRadius: '6px', - padding: '8px', - }" :default-expanded-keys="[currentNodeKey]" :canOperation="false" class="move-to-dialog-tree" @@ -36,6 +30,7 @@ @@ -47,17 +92,34 @@ const showBack = computed(() => { &__left { position: relative; box-sizing: border-box; - transition: width 0.28s; + // transition: width 0.28s; width: var(--sidebar-width); - min-width: var(--sidebar-width); - box-sizing: border-box; + .splitter-bar-line { + z-index: 1; + position: absolute; + top: 0; + right: 0; + cursor: col-resize; + width: 4px; + height: 100%; + &.hover:after { + width: 1px; + height: 100%; + content: ''; + z-index: 2; + position: absolute; + right: -1px; + top: 0; + background: var(--el-color-primary); + } + } .collapse { position: absolute; top: 36px; right: -12px; box-shadow: 0px 5px 10px 0px var(--app-text-color-light-1); - z-index: 1; + z-index: 2; } .layout-container__left_content { diff --git a/ui/src/layout/layout-header/avatar/index.vue b/ui/src/layout/layout-header/avatar/index.vue index a7358c262d4..7a71d08d3e7 100644 --- a/ui/src/layout/layout-header/avatar/index.vue +++ b/ui/src/layout/layout-header/avatar/index.vue @@ -4,7 +4,7 @@ - {{ user.userInfo?.username }} + {{ user.userInfo?.nick_name }} @@ -19,7 +19,7 @@
-

{{ user.userInfo?.username }}

+

{{ user.userInfo?.nick_name }}({{ user.userInfo?.username }})

@@ -198,15 +196,17 @@ function getList() { type: 'tool', isShared: folder_id === 'share', systemType: 'workspace', - }).getToolList({ - folder_id: folder_id, - tool_type: 'CUSTOM' - }).then((res: any) => { - toolList.value = res.data?.tools || res.data || [] - toolList.value = toolList.value?.filter((item: any) => item.is_active) - searchData.value = res.data.tools || res.data - searchData.value = searchData.value?.filter((item: any) => item.is_active) }) + .getToolList({ + folder_id: folder_id, + tool_type: 'CUSTOM', + }) + .then((res: any) => { + toolList.value = res.data?.tools || res.data || [] + toolList.value = toolList.value?.filter((item: any) => item.is_active) + searchData.value = res.data.tools || res.data + searchData.value = searchData.value?.filter((item: any) => item.is_active) + }) } defineExpose({ open }) diff --git a/ui/src/views/application/index.vue b/ui/src/views/application/index.vue index 34a0a197f5f..1e43f7db940 100644 --- a/ui/src/views/application/index.vue +++ b/ui/src/views/application/index.vue @@ -1,16 +1,16 @@