Skip to content

Commit 79fbc52

Browse files
committed
feat: add search document node functionality and related configurations
1 parent ff763c4 commit 79fbc52

File tree

17 files changed

+734
-8
lines changed

17 files changed

+734
-8
lines changed

apps/application/flow/step_node/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from .mcp_node import BaseMcpNode
2424
from .question_node import *
2525
from .reranker_node import *
26+
from .search_document_node import BaseSearchDocumentNode
2627
from .search_knowledge_node import *
2728
from .speech_to_text_step_node import BaseSpeechToTextNode
2829
from .start_node import *
@@ -34,7 +35,7 @@
3435
from .variable_splitting_node import BaseVariableSplittingNode
3536
from .video_understand_step_node import BaseVideoUnderstandNode
3637

37-
node_list = [BaseStartStepNode, BaseChatNode, BaseSearchKnowledgeNode, BaseQuestionNode,
38+
node_list = [BaseStartStepNode, BaseChatNode, BaseSearchKnowledgeNode, BaseSearchDocumentNode, BaseQuestionNode,
3839
BaseConditionNode, BaseReplyNode,
3940
BaseToolNodeNode, BaseToolLibNodeNode, BaseRerankerNode, BaseApplicationNode,
4041
BaseDocumentExtractNode,
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .impl import *
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# coding=utf-8
2+
from typing import Type, List
3+
4+
from django.utils.translation import gettext_lazy as _
5+
from rest_framework import serializers
6+
7+
from application.flow.i_step_node import INode, NodeResult
8+
9+
10+
class SearchDocumentStepNodeSerializer(serializers.Serializer):
11+
knowledge_id_list = serializers.ListField(
12+
required=False, child=serializers.UUIDField(required=True),
13+
label=_("knowledge id list"), default=list
14+
)
15+
search_mode = serializers.ChoiceField(
16+
required=False, choices=['auto', 'custom'], label=_("search mode"), default='auto'
17+
)
18+
search_scope_type = serializers.ChoiceField(
19+
required=False, choices=['custom', 'referencing'], label=_("search scope type"),
20+
allow_null=True, default='custom'
21+
)
22+
search_scope_source = serializers.ChoiceField(
23+
required=False, choices=['document', 'knowledge'],
24+
label=_("search scope variable type"), default='knowledge'
25+
)
26+
search_scope_reference = serializers.ListField(
27+
required=False, label=_("search scope variable"), default=list
28+
)
29+
question_reference = serializers.ListField(
30+
required=False, label=_("question reference address"), default=list
31+
)
32+
search_condition_type = serializers.ChoiceField(
33+
required=False, choices=['AND', 'OR'], label=_("search condition type"), default='AND'
34+
)
35+
search_condition_list = serializers.ListField(
36+
required=False, label=_("search condition list"), default=list
37+
)
38+
39+
def is_valid(self, *, raise_exception=False):
40+
super().is_valid(raise_exception=True)
41+
42+
43+
class ISearchDocumentStepNode(INode):
44+
type = 'search-document-node'
45+
46+
def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
47+
return SearchDocumentStepNodeSerializer
48+
49+
def _run(self):
50+
return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data)
51+
52+
def execute(self, knowledge_id_list: List, search_mode: str, search_scope_type: str, search_scope_source: str,
53+
search_scope_reference: List, question_reference: List, search_condition_type: str,
54+
search_condition_list: List,
55+
**kwargs) -> NodeResult:
56+
pass
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .base_search_document_node import BaseSearchDocumentNode
Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
# coding=utf-8
2+
from typing import List
3+
4+
import jieba
5+
from django.db.models import Q
6+
from django.db.models import QuerySet
7+
8+
from application.flow.i_step_node import NodeResult
9+
from application.flow.step_node.search_document_node.i_search_document_node import ISearchDocumentStepNode
10+
from knowledge.models import Document, DocumentTag, Knowledge
11+
12+
13+
class BaseSearchDocumentNode(ISearchDocumentStepNode):
14+
def save_context(self, details, workflow_manage):
15+
self.context['document_list'] = details.get('document_list')
16+
self.context['knowledge_list'] = details.get('knowledge_list')
17+
self.context['document_items'] = details.get('document_items')
18+
self.context['knowledge_items'] = details.get('knowledge_items')
19+
self.context['question'] = details.get('question')
20+
self.context['run_time'] = details.get('run_time')
21+
22+
def get_reference_content(self, fields: List[str]):
23+
return self.workflow_manage.get_reference_field(fields[0], fields[1:])
24+
25+
def execute(self, knowledge_id_list: List, search_mode: str, search_scope_type: str, search_scope_source: str,
26+
search_scope_reference: List, question_reference: List, search_condition_type: str,
27+
search_condition_list: List,
28+
**kwargs) -> NodeResult:
29+
30+
if search_scope_type == 'custom': # 手动选择知识库
31+
document_id_list = QuerySet(Document).filter(
32+
knowledge_id__in=knowledge_id_list
33+
).values_list('id', flat=True)
34+
else: # 引用上一步知识库/文档
35+
if search_scope_source == 'document': # 文档
36+
document_id_list = self.get_reference_content(search_scope_reference)
37+
else: # 知识库
38+
document_id_list = QuerySet(Document).filter(
39+
knowledge_id__in=self.get_reference_content(search_scope_reference)
40+
).values_list('id', flat=True)
41+
42+
if search_mode == 'auto': # 通过问题自动检索
43+
matched_doc_ids = self.handle_auto_tags(document_id_list, question_reference)
44+
45+
final_document_ids = list(matched_doc_ids)
46+
else: # 自定义检索条件
47+
matched_document_ids = self.handle_custom_tags(
48+
document_id_list, search_condition_list, search_condition_type
49+
)
50+
51+
final_document_ids = list(matched_document_ids)
52+
53+
# UUID to str
54+
final_document_ids = [str(doc_id) for doc_id in final_document_ids]
55+
document_items = QuerySet(Document).filter(id__in=final_document_ids).values()
56+
final_knowledge_ids = list(set(str(doc['knowledge_id']) for doc in document_items))
57+
knowledge_items = QuerySet(Knowledge).filter(id__in=final_knowledge_ids).values()
58+
59+
return NodeResult({
60+
'document_list': final_document_ids,
61+
'document_items': list(document_items),
62+
'knowledge_list': final_knowledge_ids,
63+
'knowledge_items': list(knowledge_items)
64+
}, {})
65+
66+
def handle_auto_tags(self, document_id_list: list, question_reference: list):
67+
question = self.get_reference_content(question_reference)
68+
69+
# 使用jieba分词
70+
keywords = jieba.lcut(question)
71+
if not keywords:
72+
return set()
73+
74+
# 构建OR查询,一次性获取所有匹配的文档
75+
q_objects = Q()
76+
for keyword in keywords:
77+
q_objects |= Q(tag__value__icontains=keyword)
78+
79+
# 单次数据库查询
80+
matched_doc_ids = set(
81+
QuerySet(DocumentTag)
82+
.filter(document_id__in=document_id_list)
83+
.filter(q_objects)
84+
.values_list('document_id', flat=True)
85+
.distinct()
86+
)
87+
88+
return matched_doc_ids
89+
90+
def handle_custom_tags(self, document_id_list: List, search_condition_list: list, search_condition_type: str):
91+
92+
if not search_condition_list:
93+
return set(document_id_list)
94+
95+
if search_condition_type == 'AND':
96+
# AND逻辑:使用子查询和聚合
97+
matched_doc_ids = set(document_id_list)
98+
99+
for condition in search_condition_list:
100+
tag_key = condition['key']
101+
field_value = self.workflow_manage.generate_prompt(condition['value'])
102+
compare_type = condition['compare']
103+
104+
# 构建查询条件
105+
if compare_type == 'contain':
106+
q_filter = Q(tag__key=tag_key, tag__value__icontains=field_value)
107+
elif compare_type == 'eq':
108+
q_filter = Q(tag__key=tag_key, tag__value=field_value)
109+
elif compare_type == 'not_contain':
110+
q_filter = ~Q(tag__key=tag_key, tag__value__icontains=field_value)
111+
else:
112+
continue
113+
114+
# 单次查询获取符合条件的文档
115+
tag_docs = set(QuerySet(DocumentTag).filter(
116+
document_id__in=matched_doc_ids
117+
).filter(q_filter).values_list('document_id', flat=True).distinct())
118+
119+
matched_doc_ids = matched_doc_ids.intersection(tag_docs)
120+
121+
return matched_doc_ids
122+
123+
else:
124+
# OR逻辑:使用一次查询完成
125+
q_objects = Q()
126+
127+
for condition in search_condition_list:
128+
tag_key = condition['key']
129+
field_value = self.workflow_manage.generate_prompt(condition['value'])
130+
compare_type = condition['compare']
131+
132+
if compare_type == 'contain':
133+
q_objects |= Q(tag__key=tag_key, tag__value__icontains=field_value)
134+
elif compare_type == 'eq':
135+
q_objects |= Q(tag__key=tag_key, tag__value=field_value)
136+
elif compare_type == 'not_contain':
137+
q_objects |= ~Q(tag__key=tag_key, tag__value__icontains=field_value)
138+
139+
# 一次查询获取所有匹配的文档
140+
matched_docs = set(QuerySet(DocumentTag).filter(
141+
document_id__in=document_id_list
142+
).filter(q_objects).values_list('document_id', flat=True).distinct())
143+
144+
return matched_docs
145+
146+
def get_details(self, index: int, **kwargs):
147+
return {
148+
'name': self.node.properties.get('stepName'),
149+
'question': self.context.get('question'),
150+
"index": index,
151+
'run_time': self.context.get('run_time'),
152+
'document_list': self.context.get('document_list'),
153+
'knowledge_list': self.context.get('knowledge_list'),
154+
'document_items': self.context.get('document_items'),
155+
'knowledge_items': self.context.get('knowledge_items'),
156+
'type': self.node.type,
157+
'status': self.status,
158+
'err_message': self.err_message
159+
}

apps/knowledge/serializers/knowledge.py

Lines changed: 54 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import os
44
import re
55
import traceback
6+
from collections import defaultdict
67
from functools import reduce
78
from tempfile import TemporaryDirectory
89
from typing import Dict, List
@@ -13,6 +14,7 @@
1314
from django.db import transaction, models
1415
from django.db.models import QuerySet
1516
from django.db.models.functions import Reverse, Substr
17+
from django.db.models.query_utils import Q
1618
from django.http import HttpResponse
1719
from django.utils.translation import gettext_lazy as _
1820
from rest_framework import serializers
@@ -29,8 +31,9 @@
2931
from common.utils.logger import maxkb_logger
3032
from common.utils.split_model import get_split_model
3133
from knowledge.models import Knowledge, KnowledgeScope, KnowledgeType, Document, Paragraph, Problem, \
32-
ProblemParagraphMapping, TaskType, State, SearchMode, KnowledgeFolder, File
33-
from knowledge.serializers.common import ProblemParagraphManage, drop_knowledge_index, get_embedding_model_id_by_knowledge_id, MetaSerializer, \
34+
ProblemParagraphMapping, TaskType, State, SearchMode, KnowledgeFolder, File, Tag
35+
from knowledge.serializers.common import ProblemParagraphManage, drop_knowledge_index, \
36+
get_embedding_model_id_by_knowledge_id, MetaSerializer, \
3437
GenerateRelatedSerializer, get_embedding_model_by_knowledge_id, list_paragraph, write_image, zip_dir
3538
from knowledge.serializers.document import DocumentSerializers
3639
from knowledge.task.embedding import embedding_by_knowledge, delete_embedding_by_knowledge
@@ -148,7 +151,8 @@ def get_query_set(self, workspace_manage, is_x_pack_ee):
148151
if "workspace_id" in self.data and self.data.get('workspace_id') is not None:
149152
query_set = query_set.filter(**{'temp.workspace_id': self.data.get("workspace_id")})
150153
folder_query_set = folder_query_set.filter(**{'workspace_id': self.data.get("workspace_id")})
151-
if "folder_id" in self.data and self.data.get('folder_id') is not None and self.data.get('workspace_id') != self.data.get('folder_id'):
154+
if "folder_id" in self.data and self.data.get('folder_id') is not None and self.data.get(
155+
'workspace_id') != self.data.get('folder_id'):
152156
query_set = query_set.filter(**{'temp.folder_id': self.data.get("folder_id")})
153157
folder_query_set = folder_query_set.filter(**{'parent_id': self.data.get("folder_id")})
154158
if "scope" in self.data and self.data.get('scope') is not None:
@@ -764,3 +768,50 @@ def hit_test(self):
764768
'comprehensive_score': hit_dict.get(p.get('id')).get('comprehensive_score')
765769
} for p in p_list
766770
]
771+
772+
class Tags(serializers.Serializer):
773+
workspace_id = serializers.CharField(required=True, label=_('workspace id'))
774+
user_id = serializers.UUIDField(required=True, label=_('user id'))
775+
knowledge_ids = serializers.ListField(
776+
required=True, label=_('knowledge ids'),
777+
child=serializers.UUIDField(required=True, label=_('id'))
778+
)
779+
780+
def list(self):
781+
self.is_valid(raise_exception=True)
782+
if self.data.get('name'):
783+
name = self.data.get('name')
784+
tags = QuerySet(Tag).filter(
785+
knowledge_id__in=self.data.get('knowledge_ids')
786+
).filter(
787+
Q(key__icontains=name) | Q(value__icontains=name)
788+
).values('key', 'value', 'id', 'create_time', 'update_time').order_by('create_time', 'key', 'value')
789+
else:
790+
# 获取所有标签,按创建时间排序保持稳定顺序
791+
tags = QuerySet(Tag).filter(
792+
knowledge_id__in=self.data.get('knowledge_ids')
793+
).values('key', 'value', 'id', 'create_time', 'update_time').order_by('create_time', 'key', 'value')
794+
795+
# 按key分组
796+
grouped_tags = defaultdict(list)
797+
for tag in tags:
798+
grouped_tags[tag['key']].append({
799+
'id': tag['id'],
800+
'value': tag['value'],
801+
'create_time': tag['create_time'],
802+
'update_time': tag['update_time']
803+
})
804+
805+
# 转换为期望的格式,保持key的顺序
806+
result = []
807+
# 按key排序以确保结果顺序一致
808+
for key in sorted(grouped_tags.keys()):
809+
values = grouped_tags[key]
810+
# 按创建时间对values进行排序
811+
values.sort(key=lambda x: x['create_time'])
812+
result.append({
813+
'key': key,
814+
'values': values,
815+
})
816+
817+
return result

apps/knowledge/urls.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
path('workspace/<str:workspace_id>/knowledge/web', views.KnowledgeWebView.as_view()),
1313
path('workspace/<str:workspace_id>/knowledge/model', views.KnowledgeView.Model.as_view()),
1414
path('workspace/<str:workspace_id>/knowledge/embedding_model', views.KnowledgeView.EmbeddingModel.as_view()),
15+
path('workspace/<str:workspace_id>/knowledge/tags', views.KnowledgeView.Tags.as_view()),
1516
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>', views.KnowledgeView.Operate.as_view()),
1617
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/sync', views.KnowledgeView.SyncWeb.as_view()),
1718
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/generate_related', views.KnowledgeView.GenerateRelated.as_view()),

apps/knowledge/views/knowledge.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -382,6 +382,29 @@ def get(self, request: Request, workspace_id: str):
382382
}
383383
).list(workspace_id, True))
384384

385+
class Tags(APIView):
386+
authentication_classes = [TokenAuth]
387+
388+
@extend_schema(
389+
methods=['GET'],
390+
description=_('Get all tags of knowledge base'),
391+
summary=_('Get all tags of knowledge base'),
392+
operation_id=_('Get all tags of knowledge base'), # type: ignore
393+
parameters=KnowledgeReadAPI.get_parameters(),
394+
responses=KnowledgeReadAPI.get_response(),
395+
tags=[_('Knowledge Base')] # type: ignore
396+
)
397+
@has_permissions(
398+
PermissionConstants.KNOWLEDGE_READ.get_workspace_permission(),
399+
RoleConstants.WORKSPACE_MANAGE.get_workspace_role(), RoleConstants.USER.get_workspace_role()
400+
)
401+
def get(self, request: Request, workspace_id: str):
402+
return result.success(KnowledgeSerializer.Tags(data={
403+
'user_id': request.user.id,
404+
'workspace_id': workspace_id,
405+
'knowledge_ids': request.query_params.getlist('knowledge_ids[]')
406+
}).list())
407+
385408

386409
class KnowledgeBaseView(APIView):
387410
authentication_classes = [TokenAuth]

ui/src/api/knowledge/knowledge.ts

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,12 @@ const putLarkKnowledge: (
255255
return put(`${prefix.value}/lark/${knowledge_id}`, data, undefined, loading)
256256
}
257257

258+
const getAllTags: (params: any, loading?: Ref<boolean>) => Promise<Result<any>> = (
259+
params,
260+
loading,
261+
) => {
262+
return get(`${prefix.value}/tags`, params, loading)
263+
}
258264

259265
const getTags: (knowledge_id: string, params: any, loading?: Ref<boolean>) => Promise<Result<any>> = (
260266
knowledge_id,
@@ -315,6 +321,7 @@ export default {
315321
postWebKnowledge,
316322
postLarkKnowledge,
317323
putLarkKnowledge,
324+
getAllTags,
318325
getTags,
319326
postTags,
320327
putTag,

0 commit comments

Comments
 (0)