|
1 | 1 | # -*- coding: utf-8 -*- |
2 | 2 | import json |
| 3 | +import os |
3 | 4 | import threading |
4 | 5 | import time |
5 | 6 | from typing import Dict |
|
11 | 12 | from django.db.models.query_utils import Q |
12 | 13 | from common.config.embedding_config import ModelManage |
13 | 14 | from common.database_model_manage.database_model_manage import DatabaseModelManage |
| 15 | +from common.db.sql_execute import select_list |
14 | 16 | from common.exception.app_exception import AppApiException |
| 17 | +from common.utils.common import get_file_content |
15 | 18 | from common.utils.rsa_util import rsa_long_encrypt, rsa_long_decrypt |
| 19 | +from maxkb.conf import PROJECT_DIR |
16 | 20 | from models_provider.base_model_provider import ValidCode, DownModelChunkStatus |
17 | 21 | from models_provider.constants.model_provider_constants import ModelProvideConstants |
18 | 22 | from models_provider.models import Model, Status |
@@ -412,27 +416,13 @@ def save_model_params_form(self, model_params_form, with_valid=True): |
412 | 416 | return True |
413 | 417 |
|
414 | 418 |
|
415 | | -def get_authorized_tool(tool_query_set, workspace_id, model_workspace_authorization): |
416 | | - # 对所有工作空间拉黑的工具 |
417 | | - non_auths = QuerySet(model_workspace_authorization).filter( |
418 | | - Q(workspace_id='None') & Q(authentication_type='WHITE_LIST') |
419 | | - ).values_list('model_id', flat=True) |
420 | | - # 授权给所有工作空间的工具 |
421 | | - all_auths = QuerySet(model_workspace_authorization).filter( |
422 | | - Q(workspace_id='None') & Q(authentication_type='BLACK_LIST') |
423 | | - ).values_list('model_id', flat=True) |
424 | | - # 查询白名单授权的工具 |
425 | | - white_authorized_tool_ids = QuerySet(model_workspace_authorization).filter( |
426 | | - workspace_id=workspace_id, authentication_type='WHITE_LIST' |
427 | | - ).values_list('model_id', flat=True) |
428 | | - # 查询黑名单授权的工具 |
429 | | - black_authorized_tool_ids = QuerySet(model_workspace_authorization).filter( |
430 | | - workspace_id=workspace_id, authentication_type='BLACK_LIST' |
431 | | - ).values_list('model_id', flat=True) |
| 419 | +def get_authorized_tool(tool_query_set, workspace_id): |
| 420 | + model_id_list = select_list(get_file_content( |
| 421 | + os.path.join(PROJECT_DIR, "apps", "models_provider", 'sql', |
| 422 | + 'list_share_authorized_model.sql' |
| 423 | + )), [workspace_id, workspace_id]) |
432 | 424 | tool_query_set = tool_query_set.filter( |
433 | | - id__in=list(white_authorized_tool_ids) + list(all_auths) |
434 | | - ).exclude( |
435 | | - id__in=list(black_authorized_tool_ids) + list(non_auths) |
| 425 | + id__in=[k.get('model_id') for k in model_id_list] |
436 | 426 | ) |
437 | 427 | return tool_query_set |
438 | 428 |
|
@@ -471,8 +461,7 @@ def _build_queryset(self, workspace_id): |
471 | 461 | if workspace_id: |
472 | 462 | model_workspace_authorization = DatabaseModelManage.get_model("model_workspace_authorization") |
473 | 463 | if model_workspace_authorization is not None: |
474 | | - queryset = get_authorized_tool(queryset, workspace_id, |
475 | | - model_workspace_authorization=model_workspace_authorization) |
| 464 | + queryset = get_authorized_tool(queryset, workspace_id) |
476 | 465 |
|
477 | 466 | for field in ['name', 'model_type', 'model_name', 'provider', 'create_user']: |
478 | 467 | value = self.data.get(field) |
|
0 commit comments