Skip to content

Commit e8f8009

Browse files
committed
refactor: model
1 parent d49f448 commit e8f8009

File tree

3 files changed

+20
-23
lines changed

3 files changed

+20
-23
lines changed

apps/models_provider/serializers/model_serializer.py

Lines changed: 11 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# -*- coding: utf-8 -*-
22
import json
3+
import os
34
import threading
45
import time
56
from typing import Dict
@@ -11,8 +12,11 @@
1112
from django.db.models.query_utils import Q
1213
from common.config.embedding_config import ModelManage
1314
from common.database_model_manage.database_model_manage import DatabaseModelManage
15+
from common.db.sql_execute import select_list
1416
from common.exception.app_exception import AppApiException
17+
from common.utils.common import get_file_content
1518
from common.utils.rsa_util import rsa_long_encrypt, rsa_long_decrypt
19+
from maxkb.conf import PROJECT_DIR
1620
from models_provider.base_model_provider import ValidCode, DownModelChunkStatus
1721
from models_provider.constants.model_provider_constants import ModelProvideConstants
1822
from models_provider.models import Model, Status
@@ -412,27 +416,13 @@ def save_model_params_form(self, model_params_form, with_valid=True):
412416
return True
413417

414418

415-
def get_authorized_tool(tool_query_set, workspace_id, model_workspace_authorization):
416-
# 对所有工作空间拉黑的工具
417-
non_auths = QuerySet(model_workspace_authorization).filter(
418-
Q(workspace_id='None') & Q(authentication_type='WHITE_LIST')
419-
).values_list('model_id', flat=True)
420-
# 授权给所有工作空间的工具
421-
all_auths = QuerySet(model_workspace_authorization).filter(
422-
Q(workspace_id='None') & Q(authentication_type='BLACK_LIST')
423-
).values_list('model_id', flat=True)
424-
# 查询白名单授权的工具
425-
white_authorized_tool_ids = QuerySet(model_workspace_authorization).filter(
426-
workspace_id=workspace_id, authentication_type='WHITE_LIST'
427-
).values_list('model_id', flat=True)
428-
# 查询黑名单授权的工具
429-
black_authorized_tool_ids = QuerySet(model_workspace_authorization).filter(
430-
workspace_id=workspace_id, authentication_type='BLACK_LIST'
431-
).values_list('model_id', flat=True)
419+
def get_authorized_tool(tool_query_set, workspace_id):
420+
model_id_list = select_list(get_file_content(
421+
os.path.join(PROJECT_DIR, "apps", "models_provider", 'sql',
422+
'list_share_authorized_model.sql'
423+
)), [workspace_id, workspace_id])
432424
tool_query_set = tool_query_set.filter(
433-
id__in=list(white_authorized_tool_ids) + list(all_auths)
434-
).exclude(
435-
id__in=list(black_authorized_tool_ids) + list(non_auths)
425+
id__in=[k.get('model_id') for k in model_id_list]
436426
)
437427
return tool_query_set
438428

@@ -471,8 +461,7 @@ def _build_queryset(self, workspace_id):
471461
if workspace_id:
472462
model_workspace_authorization = DatabaseModelManage.get_model("model_workspace_authorization")
473463
if model_workspace_authorization is not None:
474-
queryset = get_authorized_tool(queryset, workspace_id,
475-
model_workspace_authorization=model_workspace_authorization)
464+
queryset = get_authorized_tool(queryset, workspace_id)
476465

477466
for field in ['name', 'model_type', 'model_name', 'provider', 'create_user']:
478467
value = self.data.get(field)
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
select model_id
2+
from model_workspace_authorization
3+
where case
4+
when authentication_type = 'WHITE_LIST' then
5+
%s = any (workspace_id_list)
6+
else
7+
not %s = any(workspace_id_list)
8+
end

ui/src/locales/lang/zh-CN/views/model.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ export default {
99
},
1010
tip: {
1111
createSuccessMessage: '创建模型成功',
12-
createErrorMessage: '基础信息有填写错误',
12+
createErrorMessage: '基础信息填写错误',
1313
errorMessage: '变量已存在: ',
1414
emptyMessage1: '请先选择基础信息的模型类型和基础模型',
1515
emptyMessage2: '所选模型不支持参数设置',

0 commit comments

Comments
 (0)