Skip to content

Commit 59eaacf

Browse files
authored
Merge branch 'v2' into pr@v2@feat_parameter_node
2 parents ce3a870 + a4330e3 commit 59eaacf

File tree

53 files changed

+1418
-206
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

53 files changed

+1418
-206
lines changed

apps/application/flow/step_node/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from .parameter_extraction_node import BaseParameterExtractionNode
2525
from .question_node import *
2626
from .reranker_node import *
27+
from .search_document_node import BaseSearchDocumentNode
2728
from .search_knowledge_node import *
2829
from .speech_to_text_step_node import BaseSpeechToTextNode
2930
from .start_node import *
@@ -35,7 +36,7 @@
3536
from .variable_splitting_node import BaseVariableSplittingNode
3637
from .video_understand_step_node import BaseVideoUnderstandNode
3738

38-
node_list = [BaseStartStepNode, BaseChatNode, BaseSearchKnowledgeNode, BaseQuestionNode,
39+
node_list = [BaseStartStepNode, BaseChatNode, BaseSearchKnowledgeNode, BaseSearchDocumentNode, BaseQuestionNode,
3940
BaseConditionNode, BaseReplyNode,
4041
BaseToolNodeNode, BaseToolLibNodeNode, BaseRerankerNode, BaseApplicationNode,
4142
BaseDocumentExtractNode,

apps/application/flow/step_node/image_understand_step_node/impl/base_image_understand_node.py

Lines changed: 31 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,6 @@ def execute(self, model_id, system, prompt, dialogue_number, dialogue_type, hist
7777
image,
7878
**kwargs) -> NodeResult:
7979
# 处理不正确的参数
80-
if image is None or not isinstance(image, list):
81-
image = []
8280
workspace_id = self.workflow_manage.get_body().get('workspace_id')
8381
image_model = get_model_instance_by_model_workspace_id(model_id, workspace_id,
8482
**model_params_setting)
@@ -91,7 +89,7 @@ def execute(self, model_id, system, prompt, dialogue_number, dialogue_type, hist
9189
message_list = self.generate_message_list(image_model, system, prompt,
9290
self.get_history_message(history_chat_record, dialogue_number), image)
9391
self.context['message_list'] = message_list
94-
self.context['image_list'] = image
92+
self.generate_context_image(image)
9593
self.context['dialogue_type'] = dialogue_type
9694
if stream:
9795
r = image_model.stream(message_list)
@@ -104,6 +102,12 @@ def execute(self, model_id, system, prompt, dialogue_number, dialogue_type, hist
104102
'history_message': history_message, 'question': question.content}, {},
105103
_write_context=write_context)
106104

105+
def generate_context_image(self, image):
106+
if isinstance(image, str) and image.startswith('http'):
107+
self.context['image_list'] = [{'url': image}]
108+
elif image is not None and len(image) > 0:
109+
self.context['image_list'] = image
110+
107111
def get_history_message_for_details(self, history_chat_record, dialogue_number):
108112
start_index = len(history_chat_record) - dialogue_number
109113
history_message = reduce(lambda x, y: [*x, *y], [
@@ -164,28 +168,32 @@ def generate_history_human_message(self, chat_record):
164168
def generate_prompt_question(self, prompt):
165169
return HumanMessage(self.workflow_manage.generate_prompt(prompt))
166170

167-
def generate_message_list(self, image_model, system: str, prompt: str, history_message, image):
168-
if image is not None and len(image) > 0:
169-
# 处理多张图片
170-
images = []
171+
def _process_images(self, image):
172+
"""
173+
处理图像数据,转换为模型可识别的格式
174+
"""
175+
images = []
176+
if isinstance(image, str) and image.startswith('http'):
177+
images.append({'type': 'image_url', 'image_url': {'url': image}})
178+
elif image is not None and len(image) > 0:
171179
for img in image:
172-
if isinstance(img, str) and img.startswith('http'):
173-
images.append({'type': 'image_url', 'image_url': {'url': img}})
174-
else:
175-
file_id = img['file_id']
176-
file = QuerySet(File).filter(id=file_id).first()
177-
image_bytes = file.get_bytes()
178-
base64_image = base64.b64encode(image_bytes).decode("utf-8")
179-
image_format = what(None, image_bytes)
180-
images.append(
181-
{'type': 'image_url', 'image_url': {'url': f'data:image/{image_format};base64,{base64_image}'}})
182-
messages = [HumanMessage(
183-
content=[
184-
{'type': 'text', 'text': self.workflow_manage.generate_prompt(prompt)},
185-
*images
186-
])]
180+
file_id = img['file_id']
181+
file = QuerySet(File).filter(id=file_id).first()
182+
image_bytes = file.get_bytes()
183+
base64_image = base64.b64encode(image_bytes).decode("utf-8")
184+
image_format = what(None, image_bytes)
185+
images.append(
186+
{'type': 'image_url', 'image_url': {'url': f'data:image/{image_format};base64,{base64_image}'}})
187+
return images
188+
189+
def generate_message_list(self, image_model, system: str, prompt: str, history_message, image):
190+
prompt_text = self.workflow_manage.generate_prompt(prompt)
191+
images = self._process_images(image)
192+
193+
if images:
194+
messages = [HumanMessage(content=[{'type': 'text', 'text': prompt_text}, *images])]
187195
else:
188-
messages = [HumanMessage(self.workflow_manage.generate_prompt(prompt))]
196+
messages = [HumanMessage(prompt_text)]
189197

190198
if system is not None and len(system) > 0:
191199
return [
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .impl import *
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# coding=utf-8
2+
from typing import Type, List
3+
4+
from django.utils.translation import gettext_lazy as _
5+
from rest_framework import serializers
6+
7+
from application.flow.i_step_node import INode, NodeResult
8+
9+
10+
class SearchDocumentStepNodeSerializer(serializers.Serializer):
11+
knowledge_id_list = serializers.ListField(
12+
required=False, child=serializers.UUIDField(required=True),
13+
label=_("knowledge id list"), default=list
14+
)
15+
search_mode = serializers.ChoiceField(
16+
required=False, choices=['auto', 'custom'], label=_("search mode"), default='auto'
17+
)
18+
search_scope_type = serializers.ChoiceField(
19+
required=False, choices=['custom', 'referencing'], label=_("search scope type"),
20+
allow_null=True, default='custom'
21+
)
22+
search_scope_source = serializers.ChoiceField(
23+
required=False, choices=['document', 'knowledge'],
24+
label=_("search scope variable type"), default='knowledge'
25+
)
26+
search_scope_reference = serializers.ListField(
27+
required=False, label=_("search scope variable"), default=list
28+
)
29+
question_reference = serializers.ListField(
30+
required=False, label=_("question reference address"), default=list
31+
)
32+
search_condition_type = serializers.ChoiceField(
33+
required=False, choices=['AND', 'OR'], label=_("search condition type"), default='AND'
34+
)
35+
search_condition_list = serializers.ListField(
36+
required=False, label=_("search condition list"), default=list
37+
)
38+
39+
def is_valid(self, *, raise_exception=False):
40+
super().is_valid(raise_exception=True)
41+
42+
43+
class ISearchDocumentStepNode(INode):
44+
type = 'search-document-node'
45+
46+
def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
47+
return SearchDocumentStepNodeSerializer
48+
49+
def _run(self):
50+
return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data)
51+
52+
def execute(self, knowledge_id_list: List, search_mode: str, search_scope_type: str, search_scope_source: str,
53+
search_scope_reference: List, question_reference: List, search_condition_type: str,
54+
search_condition_list: List,
55+
**kwargs) -> NodeResult:
56+
pass
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .base_search_document_node import BaseSearchDocumentNode
Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
# coding=utf-8
2+
from typing import List
3+
4+
import jieba
5+
from django.db.models import Q
6+
from django.db.models import QuerySet
7+
8+
from application.flow.i_step_node import NodeResult
9+
from application.flow.step_node.search_document_node.i_search_document_node import ISearchDocumentStepNode
10+
from common.constants.permission_constants import RoleConstants
11+
from common.database_model_manage.database_model_manage import DatabaseModelManage
12+
from knowledge.models import Document, DocumentTag, Knowledge
13+
14+
15+
class BaseSearchDocumentNode(ISearchDocumentStepNode):
16+
def save_context(self, details, workflow_manage):
17+
self.context['document_list'] = details.get('document_list')
18+
self.context['knowledge_list'] = details.get('knowledge_list')
19+
self.context['document_items'] = details.get('document_items')
20+
self.context['knowledge_items'] = details.get('knowledge_items')
21+
self.context['question'] = details.get('question')
22+
self.context['run_time'] = details.get('run_time')
23+
24+
def get_reference_content(self, fields: List[str]):
25+
return self.workflow_manage.get_reference_field(fields[0], fields[1:])
26+
27+
def execute(self, knowledge_id_list: List, search_mode: str, search_scope_type: str, search_scope_source: str,
28+
search_scope_reference: List, question_reference: List, search_condition_type: str,
29+
search_condition_list: List,
30+
**kwargs) -> NodeResult:
31+
32+
if search_scope_type == 'custom': # 手动选择知识库
33+
document_id_list = QuerySet(Document).filter(
34+
knowledge_id__in=knowledge_id_list
35+
).values_list('id', flat=True)
36+
else: # 引用上一步知识库/文档
37+
if search_scope_source == 'document': # 文档
38+
document_id_list = self.get_reference_content(search_scope_reference)
39+
else: # 知识库
40+
document_id_list = QuerySet(Document).filter(
41+
knowledge_id__in=self.get_reference_content(search_scope_reference)
42+
).values_list('id', flat=True)
43+
44+
# 权限过滤
45+
get_knowledge_list_of_authorized = DatabaseModelManage.get_model('get_knowledge_list_of_authorized')
46+
chat_user_type = self.workflow_manage.get_body().get('chat_user_type')
47+
48+
if get_knowledge_list_of_authorized is not None and RoleConstants.CHAT_USER.value.name == chat_user_type:
49+
# 获取授权的知识库ID列表
50+
authorized_knowledge_ids = get_knowledge_list_of_authorized(
51+
self.workflow_manage.get_body().get('chat_user_id'),
52+
knowledge_id_list
53+
)
54+
55+
# 过滤出授权知识库下的文档
56+
document_id_list = QuerySet(Document).filter(
57+
id__in=document_id_list,
58+
knowledge_id__in=authorized_knowledge_ids
59+
).values_list('id', flat=True)
60+
61+
if search_mode == 'auto': # 通过问题自动检索
62+
matched_doc_ids = self.handle_auto_tags(document_id_list, question_reference)
63+
64+
final_document_ids = list(matched_doc_ids)
65+
else: # 自定义检索条件
66+
matched_document_ids = self.handle_custom_tags(
67+
document_id_list, search_condition_list, search_condition_type
68+
)
69+
70+
final_document_ids = list(matched_document_ids)
71+
72+
# UUID to str
73+
final_document_ids = [str(doc_id) for doc_id in final_document_ids]
74+
document_items = QuerySet(Document).filter(id__in=final_document_ids).values()
75+
final_knowledge_ids = list(set(str(doc['knowledge_id']) for doc in document_items))
76+
knowledge_items = QuerySet(Knowledge).filter(id__in=final_knowledge_ids).values()
77+
78+
return NodeResult({
79+
'document_list': final_document_ids,
80+
'document_items': list(document_items),
81+
'knowledge_list': final_knowledge_ids,
82+
'knowledge_items': list(knowledge_items)
83+
}, {})
84+
85+
def handle_auto_tags(self, document_id_list: list, question_reference: list):
86+
question = self.get_reference_content(question_reference)
87+
88+
# 使用jieba分词
89+
keywords = jieba.lcut(question)
90+
if not keywords:
91+
return set()
92+
93+
# 构建OR查询,一次性获取所有匹配的文档
94+
q_objects = Q()
95+
for keyword in keywords:
96+
q_objects |= Q(tag__value__icontains=keyword)
97+
98+
# 单次数据库查询
99+
matched_doc_ids = set(
100+
QuerySet(DocumentTag)
101+
.filter(document_id__in=document_id_list)
102+
.filter(q_objects)
103+
.values_list('document_id', flat=True)
104+
.distinct()
105+
)
106+
107+
return matched_doc_ids
108+
109+
def handle_custom_tags(self, document_id_list: List, search_condition_list: list, search_condition_type: str):
110+
111+
if not search_condition_list:
112+
return set(document_id_list)
113+
114+
if search_condition_type == 'AND':
115+
# AND逻辑:使用子查询和聚合
116+
matched_doc_ids = set(document_id_list)
117+
118+
for condition in search_condition_list:
119+
tag_key = condition['key']
120+
field_value = self.workflow_manage.generate_prompt(condition['value'])
121+
compare_type = condition['compare']
122+
123+
# 构建查询条件
124+
if compare_type == 'contain':
125+
q_filter = Q(tag__key=tag_key, tag__value__icontains=field_value)
126+
elif compare_type == 'eq':
127+
q_filter = Q(tag__key=tag_key, tag__value=field_value)
128+
elif compare_type == 'not_contain':
129+
q_filter = ~Q(tag__key=tag_key, tag__value__icontains=field_value)
130+
else:
131+
continue
132+
133+
# 单次查询获取符合条件的文档
134+
tag_docs = set(QuerySet(DocumentTag).filter(
135+
document_id__in=matched_doc_ids
136+
).filter(q_filter).values_list('document_id', flat=True).distinct())
137+
138+
matched_doc_ids = matched_doc_ids.intersection(tag_docs)
139+
140+
return matched_doc_ids
141+
142+
else:
143+
# OR逻辑:使用一次查询完成
144+
q_objects = Q()
145+
146+
for condition in search_condition_list:
147+
tag_key = condition['key']
148+
field_value = self.workflow_manage.generate_prompt(condition['value'])
149+
compare_type = condition['compare']
150+
151+
if compare_type == 'contain':
152+
q_objects |= Q(tag__key=tag_key, tag__value__icontains=field_value)
153+
elif compare_type == 'eq':
154+
q_objects |= Q(tag__key=tag_key, tag__value=field_value)
155+
elif compare_type == 'not_contain':
156+
q_objects |= ~Q(tag__key=tag_key, tag__value__icontains=field_value)
157+
158+
# 一次查询获取所有匹配的文档
159+
matched_docs = set(QuerySet(DocumentTag).filter(
160+
document_id__in=document_id_list
161+
).filter(q_objects).values_list('document_id', flat=True).distinct())
162+
163+
return matched_docs
164+
165+
def get_details(self, index: int, **kwargs):
166+
return {
167+
'name': self.node.properties.get('stepName'),
168+
'question': self.context.get('question'),
169+
"index": index,
170+
'run_time': self.context.get('run_time'),
171+
'document_list': self.context.get('document_list'),
172+
'knowledge_list': self.context.get('knowledge_list'),
173+
'document_items': self.context.get('document_items'),
174+
'knowledge_items': self.context.get('knowledge_items'),
175+
'type': self.node.type,
176+
'status': self.status,
177+
'err_message': self.err_message
178+
}

apps/application/flow/step_node/search_knowledge_node/i_search_knowledge_node.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,12 @@
1010
from typing import Type
1111

1212
from django.core import validators
13+
from django.utils.translation import gettext_lazy as _
1314
from rest_framework import serializers
1415

1516
from application.flow.i_step_node import INode, NodeResult
1617
from common.utils.common import flat_map
1718

18-
from django.utils.translation import gettext_lazy as _
19-
2019

2120
class DatasetSettingSerializer(serializers.Serializer):
2221
# 需要查询的条数
@@ -43,6 +42,17 @@ class SearchDatasetStepNodeSerializer(serializers.Serializer):
4342

4443
show_knowledge = serializers.BooleanField(required=True,
4544
label=_("The results are displayed in the knowledge sources"))
45+
search_scope_type = serializers.ChoiceField(
46+
required=False, choices=['custom', 'referencing'], label=_("search scope type"),
47+
allow_null=True, default='custom'
48+
)
49+
search_scope_source = serializers.ChoiceField(
50+
required=False, choices=['document', 'knowledge'],
51+
label=_("search scope variable type"), default='knowledge'
52+
)
53+
search_scope_reference = serializers.ListField(
54+
required=False, label=_("search scope variable"), default=list
55+
)
4656

4757
def is_valid(self, *, raise_exception=False):
4858
super().is_valid(raise_exception=True)
@@ -76,7 +86,9 @@ def _run(self):
7686
return self.execute(**self.node_params_serializer.data, question=str(question),
7787
exclude_paragraph_id_list=exclude_paragraph_id_list)
7888

79-
def execute(self, dataset_id_list, dataset_setting, question, show_knowledge,
89+
def execute(self, dataset_id_list, dataset_setting, question, show_knowledge, search_scope_type,
90+
search_scope_source,
91+
search_scope_reference,
8092
exclude_paragraph_id_list=None,
8193
**kwargs) -> NodeResult:
8294
pass

0 commit comments

Comments
 (0)