Skip to content

Commit 3450ef7

Browse files
authored
fix: Application import and export (#3538)
1 parent 857f988 commit 3450ef7

File tree

3 files changed

+56
-13
lines changed

3 files changed

+56
-13
lines changed

apps/application/serializers/application.py

Lines changed: 38 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import os
1414
import pickle
1515
import re
16+
from functools import reduce
1617
from typing import Dict, List
1718

1819
import uuid_utils.compat as uuid
@@ -33,7 +34,7 @@
3334
from common.db.search import native_search, native_page_search
3435
from common.exception.app_exception import AppApiException
3536
from common.field.common import UploadedFileField
36-
from common.utils.common import get_file_content, valid_license, restricted_loads
37+
from common.utils.common import get_file_content, valid_license, restricted_loads, generate_uuid
3738
from knowledge.models import Knowledge, KnowledgeScope
3839
from knowledge.serializers.knowledge import KnowledgeSerializer, KnowledgeModelSerializer
3940
from maxkb.conf import PROJECT_DIR
@@ -493,18 +494,37 @@ def import_(self, instance: dict, with_valid=True):
493494
except Exception as e:
494495
raise AppApiException(1001, _("Unsupported file format"))
495496
application = mk_instance.application
496-
497497
tool_list = mk_instance.get_tool_list()
498+
update_tool_map = {}
498499
if len(tool_list) > 0:
499-
tool_id_list = [tool.get('id') for tool in tool_list]
500+
tool_id_list = reduce(lambda x, y: [*x, *y],
501+
[[tool.get('id'), generate_uuid((tool.get('id') + tool.get('workspace_id') or ''))]
502+
for tool
503+
in
504+
tool_list], [])
505+
# 存在的工具列表
500506
exits_tool_id_list = [str(tool.id) for tool in
501-
QuerySet(Tool).filter(id__in=tool_id_list)]
502-
# 获取到需要插入的函数
503-
tool_list = [tool for tool in tool_id_list if
504-
not exits_tool_id_list.__contains__(tool.get('id'))]
505-
application_model = self.to_application(application, workspace_id, user_id)
507+
QuerySet(Tool).filter(id__in=tool_id_list, workspace_id=workspace_id)]
508+
# 需要更新的工具集合
509+
update_tool_map = {tool.get('id'): generate_uuid((tool.get('id') + tool.get('workspace_id') or '')) for tool
510+
in
511+
tool_list if
512+
not exits_tool_id_list.__contains__(
513+
tool.get('id'))}
514+
515+
tool_list = [{**tool, 'id': update_tool_map.get(tool.get('id'))} for tool in tool_list if
516+
not exits_tool_id_list.__contains__(
517+
tool.get('id')) and not exits_tool_id_list.__contains__(
518+
generate_uuid((tool.get('id') + tool.get('workspace_id') or '')))]
519+
application_model = self.to_application(application, workspace_id, user_id, update_tool_map)
506520
tool_model_list = [self.to_tool(f, workspace_id, user_id) for f in tool_list]
507521
application_model.save()
522+
# 插入授权数据
523+
UserResourcePermissionSerializer(data={
524+
'workspace_id': self.data.get('workspace_id'),
525+
'user_id': self.data.get('user_id'),
526+
'auth_target_type': AuthTargetType.APPLICATION.value
527+
}).auth_resource(str(application_model.id))
508528
# 插入认证信息
509529
ApplicationAccessToken(application_id=application_model.id,
510530
access_token=hashlib.md5(str(uuid.uuid7()).encode()).hexdigest()[8:24]).save()
@@ -526,18 +546,24 @@ def to_tool(tool, workspace_id, user_id):
526546
input_field_list=tool.get('input_field_list'),
527547
is_active=tool.get('is_active'),
528548
scope=ToolScope.WORKSPACE,
549+
folder_id=workspace_id,
529550
workspace_id=workspace_id)
530551

531552
@staticmethod
532-
def to_application(application, workspace_id, user_id):
553+
def to_application(application, workspace_id, user_id, update_tool_map):
533554
work_flow = application.get('work_flow')
534555
for node in work_flow.get('nodes', []):
556+
if node.get('type') == 'tool-lib-node':
557+
tool_lib_id = (node.get('properties', {}).get('node_data', {}).get('tool_lib_id') or '')
558+
node.get('properties', {}).get('node_data', {})['tool_lib_id'] = update_tool_map.get(tool_lib_id,
559+
tool_lib_id)
535560
if node.get('type') == 'search-knowledge-node':
536561
node.get('properties', {}).get('node_data', {})['knowledge_id_list'] = []
537562
return Application(id=uuid.uuid7(),
538563
user_id=user_id,
539564
name=application.get('name'),
540565
workspace_id=workspace_id,
566+
folder_id=workspace_id,
541567
desc=application.get('desc'),
542568
prologue=application.get('prologue'), dialogue_number=application.get('dialogue_number'),
543569
knowledge_setting=application.get('knowledge_setting'),
@@ -624,13 +650,13 @@ def export(self, with_valid=True):
624650
self.is_valid()
625651
application_id = self.data.get('application_id')
626652
application = QuerySet(Application).filter(id=application_id).first()
627-
tool_id_list = [node.get('properties', {}).get('node_data', {}).get('tool_id') for node
653+
tool_id_list = [node.get('properties', {}).get('node_data', {}).get('tool_lib_id') for node
628654
in
629655
application.work_flow.get('nodes', []) if
630-
node.get('type') == 'tool-node']
656+
node.get('type') == 'tool-lib-node']
631657
tool_list = []
632658
if len(tool_id_list) > 0:
633-
tool_list = QuerySet(Tool).filter(id__in=tool_id_list)
659+
tool_list = QuerySet(Tool).filter(id__in=tool_id_list).exclude(scope=ToolScope.SHARED)
634660
application_dict = ApplicationSerializerModel(application).data
635661

636662
mk_instance = MKInstance(application_dict,

apps/common/utils/common.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import random
1414
import re
1515
import shutil
16+
import uuid
1617
from functools import reduce
1718
from typing import List, Dict
1819

@@ -329,3 +330,7 @@ def parse_image(content: str):
329330
matches = re.finditer("!\[.*?\]\(\/oss\/(image|file)\/.*?\)", content)
330331
image_list = [match.group() for match in matches]
331332
return image_list
333+
334+
335+
def generate_uuid(tag: str):
336+
return str(uuid.uuid5(uuid.NAMESPACE_DNS, tag))

apps/tools/serializers/tool.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,18 @@ class Operate(serializers.Serializer):
287287
id = serializers.UUIDField(required=True, label=_('tool id'))
288288
workspace_id = serializers.CharField(required=True, label=_('workspace id'))
289289

290+
def is_one_valid(self, *, raise_exception=False):
291+
super().is_valid(raise_exception=True)
292+
workspace_id = self.data.get('workspace_id')
293+
query_set = QuerySet(Tool).filter(id=self.data.get('id'))
294+
if workspace_id:
295+
query_set = query_set.filter(workspace_id=workspace_id)
296+
if not query_set.exists():
297+
get_authorized_tool = DatabaseModelManage.get_model('get_authorized_tool')
298+
if get_authorized_tool:
299+
return get_authorized_tool(QuerySet(Tool).filter(id=self.data.get('id')), workspace_id).exists()
300+
raise AppApiException(500, _('Tool id does not exist'))
301+
290302
def is_valid(self, *, raise_exception=False):
291303
super().is_valid(raise_exception=True)
292304
workspace_id = self.data.get('workspace_id')
@@ -337,7 +349,7 @@ def delete(self):
337349
QuerySet(Tool).filter(id=self.data.get('id')).delete()
338350

339351
def one(self):
340-
self.is_valid(raise_exception=True)
352+
self.is_one_valid(raise_exception=True)
341353
tool = QuerySet(Tool).filter(id=self.data.get('id')).first()
342354
if tool.init_params:
343355
tool.init_params = json.loads(rsa_long_decrypt(tool.init_params))

0 commit comments

Comments
 (0)